Implement the Montgomery technique for speeding up modular
authorsimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Fri, 18 Feb 2011 08:25:38 +0000 (08:25 +0000)
committersimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Fri, 18 Feb 2011 08:25:38 +0000 (08:25 +0000)
exponentiation by replacing the modulo operation by a cleverly chosen
multiplication. This was not worth doing in the previous state of the
code (because my multiply was about as slow as my modulo), but now
that multiplication has been sped up by the Karatsuba optimisation,
Montgomery becomes worthwhile.

git-svn-id: svn://svn.tartarus.org/sgt/putty@9094 cda61777-01e9-0310-a592-d414129be87e

sshbn.c

diff --git a/sshbn.c b/sshbn.c
index 65c0629..5620bb0 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -203,7 +203,7 @@ static void internal_sub(const BignumInt *a, const BignumInt *b,
  * Result is returned in the first 2*len words of c.
  */
 #define KARATSUBA_THRESHOLD 50
-static void internal_mul(BignumInt *a, BignumInt *b,
+static void internal_mul(const BignumInt *a, const BignumInt *b,
                         BignumInt *c, int len)
 {
     int i, j;
@@ -348,6 +348,167 @@ static void internal_mul(BignumInt *a, BignumInt *b,
     }
 }
 
+/*
+ * Variant form of internal_mul used for the initial step of
+ * Montgomery reduction. Only bothers outputting 'len' words
+ * (everything above that is thrown away).
+ */
+static void internal_mul_low(const BignumInt *a, const BignumInt *b,
+                             BignumInt *c, int len)
+{
+    int i, j;
+    BignumDblInt t;
+
+    if (len > KARATSUBA_THRESHOLD) {
+
+        /*
+         * Karatsuba-aware version of internal_mul_low. As before, we
+         * express each input value as a shifted combination of two
+         * halves:
+         *
+         *   a = a_1 D + a_0
+         *   b = b_1 D + b_0
+         *
+         * Then the full product is, as before,
+         *
+         *  ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
+         *
+         * Provided we choose D on the large side (so that a_0 and b_0
+         * are _at least_ as long as a_1 and b_1), we don't need the
+         * topmost term at all, and we only need half of the middle
+         * term. So there's no point in doing the proper Karatsuba
+         * optimisation which computes the middle term using the top
+         * one, because we'd take as long computing the top one as
+         * just computing the middle one directly.
+         *
+         * So instead, we do a much more obvious thing: we call the
+         * fully optimised internal_mul to compute a_0 b_0, and we
+         * recursively call ourself to compute the _bottom halves_ of
+         * a_1 b_0 and a_0 b_1, each of which we add into the result
+         * in the obvious way.
+         *
+         * In other words, there's no actual Karatsuba _optimisation_
+         * in this function; the only benefit in doing it this way is
+         * that we call internal_mul proper for a large part of the
+         * work, and _that_ can optimise its operation.
+         */
+
+        int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
+        BignumInt *scratch;
+
+        /*
+         * Allocate scratch space for the various bits and pieces
+         * we're going to be adding together. We need botlen*2 words
+         * for a_0 b_0 (though we may end up throwing away its topmost
+         * word), and toplen words for each of a_1 b_0 and a_0 b_1.
+         * That adds up to exactly 2*len.
+         */
+        scratch = snewn(len*2, BignumInt);
+
+        /* a_0 b_0 */
+        internal_mul(a + toplen, b + toplen, scratch + 2*toplen, botlen);
+
+        /* a_1 b_0 */
+        internal_mul_low(a, b + len - toplen, scratch + toplen, toplen);
+
+        /* a_0 b_1 */
+        internal_mul_low(a + len - toplen, b, scratch, toplen);
+
+        /* Copy the bottom half of the big coefficient into place */
+        for (j = 0; j < botlen; j++)
+            c[toplen + j] = scratch[2*toplen + botlen + j];
+
+        /* Add the two small coefficients, throwing away the returned carry */
+        internal_add(scratch, scratch + toplen, scratch, toplen);
+
+        /* And add that to the large coefficient, leaving the result in c. */
+        internal_add(scratch, scratch + 2*toplen + botlen - toplen,
+                     c, toplen);
+
+        /* Free scratch. */
+        for (j = 0; j < len*2; j++)
+            scratch[j] = 0;
+        sfree(scratch);
+
+    } else {
+
+        for (j = 0; j < len; j++)
+            c[j] = 0;
+
+        for (i = len - 1; i >= 0; i--) {
+            t = 0;
+            for (j = len - 1; j >= len - i - 1; j--) {
+                t += MUL_WORD(a[i], (BignumDblInt) b[j]);
+                t += (BignumDblInt) c[i + j + 1 - len];
+                c[i + j + 1 - len] = (BignumInt) t;
+                t = t >> BIGNUM_INT_BITS;
+            }
+        }
+
+    }
+}
+
+/*
+ * Montgomery reduction. Expects x to be a big-endian array of 2*len
+ * BignumInts whose value satisfies 0 <= x < rn (where r = 2^(len *
+ * BIGNUM_INT_BITS) is the Montgomery base). Returns in the same array
+ * a value x' which is congruent to xr^{-1} mod n, and satisfies 0 <=
+ * x' < n.
+ *
+ * 'n' and 'mninv' should be big-endian arrays of 'len' BignumInts
+ * each, containing respectively n and the multiplicative inverse of
+ * -n mod r.
+ *
+ * 'tmp' is an array of at least '3*len' BignumInts used as scratch
+ * space.
+ */
+static void monty_reduce(BignumInt *x, const BignumInt *n,
+                         const BignumInt *mninv, BignumInt *tmp, int len)
+{
+    int i;
+    BignumInt carry;
+
+    /*
+     * Multiply x by (-n)^{-1} mod r. This gives us a value m such
+     * that mn is congruent to -x mod r. Hence, mn+x is an exact
+     * multiple of r, and is also (obviously) congruent to x mod n.
+     */
+    internal_mul_low(x + len, mninv, tmp, len);
+
+    /*
+     * Compute t = (mn+x)/r in ordinary, non-modular, integer
+     * arithmetic. By construction this is exact, and is congruent mod
+     * n to x * r^{-1}, i.e. the answer we want.
+     *
+     * The following multiply leaves that answer in the _most_
+     * significant half of the 'x' array, so then we must shift it
+     * down.
+     */
+    internal_mul(tmp, n, tmp+len, len);
+    carry = internal_add(x, tmp+len, x, 2*len);
+    for (i = 0; i < len; i++)
+        x[len + i] = x[i], x[i] = 0;
+
+    /*
+     * Reduce t mod n. This doesn't require a full-on division by n,
+     * but merely a test and single optional subtraction, since we can
+     * show that 0 <= t < 2n.
+     *
+     * Proof:
+     *  + we computed m mod r, so 0 <= m < r.
+     *  + so 0 <= mn < rn, obviously
+     *  + hence we only need 0 <= x < rn to guarantee that 0 <= mn+x < 2rn
+     *  + yielding 0 <= (mn+x)/r < 2n as required.
+     */
+    if (!carry) {
+        for (i = 0; i < len; i++)
+            if (x[len + i] != n[i])
+                break;
+    }
+    if (carry || i >= len || x[len + i] > n[i])
+        internal_sub(x+len, n, x+len, len);
+}
+
 static void internal_add_shifted(BignumInt *number,
                                 unsigned n, int shift)
 {
@@ -469,14 +630,14 @@ static void internal_mod(BignumInt *a, int alen,
 }
 
 /*
- * Compute (base ^ exp) % mod.
+ * Compute (base ^ exp) % mod. Uses the Montgomery multiplication
+ * technique.
  */
 Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
 {
-    BignumInt *a, *b, *n, *m;
-    int mshift;
-    int mlen, i, j;
-    Bignum base, result;
+    BignumInt *a, *b, *x, *n, *mninv, *tmp;
+    int len, i, j;
+    Bignum base, base2, r, rn, inv, result;
 
     /*
      * The most significant word of mod needs to be non-zero. It
@@ -490,37 +651,64 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
      */
     base = bigmod(base_in, mod);
 
-    /* Allocate m of size mlen, copy mod to m */
-    /* We use big endian internally */
-    mlen = mod[0];
-    m = snewn(mlen, BignumInt);
-    for (j = 0; j < mlen; j++)
-       m[j] = mod[mod[0] - j];
+    /*
+     * mod had better be odd, or we can't do Montgomery multiplication
+     * using a power of two at all.
+     */
+    assert(mod[1] & 1);
 
-    /* Shift m left to make msb bit set */
-    for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
-       if ((m[0] << mshift) & BIGNUM_TOP_BIT)
-           break;
-    if (mshift) {
-       for (i = 0; i < mlen - 1; i++)
-           m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       m[mlen - 1] = m[mlen - 1] << mshift;
-    }
+    /*
+     * Compute the inverse of n mod r, for monty_reduce. (In fact we
+     * want the inverse of _minus_ n mod r, but we'll sort that out
+     * below.)
+     */
+    len = mod[0];
+    r = bn_power_2(BIGNUM_INT_BITS * len);
+    inv = modinv(mod, r);
 
-    /* Allocate n of size mlen, copy base to n */
-    n = snewn(mlen, BignumInt);
-    i = mlen - base[0];
-    for (j = 0; j < i; j++)
-       n[j] = 0;
-    for (j = 0; j < (int)base[0]; j++)
-       n[i + j] = base[base[0] - j];
+    /*
+     * Multiply the base by r mod n, to get it into Montgomery
+     * representation.
+     */
+    base2 = modmul(base, r, mod);
+    freebn(base);
+    base = base2;
 
-    /* Allocate a and b of size 2*mlen. Set a = 1 */
-    a = snewn(2 * mlen, BignumInt);
-    b = snewn(2 * mlen, BignumInt);
-    for (i = 0; i < 2 * mlen; i++)
-       a[i] = 0;
-    a[2 * mlen - 1] = 1;
+    rn = bigmod(r, mod);               /* r mod n, i.e. Montgomerified 1 */
+
+    freebn(r);                         /* won't need this any more */
+
+    /*
+     * Set up internal arrays of the right lengths, in big-endian
+     * format, containing the base, the modulus, and the modulus's
+     * inverse.
+     */
+    n = snewn(len, BignumInt);
+    for (j = 0; j < len; j++)
+       n[len - 1 - j] = mod[j + 1];
+
+    mninv = snewn(len, BignumInt);
+    for (j = 0; j < len; j++)
+       mninv[len - 1 - j] = (j < inv[0] ? inv[j + 1] : 0);
+    freebn(inv);         /* we don't need this copy of it any more */
+    /* Now negate mninv mod r, so it's the inverse of -n rather than +n. */
+    x = snewn(len, BignumInt);
+    for (j = 0; j < len; j++)
+        x[j] = 0;
+    internal_sub(x, mninv, mninv, len);
+
+    /* x = snewn(len, BignumInt); */ /* already done above */
+    for (j = 0; j < len; j++)
+       x[len - 1 - j] = (j < base[0] ? base[j + 1] : 0);
+    freebn(base);        /* we don't need this copy of it any more */
+
+    a = snewn(2*len, BignumInt);
+    b = snewn(2*len, BignumInt);
+    for (j = 0; j < len; j++)
+       a[2*len - 1 - j] = (j < rn[0] ? rn[j + 1] : 0);
+    freebn(rn);
+
+    tmp = snewn(3*len, BignumInt);
 
     /* Skip leading zero bits of exp. */
     i = 0;
@@ -536,11 +724,11 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
     /* Main computation */
     while (i < (int)exp[0]) {
        while (j >= 0) {
-           internal_mul(a + mlen, a + mlen, b, mlen);
-           internal_mod(b, mlen * 2, m, mlen, NULL, 0);
+           internal_mul(a + len, a + len, b, len);
+            monty_reduce(b, n, mninv, tmp, len);
            if ((exp[exp[0] - i] & (1 << j)) != 0) {
-               internal_mul(b + mlen, n, a, mlen);
-               internal_mod(a, mlen * 2, m, mlen, NULL, 0);
+                internal_mul(b + len, x, a, len);
+                monty_reduce(a, n, mninv, tmp, len);
            } else {
                BignumInt *t;
                t = a;
@@ -553,38 +741,38 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
        j = BIGNUM_INT_BITS-1;
     }
 
-    /* Fixup result in case the modulus was shifted */
-    if (mshift) {
-       for (i = mlen - 1; i < 2 * mlen - 1; i++)
-           a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       a[2 * mlen - 1] = a[2 * mlen - 1] << mshift;
-       internal_mod(a, mlen * 2, m, mlen, NULL, 0);
-       for (i = 2 * mlen - 1; i >= mlen; i--)
-           a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
-    }
+    /*
+     * Final monty_reduce to get back from the adjusted Montgomery
+     * representation.
+     */
+    monty_reduce(a, n, mninv, tmp, len);
 
     /* Copy result to buffer */
     result = newbn(mod[0]);
-    for (i = 0; i < mlen; i++)
-       result[result[0] - i] = a[i + mlen];
+    for (i = 0; i < len; i++)
+       result[result[0] - i] = a[i + len];
     while (result[0] > 1 && result[result[0]] == 0)
        result[0]--;
 
     /* Free temporary arrays */
-    for (i = 0; i < 2 * mlen; i++)
+    for (i = 0; i < 3 * len; i++)
+       tmp[i] = 0;
+    sfree(tmp);
+    for (i = 0; i < 2 * len; i++)
        a[i] = 0;
     sfree(a);
-    for (i = 0; i < 2 * mlen; i++)
+    for (i = 0; i < 2 * len; i++)
        b[i] = 0;
     sfree(b);
-    for (i = 0; i < mlen; i++)
-       m[i] = 0;
-    sfree(m);
-    for (i = 0; i < mlen; i++)
+    for (i = 0; i < len; i++)
+       mninv[i] = 0;
+    sfree(mninv);
+    for (i = 0; i < len; i++)
        n[i] = 0;
     sfree(n);
-
-    freebn(base);
+    for (i = 0; i < len; i++)
+       x[i] = 0;
+    sfree(x);
 
     return result;
 }