Make modinv able to return NULL if its inputs are not coprime, and
[u/mdw/putty] / sshbn.c
diff --git a/sshbn.c b/sshbn.c
index 677b121..a206783 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -624,6 +624,7 @@ static void internal_mod(BignumInt *a, int alen,
     int i, k;
 
     m0 = m[0];
+    assert(m0 >> (BIGNUM_INT_BITS-1) == 1);
     if (mlen > 1)
        m1 = m[1];
     else
@@ -815,20 +816,15 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
        result[0]--;
 
     /* Free temporary arrays */
-    for (i = 0; i < 2 * mlen; i++)
-       a[i] = 0;
+    smemclr(a, 2 * mlen * sizeof(*a));
     sfree(a);
-    for (i = 0; i < scratchlen; i++)
-       scratch[i] = 0;
+    smemclr(scratch, scratchlen * sizeof(*scratch));
     sfree(scratch);
-    for (i = 0; i < 2 * mlen; i++)
-       b[i] = 0;
+    smemclr(b, 2 * mlen * sizeof(*b));
     sfree(b);
-    for (i = 0; i < mlen; i++)
-       m[i] = 0;
+    smemclr(m, mlen * sizeof(*m));
     sfree(m);
-    for (i = 0; i < mlen; i++)
-       n[i] = 0;
+    smemclr(n, mlen * sizeof(*n));
     sfree(n);
 
     freebn(base);
@@ -873,6 +869,7 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
     len = mod[0];
     r = bn_power_2(BIGNUM_INT_BITS * len);
     inv = modinv(mod, r);
+    assert(inv); /* cannot fail, since mod is odd and r is a power of 2 */
 
     /*
      * Multiply the base by r mod n, to get it into Montgomery
@@ -965,23 +962,17 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
        result[0]--;
 
     /* Free temporary arrays */
-    for (i = 0; i < scratchlen; i++)
-       scratch[i] = 0;
+    smemclr(scratch, scratchlen * sizeof(*scratch));
     sfree(scratch);
-    for (i = 0; i < 2 * len; i++)
-       a[i] = 0;
+    smemclr(a, 2 * len * sizeof(*a));
     sfree(a);
-    for (i = 0; i < 2 * len; i++)
-       b[i] = 0;
+    smemclr(b, 2 * len * sizeof(*b));
     sfree(b);
-    for (i = 0; i < len; i++)
-       mninv[i] = 0;
+    smemclr(mninv, len * sizeof(*mninv));
     sfree(mninv);
-    for (i = 0; i < len; i++)
-       n[i] = 0;
+    smemclr(n, len * sizeof(*n));
     sfree(n);
-    for (i = 0; i < len; i++)
-       x[i] = 0;
+    smemclr(x, len * sizeof(*x));
     sfree(x);
 
     return result;
@@ -999,6 +990,12 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
     int pqlen, mlen, rlen, i, j;
     Bignum result;
 
+    /*
+     * The most significant word of mod needs to be non-zero. It
+     * should already be, but let's make sure.
+     */
+    assert(mod[mod[0]] != 0);
+
     /* Allocate m of size mlen, copy mod to m */
     /* We use big endian internally */
     mlen = mod[0];
@@ -1071,20 +1068,15 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
        result[0]--;
 
     /* Free temporary arrays */
-    for (i = 0; i < scratchlen; i++)
-       scratch[i] = 0;
+    smemclr(scratch, scratchlen * sizeof(*scratch));
     sfree(scratch);
-    for (i = 0; i < 2 * pqlen; i++)
-       a[i] = 0;
+    smemclr(a, 2 * pqlen * sizeof(*a));
     sfree(a);
-    for (i = 0; i < mlen; i++)
-       m[i] = 0;
+    smemclr(m, mlen * sizeof(*m));
     sfree(m);
-    for (i = 0; i < pqlen; i++)
-       n[i] = 0;
+    smemclr(n, pqlen * sizeof(*n));
     sfree(n);
-    for (i = 0; i < pqlen; i++)
-       o[i] = 0;
+    smemclr(o, pqlen * sizeof(*o));
     sfree(o);
 
     return result;
@@ -1103,6 +1095,12 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
     int mshift;
     int plen, mlen, i, j;
 
+    /*
+     * The most significant word of mod needs to be non-zero. It
+     * should already be, but let's make sure.
+     */
+    assert(mod[mod[0]] != 0);
+
     /* Allocate m of size mlen, copy mod to m */
     /* We use big endian internally */
     mlen = mod[0];
@@ -1154,11 +1152,9 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
     }
 
     /* Free temporary arrays */
-    for (i = 0; i < mlen; i++)
-       m[i] = 0;
+    smemclr(m, mlen * sizeof(*m));
     sfree(m);
-    for (i = 0; i < plen; i++)
-       n[i] = 0;
+    smemclr(n, plen * sizeof(*n));
     sfree(n);
 }
 
@@ -1405,8 +1401,7 @@ Bignum bigmuladd(Bignum a, Bignum b, Bignum addend)
     }
     ret[0] = maxspot;
 
-    for (i = 0; i < wslen; i++)
-        workspace[i] = 0;
+    smemclr(workspace, wslen * sizeof(*workspace));
     sfree(workspace);
     return ret;
 }
@@ -1636,9 +1631,22 @@ Bignum modinv(Bignum number, Bignum modulus)
     Bignum x = copybn(One);
     int sign = +1;
 
+    assert(number[number[0]] != 0);
+    assert(modulus[modulus[0]] != 0);
+
     while (bignum_cmp(b, One) != 0) {
-       Bignum t = newbn(b[0]);
-       Bignum q = newbn(a[0]);
+       Bignum t, q;
+
+        if (bignum_cmp(b, Zero) == 0) {
+            /*
+             * Found a common factor between the inputs, so we cannot
+             * return a modular inverse at all.
+             */
+            return NULL;
+        }
+
+        t = newbn(b[0]);
+       q = newbn(a[0]);
        bigdivmod(a, b, t, q);
        while (t[0] > 1 && t[t[0]] == 0)
            t[0]--;
@@ -1757,6 +1765,7 @@ char *bignum_decimal(Bignum x)
     /*
      * Done.
      */
+    smemclr(workspace, x[0] * sizeof(*workspace));
     sfree(workspace);
     return ret;
 }