mp-modsqrt: Always return the smaller possible square root.
[u/mdw/catacomb] / mp-modsqrt.c
index f9e4b0f..1cacacd 100644 (file)
@@ -51,6 +51,9 @@
  *             work if %$p$% is composite: you must factor the modulus, take
  *             a square root mod each factor, and recombine the results
  *             using the Chinese Remainder Theorem.
+ *
+ *             We guarantee that the square root returned is the smallest
+ *             one (i.e., the `positive' square root).
  */
 
 mp *mp_modsqrt(mp *d, mp *a, mp *p)
@@ -131,9 +134,14 @@ mp *mp_modsqrt(mp *d, mp *a, mp *p)
     c = mpmont_reduce(&mm, c, c);
   }
 
-  /* --- Done, so tidy up --- */
+  /* --- Done, so tidy up --- *
+   *
+   * Canonify the answer.
+   */
 
   d = mpmont_reduce(&mm, d, r);
+  r = mp_sub(r, p, d);
+  if (MP_CMP(r, <, d)) { mp *tt = r; r = d; d = tt; }
   mp_drop(ainv);
   mp_drop(r); mp_drop(c);
   mp_drop(dd);
@@ -161,11 +169,6 @@ static int verify(dstr *v)
     ok = 0;
   else if (MP_EQ(r, rr))
     ok = 1;
-  else {
-    r = mp_sub(r, p, r);
-    if (MP_EQ(r, rr))
-      ok = 1;
-  }
 
   if (!ok) {
     fputs("\n*** fail\n", stderr);