Implement the Montgomery technique for speeding up modular
[u/mdw/putty] / sshrsa.c
index 6db265e..3c0feaf 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
@@ -352,9 +352,20 @@ int rsa_verify(struct RSAKey *key)
 
     /*
      * Ensure p > q.
+     *
+     * I have seen key blobs in the wild which were generated with
+     * p < q, so instead of rejecting the key in this case we
+     * should instead flip them round into the canonical order of
+     * p > q. This also involves regenerating iqmp.
      */
-    if (bignum_cmp(key->p, key->q) <= 0)
-       return 0;
+    if (bignum_cmp(key->p, key->q) <= 0) {
+       Bignum tmp = key->p;
+       key->p = key->q;
+       key->q = tmp;
+
+       freebn(key->iqmp);
+       key->iqmp = modinv(key->q, key->p);
+    }
 
     /*
      * Ensure iqmp * q is congruent to 1, modulo p.
@@ -419,6 +430,12 @@ void freersakey(struct RSAKey *key)
        freebn(key->exponent);
     if (key->private_exponent)
        freebn(key->private_exponent);
+    if (key->p)
+       freebn(key->p);
+    if (key->q)
+       freebn(key->q);
+    if (key->iqmp)
+       freebn(key->iqmp);
     if (key->comment)
        sfree(key->comment);
 }
@@ -472,6 +489,7 @@ static void *rsa2_newkey(char *data, int len)
     rsa->exponent = getmp(&data, &len);
     rsa->modulus = getmp(&data, &len);
     rsa->private_exponent = NULL;
+    rsa->p = rsa->q = rsa->iqmp = NULL;
     rsa->comment = NULL;
 
     return rsa;
@@ -863,8 +881,9 @@ static void oaep_mask(const struct ssh_hash *h, void *seed, int seedlen,
     while (datalen > 0) {
         int i, max = (datalen > h->hlen ? h->hlen : datalen);
         void *s;
-        unsigned char counter[4], hash[h->hlen];
+        unsigned char counter[4], hash[SSH2_KEX_MAX_HASH_LEN];
 
+       assert(h->hlen <= SSH2_KEX_MAX_HASH_LEN);
         PUT_32BIT(counter, count);
         s = h->init();
         h->bytes(s, seed, seedlen);
@@ -960,7 +979,7 @@ void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen,
      */
     b1 = bignum_from_bytes(out, outlen);
     b2 = modpow(b1, rsa->exponent, rsa->modulus);
-    p = out;
+    p = (char *)out;
     for (i = outlen; i--;) {
        *p++ = bignum_byte(b2, i);
     }