Merge branch 'master' of git+ssh://metalzone.distorted.org.uk/~mdw/public-git/catacomb/
[u/mdw/catacomb] / mp-modsqrt.c
index df60aba..1cacacd 100644 (file)
@@ -1,6 +1,6 @@
 /* -*-c-*-
  *
- * $Id: mp-modsqrt.c,v 1.4 2001/06/16 12:56:38 mdw Exp $
+ * $Id: mp-modsqrt.c,v 1.5 2004/04/08 01:36:15 mdw Exp $
  *
  * Compute square roots modulo a prime
  *
  * MA 02111-1307, USA.
  */
 
-/*----- Revision history --------------------------------------------------* 
- *
- * $Log: mp-modsqrt.c,v $
- * Revision 1.4  2001/06/16 12:56:38  mdw
- * Fixes for interface change to @mpmont_expr@ and @mpmont_mexpr@.
- *
- * Revision 1.3  2001/02/03 12:00:29  mdw
- * Now @mp_drop@ checks its argument is non-NULL before attempting to free
- * it.  Note that the macro version @MP_DROP@ doesn't do this.
- *
- * Revision 1.2  2000/10/08 12:02:21  mdw
- * Use @MP_EQ@ instead of @MP_CMP@.
- *
- * Revision 1.1  2000/06/22 19:01:31  mdw
- * Compute square roots in a prime field.
- *
- */
-
 /*----- Header files ------------------------------------------------------*/
 
 #include "fibrand.h"
@@ -69,6 +51,9 @@
  *             work if %$p$% is composite: you must factor the modulus, take
  *             a square root mod each factor, and recombine the results
  *             using the Chinese Remainder Theorem.
+ *
+ *             We guarantee that the square root returned is the smallest
+ *             one (i.e., the `positive' square root).
  */
 
 mp *mp_modsqrt(mp *d, mp *a, mp *p)
@@ -103,8 +88,7 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
 
   /* --- Find the inverse of %$a$% --- */
 
-  ainv = MP_NEW;
-  mp_gcd(0, &ainv, 0, a, p);
+  ainv = mp_modinv(MP_NEW, a, p);
   
   /* --- Split %$p - 1$% into a power of two and an odd number --- */
 
@@ -150,9 +134,14 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
     c = mpmont_reduce(&mm, c, c);
   }
 
-  /* --- Done, so tidy up --- */
+  /* --- Done, so tidy up --- *
+   *
+   * Canonify the answer.
+   */
 
   d = mpmont_reduce(&mm, d, r);
+  r = mp_sub(r, p, d);
+  if (MP_CMP(r, <, d)) { mp *tt = r; r = d; d = tt; }
   mp_drop(ainv);
   mp_drop(r); mp_drop(c);
   mp_drop(dd);
@@ -180,11 +169,6 @@ static int verify(dstr *v)
     ok = 0;
   else if (MP_EQ(r, rr))
     ok = 1;
-  else {
-    r = mp_sub(r, p, r);
-    if (MP_EQ(r, rr))
-      ok = 1;
-  }
 
   if (!ok) {
     fputs("\n*** fail\n", stderr);