X-Git-Url: https://git.distorted.org.uk/u/mdw/putty/blobdiff_plain/7108a872e03aff0fbc4dfb4b7f0f9718d45463b2..HEAD:/sshrsa.c diff --git a/sshrsa.c b/sshrsa.c index 12229e63..4ec95f23 100644 --- a/sshrsa.c +++ b/sshrsa.c @@ -110,13 +110,87 @@ static void sha512_mpint(SHA512_State * s, Bignum b) lenbuf[0] = bignum_byte(b, len); SHA512_Bytes(s, lenbuf, 1); } - memset(lenbuf, 0, sizeof(lenbuf)); + smemclr(lenbuf, sizeof(lenbuf)); } /* - * This function is a wrapper on modpow(). It has the same effect - * as modpow(), but employs RSA blinding to protect against timing - * attacks. + * Compute (base ^ exp) % mod, provided mod == p * q, with p,q + * distinct primes, and iqmp is the multiplicative inverse of q mod p. + * Uses Chinese Remainder Theorem to speed computation up over the + * obvious implementation of a single big modpow. + */ +Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod, + Bignum p, Bignum q, Bignum iqmp) +{ + Bignum pm1, qm1, pexp, qexp, presult, qresult, diff, multiplier, ret0, ret; + + /* + * Reduce the exponent mod phi(p) and phi(q), to save time when + * exponentiating mod p and mod q respectively. Of course, since p + * and q are prime, phi(p) == p-1 and similarly for q. + */ + pm1 = copybn(p); + decbn(pm1); + qm1 = copybn(q); + decbn(qm1); + pexp = bigmod(exp, pm1); + qexp = bigmod(exp, qm1); + + /* + * Do the two modpows. + */ + presult = modpow(base, pexp, p); + qresult = modpow(base, qexp, q); + + /* + * Recombine the results. We want a value which is congruent to + * qresult mod q, and to presult mod p. + * + * We know that iqmp * q is congruent to 1 * mod p (by definition + * of iqmp) and to 0 mod q (obviously). So we start with qresult + * (which is congruent to qresult mod both primes), and add on + * (presult-qresult) * (iqmp * q) which adjusts it to be congruent + * to presult mod p without affecting its value mod q. + */ + if (bignum_cmp(presult, qresult) < 0) { + /* + * Can't subtract presult from qresult without first adding on + * p. + */ + Bignum tmp = presult; + presult = bigadd(presult, p); + freebn(tmp); + } + diff = bigsub(presult, qresult); + multiplier = bigmul(iqmp, q); + ret0 = bigmuladd(multiplier, diff, qresult); + + /* + * Finally, reduce the result mod n. + */ + ret = bigmod(ret0, mod); + + /* + * Free all the intermediate results before returning. + */ + freebn(pm1); + freebn(qm1); + freebn(pexp); + freebn(qexp); + freebn(presult); + freebn(qresult); + freebn(diff); + freebn(multiplier); + freebn(ret0); + + return ret; +} + +/* + * This function is a wrapper on modpow(). It has the same effect as + * modpow(), but employs RSA blinding to protect against timing + * attacks and also uses the Chinese Remainder Theorem (implemented + * above, in crt_modpow()) to speed up the main operation. */ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) { @@ -199,9 +273,18 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) bignum_cmp(random, key->modulus) >= 0) { freebn(random); continue; - } else { - break; } + + /* + * Also, make sure it has an inverse mod modulus. + */ + random_inverse = modinv(random, key->modulus); + if (!random_inverse) { + freebn(random); + continue; + } + + break; } /* @@ -218,10 +301,11 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) * _y^d_, and use the _public_ exponent to compute (y^d)^e = y * from it, which is much faster to do. */ - random_encrypted = modpow(random, key->exponent, key->modulus); - random_inverse = modinv(random, key->modulus); + random_encrypted = crt_modpow(random, key->exponent, + key->modulus, key->p, key->q, key->iqmp); input_blinded = modmul(input, random_encrypted, key->modulus); - ret_blinded = modpow(input_blinded, key->private_exponent, key->modulus); + ret_blinded = crt_modpow(input_blinded, key->private_exponent, + key->modulus, key->p, key->q, key->iqmp); ret = modmul(ret_blinded, random_inverse, key->modulus); freebn(ret_blinded); @@ -337,31 +421,46 @@ int rsa_verify(struct RSAKey *key) pm1 = copybn(key->p); decbn(pm1); ed = modmul(key->exponent, key->private_exponent, pm1); + freebn(pm1); cmp = bignum_cmp(ed, One); - sfree(ed); + freebn(ed); if (cmp != 0) return 0; qm1 = copybn(key->q); decbn(qm1); ed = modmul(key->exponent, key->private_exponent, qm1); + freebn(qm1); cmp = bignum_cmp(ed, One); - sfree(ed); + freebn(ed); if (cmp != 0) return 0; /* * Ensure p > q. + * + * I have seen key blobs in the wild which were generated with + * p < q, so instead of rejecting the key in this case we + * should instead flip them round into the canonical order of + * p > q. This also involves regenerating iqmp. */ - if (bignum_cmp(key->p, key->q) <= 0) - return 0; + if (bignum_cmp(key->p, key->q) <= 0) { + Bignum tmp = key->p; + key->p = key->q; + key->q = tmp; + + freebn(key->iqmp); + key->iqmp = modinv(key->q, key->p); + if (!key->iqmp) + return 0; + } /* * Ensure iqmp * q is congruent to 1, modulo p. */ n = modmul(key->iqmp, key->q, key->p); cmp = bignum_cmp(n, One); - sfree(n); + freebn(n); if (cmp != 0) return 0; @@ -419,6 +518,12 @@ void freersakey(struct RSAKey *key) freebn(key->exponent); if (key->private_exponent) freebn(key->private_exponent); + if (key->p) + freebn(key->p); + if (key->q) + freebn(key->q); + if (key->iqmp) + freebn(key->iqmp); if (key->comment) sfree(key->comment); } @@ -432,7 +537,9 @@ static void getstring(char **data, int *datalen, char **p, int *length) *p = NULL; if (*datalen < 4) return; - *length = GET_32BIT(*data); + *length = toint(GET_32BIT(*data)); + if (*length < 0) + return; *datalen -= 4; *data += 4; if (*datalen < *length) @@ -454,6 +561,8 @@ static Bignum getmp(char **data, int *datalen) return b; } +static void rsa2_freekey(void *key); /* forward reference */ + static void *rsa2_newkey(char *data, int len) { char *p; @@ -461,8 +570,6 @@ static void *rsa2_newkey(char *data, int len) struct RSAKey *rsa; rsa = snew(struct RSAKey); - if (!rsa) - return NULL; getstring(&data, &len, &p, &slen); if (!p || slen != 7 || memcmp(p, "ssh-rsa", 7)) { @@ -472,8 +579,14 @@ static void *rsa2_newkey(char *data, int len) rsa->exponent = getmp(&data, &len); rsa->modulus = getmp(&data, &len); rsa->private_exponent = NULL; + rsa->p = rsa->q = rsa->iqmp = NULL; rsa->comment = NULL; + if (!rsa->exponent || !rsa->modulus) { + rsa2_freekey(rsa); + return NULL; + } + return rsa; } @@ -596,8 +709,6 @@ static void *rsa2_openssh_createkey(unsigned char **blob, int *len) struct RSAKey *rsa; rsa = snew(struct RSAKey); - if (!rsa) - return NULL; rsa->comment = NULL; rsa->modulus = getmp(b, len); @@ -609,13 +720,12 @@ static void *rsa2_openssh_createkey(unsigned char **blob, int *len) if (!rsa->modulus || !rsa->exponent || !rsa->private_exponent || !rsa->iqmp || !rsa->p || !rsa->q) { - sfree(rsa->modulus); - sfree(rsa->exponent); - sfree(rsa->private_exponent); - sfree(rsa->iqmp); - sfree(rsa->p); - sfree(rsa->q); - sfree(rsa); + rsa2_freekey(rsa); + return NULL; + } + + if (!rsa_verify(rsa)) { + rsa2_freekey(rsa); return NULL; } @@ -744,6 +854,8 @@ static int rsa2_verifysig(void *key, char *sig, int siglen, return 0; } in = getmp(&sig, &siglen); + if (!in) + return 0; out = modpow(in, rsa->exponent, rsa->modulus); freebn(in);