Implement the Chinese Remainder Theorem optimisation for speeding up
authorsimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Fri, 18 Feb 2011 08:25:39 +0000 (08:25 +0000)
committersimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Fri, 18 Feb 2011 08:25:39 +0000 (08:25 +0000)
RSA private key operations by making use of the fact that we know the
factors of the modulus.

git-svn-id: svn://svn.tartarus.org/sgt/putty@9095 cda61777-01e9-0310-a592-d414129be87e

ssh.h
sshbn.c
sshrsa.c

diff --git a/ssh.h b/ssh.h
index 3d02ad7..6acd878 100644 (file)
--- a/ssh.h
+++ b/ssh.h
@@ -447,6 +447,8 @@ int ssh1_write_bignum(void *data, Bignum bn);
 Bignum biggcd(Bignum a, Bignum b);
 unsigned short bignum_mod_short(Bignum number, unsigned short modulus);
 Bignum bignum_add_long(Bignum number, unsigned long addend);
+Bignum bigadd(Bignum a, Bignum b);
+Bignum bigsub(Bignum a, Bignum b);
 Bignum bigmul(Bignum a, Bignum b);
 Bignum bigmuladd(Bignum a, Bignum b, Bignum addend);
 Bignum bigdiv(Bignum a, Bignum b);
diff --git a/sshbn.c b/sshbn.c
index 5620bb0..7368a43 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -1191,6 +1191,69 @@ Bignum bigmul(Bignum a, Bignum b)
 }
 
 /*
+ * Simple addition.
+ */
+Bignum bigadd(Bignum a, Bignum b)
+{
+    int alen = a[0], blen = b[0];
+    int rlen = (alen > blen ? alen : blen) + 1;
+    int i, maxspot;
+    Bignum ret;
+    BignumDblInt carry;
+
+    ret = newbn(rlen);
+
+    carry = 0;
+    maxspot = 0;
+    for (i = 1; i <= rlen; i++) {
+        carry += (i <= (int)a[0] ? a[i] : 0);
+        carry += (i <= (int)b[0] ? b[i] : 0);
+        ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
+        carry >>= BIGNUM_INT_BITS;
+        if (ret[i] != 0 && i > maxspot)
+            maxspot = i;
+    }
+    ret[0] = maxspot;
+
+    return ret;
+}
+
+/*
+ * Subtraction. Returns a-b, or NULL if the result would come out
+ * negative (recall that this entire bignum module only handles
+ * positive numbers).
+ */
+Bignum bigsub(Bignum a, Bignum b)
+{
+    int alen = a[0], blen = b[0];
+    int rlen = (alen > blen ? alen : blen);
+    int i, maxspot;
+    Bignum ret;
+    BignumDblInt carry;
+
+    ret = newbn(rlen);
+
+    carry = 1;
+    maxspot = 0;
+    for (i = 1; i <= rlen; i++) {
+        carry += (i <= (int)a[0] ? a[i] : 0);
+        carry += (i <= (int)b[0] ? b[i] ^ BIGNUM_INT_MASK : BIGNUM_INT_MASK);
+        ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
+        carry >>= BIGNUM_INT_BITS;
+        if (ret[i] != 0 && i > maxspot)
+            maxspot = i;
+    }
+    ret[0] = maxspot;
+
+    if (!carry) {
+        freebn(ret);
+        return NULL;
+    }
+
+    return ret;
+}
+
+/*
  * Create a bignum which is the bitmask covering another one. That
  * is, the smallest integer which is >= N and is also one less than
  * a power of two.
index 3c0feaf..0c1b2ef 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
@@ -114,9 +114,83 @@ static void sha512_mpint(SHA512_State * s, Bignum b)
 }
 
 /*
- * This function is a wrapper on modpow(). It has the same effect
- * as modpow(), but employs RSA blinding to protect against timing
- * attacks.
+ * Compute (base ^ exp) % mod, provided mod == p * q, with p,q
+ * distinct primes, and iqmp is the multiplicative inverse of q mod p.
+ * Uses Chinese Remainder Theorem to speed computation up over the
+ * obvious implementation of a single big modpow.
+ */
+Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod,
+                  Bignum p, Bignum q, Bignum iqmp)
+{
+    Bignum pm1, qm1, pexp, qexp, presult, qresult, diff, multiplier, ret0, ret;
+
+    /*
+     * Reduce the exponent mod phi(p) and phi(q), to save time when
+     * exponentiating mod p and mod q respectively. Of course, since p
+     * and q are prime, phi(p) == p-1 and similarly for q.
+     */
+    pm1 = copybn(p);
+    decbn(pm1);
+    qm1 = copybn(q);
+    decbn(qm1);
+    pexp = bigmod(exp, pm1);
+    qexp = bigmod(exp, qm1);
+
+    /*
+     * Do the two modpows.
+     */
+    presult = modpow(base, pexp, p);
+    qresult = modpow(base, qexp, q);
+
+    /*
+     * Recombine the results. We want a value which is congruent to
+     * qresult mod q, and to presult mod p.
+     *
+     * We know that iqmp * q is congruent to 1 * mod p (by definition
+     * of iqmp) and to 0 mod q (obviously). So we start with qresult
+     * (which is congruent to qresult mod both primes), and add on
+     * (presult-qresult) * (iqmp * q) which adjusts it to be congruent
+     * to presult mod p without affecting its value mod q.
+     */
+    if (bignum_cmp(presult, qresult) < 0) {
+        /*
+         * Can't subtract presult from qresult without first adding on
+         * p.
+         */
+        Bignum tmp = presult;
+        presult = bigadd(presult, p);
+        freebn(tmp);
+    }
+    diff = bigsub(presult, qresult);
+    multiplier = bigmul(iqmp, q);
+    ret0 = bigmuladd(multiplier, diff, qresult);
+
+    /*
+     * Finally, reduce the result mod n.
+     */
+    ret = bigmod(ret0, mod);
+
+    /*
+     * Free all the intermediate results before returning.
+     */
+    freebn(pm1);
+    freebn(qm1);
+    freebn(pexp);
+    freebn(qexp);
+    freebn(presult);
+    freebn(qresult);
+    freebn(diff);
+    freebn(multiplier);
+    freebn(ret0);
+
+    return ret;
+}
+
+/*
+ * This function is a wrapper on modpow(). It has the same effect as
+ * modpow(), but employs RSA blinding to protect against timing
+ * attacks and also uses the Chinese Remainder Theorem (implemented
+ * above, in crt_modpow()) to speed up the main operation.
  */
 static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
 {
@@ -218,10 +292,12 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
      * _y^d_, and use the _public_ exponent to compute (y^d)^e = y
      * from it, which is much faster to do.
      */
-    random_encrypted = modpow(random, key->exponent, key->modulus);
+    random_encrypted = crt_modpow(random, key->exponent,
+                                  key->modulus, key->p, key->q, key->iqmp);
     random_inverse = modinv(random, key->modulus);
     input_blinded = modmul(input, random_encrypted, key->modulus);
-    ret_blinded = modpow(input_blinded, key->private_exponent, key->modulus);
+    ret_blinded = crt_modpow(input_blinded, key->private_exponent,
+                             key->modulus, key->p, key->q, key->iqmp);
     ret = modmul(ret_blinded, random_inverse, key->modulus);
 
     freebn(ret_blinded);