Compute square roots in a prime field.
[u/mdw/catacomb] / mp-modsqrt.c
diff --git a/mp-modsqrt.c b/mp-modsqrt.c
new file mode 100644 (file)
index 0000000..a3a6b1f
--- /dev/null
@@ -0,0 +1,215 @@
+/* -*-c-*-
+ *
+ * $Id: mp-modsqrt.c,v 1.1 2000/06/22 19:01:31 mdw Exp $
+ *
+ * Compute square roots modulo a prime
+ *
+ * (c) 2000 Straylight/Edgeware
+ */
+
+/*----- Licensing notice --------------------------------------------------* 
+ *
+ * This file is part of Catacomb.
+ *
+ * Catacomb is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Library General Public License as
+ * published by the Free Software Foundation; either version 2 of the
+ * License, or (at your option) any later version.
+ * 
+ * Catacomb is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU Library General Public License for more details.
+ * 
+ * You should have received a copy of the GNU Library General Public
+ * License along with Catacomb; if not, write to the Free
+ * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
+ * MA 02111-1307, USA.
+ */
+
+/*----- Revision history --------------------------------------------------* 
+ *
+ * $Log: mp-modsqrt.c,v $
+ * Revision 1.1  2000/06/22 19:01:31  mdw
+ * Compute square roots in a prime field.
+ *
+ */
+
+/*----- Header files ------------------------------------------------------*/
+
+#include "fibrand.h"
+#include "grand.h"
+#include "mp.h"
+#include "mpmont.h"
+#include "mprand.h"
+
+/*----- Main code ---------------------------------------------------------*/
+
+/* --- @mp_modsqrt@ --- *
+ *
+ * Arguments:  @mp *d@ = destination integer
+ *             @mp *a@ = source integer
+ *             @mp *p@ = modulus (must be prime)
+ *
+ * Returns:    If %$a$% is a quadratic residue, a square root of %$a$%; else
+ *             a null pointer.
+ *
+ * Use:                Returns an integer %$x$% such that %$x^2 \equiv a \pmod{p}$%,
+ *             if one exists; else a null pointer.  This function will not
+ *             work if %$p$% is composite: you must factor the modulus, take
+ *             a square root mod each factor, and recombine the results
+ *             using the Chinese Remainder Theorem.
+ */
+
+mp *mp_modsqrt(mp *d, mp *a, mp *p)
+{
+  mpmont mm;
+  mp *t;
+  size_t s;
+  mp *b;
+  mp *ainv;
+  mp *c, *r;
+  size_t i, j;
+  mp *dd, *mone;
+
+  /* --- Cope if %$a \not\in Q_p$% --- */
+
+  if (mp_jacobi(a, p) != 1) {
+    if (d)
+      mp_drop(d);
+    return (0);
+  }
+
+  /* --- Choose some quadratic non-residue --- */
+
+  {
+    grand *g = fibrand_create(0);
+
+    b = MP_NEW;
+    do
+      b = mprand_range(b, p, g, 0);
+    while (mp_jacobi(b, p) != -1);
+    g->ops->destroy(g);
+  }
+
+  /* --- Find the inverse of %$a$% --- */
+
+  ainv = MP_NEW;
+  mp_gcd(0, &ainv, 0, a, p);
+  
+  /* --- Split %$p - 1$% into a power of two and an odd number --- */
+
+  t = mp_sub(MP_NEW, p, MP_ONE);
+  t = mp_odd(t, t, &s);
+
+  /* --- Now to really get going --- */
+
+  mpmont_create(&mm, p);
+  c = mpmont_expr(&mm, b, b, t);
+  t = mp_add(t, t, MP_ONE);
+  t = mp_lsr(t, t, 1);
+  r = mpmont_expr(&mm, t, a, t);
+  ainv = mpmont_mul(&mm, ainv, ainv, mm.r2);
+
+  mone = mp_sub(MP_NEW, p, mm.r);
+
+  dd = MP_NEW;
+
+  for (i = 1; i < s; i++) {
+
+    /* --- Compute %$d_0 = r^2a^{-1}$% --- */
+
+    dd = mp_sqr(dd, r);
+    dd = mpmont_reduce(&mm, dd, dd);
+    dd = mpmont_mul(&mm, dd, dd, ainv);
+
+    /* --- Now %$d = d_0^{s - i - 1}$% --- */
+
+    for (j = i; j < s - 1; j++) {
+      dd = mp_sqr(dd, dd);
+      dd = mpmont_reduce(&mm, dd, dd);
+    }
+
+    /* --- Fiddle at the end --- */
+
+    if (MP_CMP(dd, ==, mone))
+      r = mpmont_mul(&mm, r, r, c);
+    c = mp_sqr(c, c);
+    c = mpmont_reduce(&mm, c, c);
+  }
+
+  /* --- Done, so tidy up --- */
+
+  d = mpmont_reduce(&mm, d, r);
+  mp_drop(ainv);
+  mp_drop(r); mp_drop(c);
+  if (dd)
+    mp_drop(dd);
+  mp_drop(mone);
+  mpmont_destroy(&mm);
+
+  return (d);
+}
+
+/*----- Test rig ----------------------------------------------------------*/
+
+#ifdef TEST_RIG
+
+#include <mLib/testrig.h>
+
+static int verify(dstr *v)
+{
+  mp *a = *(mp **)v[0].buf;
+  mp *p = *(mp **)v[1].buf;
+  mp *rr = *(mp **)v[2].buf;
+  mp *r = mp_modsqrt(MP_NEW, a, p);
+  int ok = 0;
+
+  if (!r)
+    ok = 0;
+  else if (MP_CMP(r, ==, rr))
+    ok = 1;
+  else {
+    r = mp_sub(r, p, r);
+    if (MP_CMP(r, ==, rr))
+      ok = 1;
+  }
+
+  if (!ok) {
+    fputs("\n*** fail\n", stderr);
+    fputs("a  = ", stderr); mp_writefile(a, stderr, 10); fputc('\n', stderr);
+    fputs("p  = ", stderr); mp_writefile(p, stderr, 10); fputc('\n', stderr);
+    if (r) {
+      fputs("r  = ", stderr);
+      mp_writefile(r, stderr, 10);
+      fputc('\n', stderr);
+    } else
+      fputs("r  = <undef>\n", stderr);
+    fputs("rr = ", stderr); mp_writefile(rr, stderr, 10); fputc('\n', stderr);
+    ok = 0;
+  }
+
+  mp_drop(a);
+  mp_drop(p);
+  if (r)
+    mp_drop(r);
+  mp_drop(rr);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+static test_chunk tests[] = {
+  { "modsqrt", verify, { &type_mp, &type_mp, &type_mp, 0 } },
+  { 0, 0, { 0 } }
+};
+
+int main(int argc, char *argv[])
+{
+  sub_init();
+  test_run(argc, argv, tests, SRCDIR "/tests/mp");
+  return (0);
+}
+
+#endif
+
+/*----- That's all, folks -------------------------------------------------*/