Gather up another utility.
[u/mdw/catacomb] / mp-modsqrt.c
index a3a6b1f..f9e4b0f 100644 (file)
@@ -1,6 +1,6 @@
 /* -*-c-*-
  *
- * $Id: mp-modsqrt.c,v 1.1 2000/06/22 19:01:31 mdw Exp $
+ * $Id: mp-modsqrt.c,v 1.5 2004/04/08 01:36:15 mdw Exp $
  *
  * Compute square roots modulo a prime
  *
  * MA 02111-1307, USA.
  */
 
-/*----- Revision history --------------------------------------------------* 
- *
- * $Log: mp-modsqrt.c,v $
- * Revision 1.1  2000/06/22 19:01:31  mdw
- * Compute square roots in a prime field.
- *
- */
-
 /*----- Header files ------------------------------------------------------*/
 
 #include "fibrand.h"
@@ -75,8 +67,7 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
   /* --- Cope if %$a \not\in Q_p$% --- */
 
   if (mp_jacobi(a, p) != 1) {
-    if (d)
-      mp_drop(d);
+    mp_drop(d);
     return (0);
   }
 
@@ -94,8 +85,7 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
 
   /* --- Find the inverse of %$a$% --- */
 
-  ainv = MP_NEW;
-  mp_gcd(0, &ainv, 0, a, p);
+  ainv = mp_modinv(MP_NEW, a, p);
   
   /* --- Split %$p - 1$% into a power of two and an odd number --- */
 
@@ -105,10 +95,13 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
   /* --- Now to really get going --- */
 
   mpmont_create(&mm, p);
+  b = mpmont_mul(&mm, b, b, mm.r2);
   c = mpmont_expr(&mm, b, b, t);
   t = mp_add(t, t, MP_ONE);
   t = mp_lsr(t, t, 1);
-  r = mpmont_expr(&mm, t, a, t);
+  dd = mpmont_mul(&mm, MP_NEW, a, mm.r2);
+  r = mpmont_expr(&mm, t, dd, t);
+  mp_drop(dd);
   ainv = mpmont_mul(&mm, ainv, ainv, mm.r2);
 
   mone = mp_sub(MP_NEW, p, mm.r);
@@ -132,7 +125,7 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
 
     /* --- Fiddle at the end --- */
 
-    if (MP_CMP(dd, ==, mone))
+    if (MP_EQ(dd, mone))
       r = mpmont_mul(&mm, r, r, c);
     c = mp_sqr(c, c);
     c = mpmont_reduce(&mm, c, c);
@@ -143,8 +136,7 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
   d = mpmont_reduce(&mm, d, r);
   mp_drop(ainv);
   mp_drop(r); mp_drop(c);
-  if (dd)
-    mp_drop(dd);
+  mp_drop(dd);
   mp_drop(mone);
   mpmont_destroy(&mm);
 
@@ -167,11 +159,11 @@ static int verify(dstr *v)
 
   if (!r)
     ok = 0;
-  else if (MP_CMP(r, ==, rr))
+  else if (MP_EQ(r, rr))
     ok = 1;
   else {
     r = mp_sub(r, p, r);
-    if (MP_CMP(r, ==, rr))
+    if (MP_EQ(r, rr))
       ok = 1;
   }
 
@@ -191,8 +183,7 @@ static int verify(dstr *v)
 
   mp_drop(a);
   mp_drop(p);
-  if (r)
-    mp_drop(r);
+  mp_drop(r);
   mp_drop(rr);
   assert(mparena_count(MPARENA_GLOBAL) == 0);
   return (ok);