Increase FONT_MAXNO from 0x2f to 0x40, to ensure the fonts[] array
[u/mdw/putty] / sshrsa.c
index 2dc09d1..7fb9694 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
@@ -110,13 +110,87 @@ static void sha512_mpint(SHA512_State * s, Bignum b)
        lenbuf[0] = bignum_byte(b, len);
        SHA512_Bytes(s, lenbuf, 1);
     }
-    memset(lenbuf, 0, sizeof(lenbuf));
+    smemclr(lenbuf, sizeof(lenbuf));
 }
 
 /*
- * This function is a wrapper on modpow(). It has the same effect
- * as modpow(), but employs RSA blinding to protect against timing
- * attacks.
+ * Compute (base ^ exp) % mod, provided mod == p * q, with p,q
+ * distinct primes, and iqmp is the multiplicative inverse of q mod p.
+ * Uses Chinese Remainder Theorem to speed computation up over the
+ * obvious implementation of a single big modpow.
+ */
+Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod,
+                  Bignum p, Bignum q, Bignum iqmp)
+{
+    Bignum pm1, qm1, pexp, qexp, presult, qresult, diff, multiplier, ret0, ret;
+
+    /*
+     * Reduce the exponent mod phi(p) and phi(q), to save time when
+     * exponentiating mod p and mod q respectively. Of course, since p
+     * and q are prime, phi(p) == p-1 and similarly for q.
+     */
+    pm1 = copybn(p);
+    decbn(pm1);
+    qm1 = copybn(q);
+    decbn(qm1);
+    pexp = bigmod(exp, pm1);
+    qexp = bigmod(exp, qm1);
+
+    /*
+     * Do the two modpows.
+     */
+    presult = modpow(base, pexp, p);
+    qresult = modpow(base, qexp, q);
+
+    /*
+     * Recombine the results. We want a value which is congruent to
+     * qresult mod q, and to presult mod p.
+     *
+     * We know that iqmp * q is congruent to 1 * mod p (by definition
+     * of iqmp) and to 0 mod q (obviously). So we start with qresult
+     * (which is congruent to qresult mod both primes), and add on
+     * (presult-qresult) * (iqmp * q) which adjusts it to be congruent
+     * to presult mod p without affecting its value mod q.
+     */
+    if (bignum_cmp(presult, qresult) < 0) {
+        /*
+         * Can't subtract presult from qresult without first adding on
+         * p.
+         */
+        Bignum tmp = presult;
+        presult = bigadd(presult, p);
+        freebn(tmp);
+    }
+    diff = bigsub(presult, qresult);
+    multiplier = bigmul(iqmp, q);
+    ret0 = bigmuladd(multiplier, diff, qresult);
+
+    /*
+     * Finally, reduce the result mod n.
+     */
+    ret = bigmod(ret0, mod);
+
+    /*
+     * Free all the intermediate results before returning.
+     */
+    freebn(pm1);
+    freebn(qm1);
+    freebn(pexp);
+    freebn(qexp);
+    freebn(presult);
+    freebn(qresult);
+    freebn(diff);
+    freebn(multiplier);
+    freebn(ret0);
+
+    return ret;
+}
+
+/*
+ * This function is a wrapper on modpow(). It has the same effect as
+ * modpow(), but employs RSA blinding to protect against timing
+ * attacks and also uses the Chinese Remainder Theorem (implemented
+ * above, in crt_modpow()) to speed up the main operation.
  */
 static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
 {
@@ -218,10 +292,12 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
      * _y^d_, and use the _public_ exponent to compute (y^d)^e = y
      * from it, which is much faster to do.
      */
-    random_encrypted = modpow(random, key->exponent, key->modulus);
+    random_encrypted = crt_modpow(random, key->exponent,
+                                  key->modulus, key->p, key->q, key->iqmp);
     random_inverse = modinv(random, key->modulus);
     input_blinded = modmul(input, random_encrypted, key->modulus);
-    ret_blinded = modpow(input_blinded, key->private_exponent, key->modulus);
+    ret_blinded = crt_modpow(input_blinded, key->private_exponent,
+                             key->modulus, key->p, key->q, key->iqmp);
     ret = modmul(ret_blinded, random_inverse, key->modulus);
 
     freebn(ret_blinded);
@@ -337,6 +413,7 @@ int rsa_verify(struct RSAKey *key)
     pm1 = copybn(key->p);
     decbn(pm1);
     ed = modmul(key->exponent, key->private_exponent, pm1);
+    freebn(pm1);
     cmp = bignum_cmp(ed, One);
     sfree(ed);
     if (cmp != 0)
@@ -345,6 +422,7 @@ int rsa_verify(struct RSAKey *key)
     qm1 = copybn(key->q);
     decbn(qm1);
     ed = modmul(key->exponent, key->private_exponent, qm1);
+    freebn(qm1);
     cmp = bignum_cmp(ed, One);
     sfree(ed);
     if (cmp != 0)
@@ -352,9 +430,20 @@ int rsa_verify(struct RSAKey *key)
 
     /*
      * Ensure p > q.
+     *
+     * I have seen key blobs in the wild which were generated with
+     * p < q, so instead of rejecting the key in this case we
+     * should instead flip them round into the canonical order of
+     * p > q. This also involves regenerating iqmp.
      */
-    if (bignum_cmp(key->p, key->q) <= 0)
-       return 0;
+    if (bignum_cmp(key->p, key->q) <= 0) {
+       Bignum tmp = key->p;
+       key->p = key->q;
+       key->q = tmp;
+
+       freebn(key->iqmp);
+       key->iqmp = modinv(key->q, key->p);
+    }
 
     /*
      * Ensure iqmp * q is congruent to 1, modulo p.
@@ -419,6 +508,12 @@ void freersakey(struct RSAKey *key)
        freebn(key->exponent);
     if (key->private_exponent)
        freebn(key->private_exponent);
+    if (key->p)
+       freebn(key->p);
+    if (key->q)
+       freebn(key->q);
+    if (key->iqmp)
+       freebn(key->iqmp);
     if (key->comment)
        sfree(key->comment);
 }
@@ -432,7 +527,9 @@ static void getstring(char **data, int *datalen, char **p, int *length)
     *p = NULL;
     if (*datalen < 4)
        return;
-    *length = GET_32BIT(*data);
+    *length = toint(GET_32BIT(*data));
+    if (*length < 0)
+        return;
     *datalen -= 4;
     *data += 4;
     if (*datalen < *length)
@@ -472,6 +569,7 @@ static void *rsa2_newkey(char *data, int len)
     rsa->exponent = getmp(&data, &len);
     rsa->modulus = getmp(&data, &len);
     rsa->private_exponent = NULL;
+    rsa->p = rsa->q = rsa->iqmp = NULL;
     rsa->comment = NULL;
 
     return rsa;
@@ -961,7 +1059,7 @@ void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen,
      */
     b1 = bignum_from_bytes(out, outlen);
     b2 = modpow(b1, rsa->exponent, rsa->modulus);
-    p = out;
+    p = (char *)out;
     for (i = outlen; i--;) {
        *p++ = bignum_byte(b2, i);
     }