Sebastian Kuschel reports that pfd_closing can be called for a socket
[u/mdw/putty] / sshrsa.c
index a735517..4ec95f2 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
 #include "ssh.h"
 #include "misc.h"
 
-#define GET_32BIT(cp) \
-    (((unsigned long)(unsigned char)(cp)[0] << 24) | \
-    ((unsigned long)(unsigned char)(cp)[1] << 16) | \
-    ((unsigned long)(unsigned char)(cp)[2] << 8) | \
-    ((unsigned long)(unsigned char)(cp)[3]))
-
-#define PUT_32BIT(cp, value) { \
-    (cp)[0] = (unsigned char)((value) >> 24); \
-    (cp)[1] = (unsigned char)((value) >> 16); \
-    (cp)[2] = (unsigned char)((value) >> 8); \
-    (cp)[3] = (unsigned char)(value); }
-
-int makekey(unsigned char *data, struct RSAKey *result,
+int makekey(unsigned char *data, int len, struct RSAKey *result,
            unsigned char **keystr, int order)
 {
     unsigned char *p = data;
-    int i;
+    int i, n;
+
+    if (len < 4)
+       return -1;
 
     if (result) {
        result->bits = 0;
@@ -35,36 +26,53 @@ int makekey(unsigned char *data, struct RSAKey *result,
     } else
        p += 4;
 
+    len -= 4;
+
     /*
      * order=0 means exponent then modulus (the keys sent by the
      * server). order=1 means modulus then exponent (the keys
      * stored in a keyfile).
      */
 
-    if (order == 0)
-       p += ssh1_read_bignum(p, result ? &result->exponent : NULL);
+    if (order == 0) {
+       n = ssh1_read_bignum(p, len, result ? &result->exponent : NULL);
+       if (n < 0) return -1;
+       p += n;
+       len -= n;
+    }
+
+    n = ssh1_read_bignum(p, len, result ? &result->modulus : NULL);
+    if (n < 0 || (result && bignum_bitcount(result->modulus) == 0)) return -1;
     if (result)
-       result->bytes = (((p[0] << 8) + p[1]) + 7) / 8;
+       result->bytes = n - 2;
     if (keystr)
        *keystr = p + 2;
-    p += ssh1_read_bignum(p, result ? &result->modulus : NULL);
-    if (order == 1)
-       p += ssh1_read_bignum(p, result ? &result->exponent : NULL);
-
+    p += n;
+    len -= n;
+
+    if (order == 1) {
+       n = ssh1_read_bignum(p, len, result ? &result->exponent : NULL);
+       if (n < 0) return -1;
+       p += n;
+       len -= n;
+    }
     return p - data;
 }
 
-int makeprivate(unsigned char *data, struct RSAKey *result)
+int makeprivate(unsigned char *data, int len, struct RSAKey *result)
 {
-    return ssh1_read_bignum(data, &result->private_exponent);
+    return ssh1_read_bignum(data, len, &result->private_exponent);
 }
 
-void rsaencrypt(unsigned char *data, int length, struct RSAKey *key)
+int rsaencrypt(unsigned char *data, int length, struct RSAKey *key)
 {
     Bignum b1, b2;
     int i;
     unsigned char *p;
 
+    if (key->bytes < length + 4)
+       return 0;                      /* RSA key too short! */
+
     memmove(data + key->bytes - length, data, length);
     data[0] = 0;
     data[1] = 2;
@@ -87,6 +95,8 @@ void rsaencrypt(unsigned char *data, int length, struct RSAKey *key)
 
     freebn(b1);
     freebn(b2);
+
+    return 1;
 }
 
 static void sha512_mpint(SHA512_State * s, Bignum b)
@@ -100,13 +110,87 @@ static void sha512_mpint(SHA512_State * s, Bignum b)
        lenbuf[0] = bignum_byte(b, len);
        SHA512_Bytes(s, lenbuf, 1);
     }
-    memset(lenbuf, 0, sizeof(lenbuf));
+    smemclr(lenbuf, sizeof(lenbuf));
+}
+
+/*
+ * Compute (base ^ exp) % mod, provided mod == p * q, with p,q
+ * distinct primes, and iqmp is the multiplicative inverse of q mod p.
+ * Uses Chinese Remainder Theorem to speed computation up over the
+ * obvious implementation of a single big modpow.
+ */
+Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod,
+                  Bignum p, Bignum q, Bignum iqmp)
+{
+    Bignum pm1, qm1, pexp, qexp, presult, qresult, diff, multiplier, ret0, ret;
+
+    /*
+     * Reduce the exponent mod phi(p) and phi(q), to save time when
+     * exponentiating mod p and mod q respectively. Of course, since p
+     * and q are prime, phi(p) == p-1 and similarly for q.
+     */
+    pm1 = copybn(p);
+    decbn(pm1);
+    qm1 = copybn(q);
+    decbn(qm1);
+    pexp = bigmod(exp, pm1);
+    qexp = bigmod(exp, qm1);
+
+    /*
+     * Do the two modpows.
+     */
+    presult = modpow(base, pexp, p);
+    qresult = modpow(base, qexp, q);
+
+    /*
+     * Recombine the results. We want a value which is congruent to
+     * qresult mod q, and to presult mod p.
+     *
+     * We know that iqmp * q is congruent to 1 * mod p (by definition
+     * of iqmp) and to 0 mod q (obviously). So we start with qresult
+     * (which is congruent to qresult mod both primes), and add on
+     * (presult-qresult) * (iqmp * q) which adjusts it to be congruent
+     * to presult mod p without affecting its value mod q.
+     */
+    if (bignum_cmp(presult, qresult) < 0) {
+        /*
+         * Can't subtract presult from qresult without first adding on
+         * p.
+         */
+        Bignum tmp = presult;
+        presult = bigadd(presult, p);
+        freebn(tmp);
+    }
+    diff = bigsub(presult, qresult);
+    multiplier = bigmul(iqmp, q);
+    ret0 = bigmuladd(multiplier, diff, qresult);
+
+    /*
+     * Finally, reduce the result mod n.
+     */
+    ret = bigmod(ret0, mod);
+
+    /*
+     * Free all the intermediate results before returning.
+     */
+    freebn(pm1);
+    freebn(qm1);
+    freebn(pexp);
+    freebn(qexp);
+    freebn(presult);
+    freebn(qresult);
+    freebn(diff);
+    freebn(multiplier);
+    freebn(ret0);
+
+    return ret;
 }
 
 /*
- * This function is a wrapper on modpow(). It has the same effect
- * as modpow(), but employs RSA blinding to protect against timing
- * attacks.
+ * This function is a wrapper on modpow(). It has the same effect as
+ * modpow(), but employs RSA blinding to protect against timing
+ * attacks and also uses the Chinese Remainder Theorem (implemented
+ * above, in crt_modpow()) to speed up the main operation.
  */
 static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
 {
@@ -189,9 +273,18 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
            bignum_cmp(random, key->modulus) >= 0) {
            freebn(random);
            continue;
-       } else {
-           break;
        }
+
+        /*
+         * Also, make sure it has an inverse mod modulus.
+         */
+        random_inverse = modinv(random, key->modulus);
+        if (!random_inverse) {
+           freebn(random);
+           continue;
+        }
+
+        break;
     }
 
     /*
@@ -208,10 +301,11 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
      * _y^d_, and use the _public_ exponent to compute (y^d)^e = y
      * from it, which is much faster to do.
      */
-    random_encrypted = modpow(random, key->exponent, key->modulus);
-    random_inverse = modinv(random, key->modulus);
+    random_encrypted = crt_modpow(random, key->exponent,
+                                  key->modulus, key->p, key->q, key->iqmp);
     input_blinded = modmul(input, random_encrypted, key->modulus);
-    ret_blinded = modpow(input_blinded, key->private_exponent, key->modulus);
+    ret_blinded = crt_modpow(input_blinded, key->private_exponent,
+                             key->modulus, key->p, key->q, key->iqmp);
     ret = modmul(ret_blinded, random_inverse, key->modulus);
 
     freebn(ret_blinded);
@@ -327,31 +421,46 @@ int rsa_verify(struct RSAKey *key)
     pm1 = copybn(key->p);
     decbn(pm1);
     ed = modmul(key->exponent, key->private_exponent, pm1);
+    freebn(pm1);
     cmp = bignum_cmp(ed, One);
-    sfree(ed);
+    freebn(ed);
     if (cmp != 0)
        return 0;
 
     qm1 = copybn(key->q);
     decbn(qm1);
     ed = modmul(key->exponent, key->private_exponent, qm1);
+    freebn(qm1);
     cmp = bignum_cmp(ed, One);
-    sfree(ed);
+    freebn(ed);
     if (cmp != 0)
        return 0;
 
     /*
      * Ensure p > q.
+     *
+     * I have seen key blobs in the wild which were generated with
+     * p < q, so instead of rejecting the key in this case we
+     * should instead flip them round into the canonical order of
+     * p > q. This also involves regenerating iqmp.
      */
-    if (bignum_cmp(key->p, key->q) <= 0)
-       return 0;
+    if (bignum_cmp(key->p, key->q) <= 0) {
+       Bignum tmp = key->p;
+       key->p = key->q;
+       key->q = tmp;
+
+       freebn(key->iqmp);
+       key->iqmp = modinv(key->q, key->p);
+        if (!key->iqmp)
+            return 0;
+    }
 
     /*
      * Ensure iqmp * q is congruent to 1, modulo p.
      */
     n = modmul(key->iqmp, key->q, key->p);
     cmp = bignum_cmp(n, One);
-    sfree(n);
+    freebn(n);
     if (cmp != 0)
        return 0;
 
@@ -378,13 +487,25 @@ unsigned char *rsa_public_blob(struct RSAKey *key, int *len)
 }
 
 /* Given a public blob, determine its length. */
-int rsa_public_blob_len(void *data)
+int rsa_public_blob_len(void *data, int maxlen)
 {
     unsigned char *p = (unsigned char *)data;
+    int n;
 
+    if (maxlen < 4)
+       return -1;
     p += 4;                           /* length word */
-    p += ssh1_read_bignum(p, NULL);    /* exponent */
-    p += ssh1_read_bignum(p, NULL);    /* modulus */
+    maxlen -= 4;
+
+    n = ssh1_read_bignum(p, maxlen, NULL);    /* exponent */
+    if (n < 0)
+       return -1;
+    p += n;
+
+    n = ssh1_read_bignum(p, maxlen, NULL);    /* modulus */
+    if (n < 0)
+       return -1;
+    p += n;
 
     return p - (unsigned char *)data;
 }
@@ -397,6 +518,12 @@ void freersakey(struct RSAKey *key)
        freebn(key->exponent);
     if (key->private_exponent)
        freebn(key->private_exponent);
+    if (key->p)
+       freebn(key->p);
+    if (key->q)
+       freebn(key->q);
+    if (key->iqmp)
+       freebn(key->iqmp);
     if (key->comment)
        sfree(key->comment);
 }
@@ -410,7 +537,9 @@ static void getstring(char **data, int *datalen, char **p, int *length)
     *p = NULL;
     if (*datalen < 4)
        return;
-    *length = GET_32BIT(*data);
+    *length = toint(GET_32BIT(*data));
+    if (*length < 0)
+        return;
     *datalen -= 4;
     *data += 4;
     if (*datalen < *length)
@@ -432,6 +561,8 @@ static Bignum getmp(char **data, int *datalen)
     return b;
 }
 
+static void rsa2_freekey(void *key);   /* forward reference */
+
 static void *rsa2_newkey(char *data, int len)
 {
     char *p;
@@ -439,8 +570,6 @@ static void *rsa2_newkey(char *data, int len)
     struct RSAKey *rsa;
 
     rsa = snew(struct RSAKey);
-    if (!rsa)
-       return NULL;
     getstring(&data, &len, &p, &slen);
 
     if (!p || slen != 7 || memcmp(p, "ssh-rsa", 7)) {
@@ -450,8 +579,14 @@ static void *rsa2_newkey(char *data, int len)
     rsa->exponent = getmp(&data, &len);
     rsa->modulus = getmp(&data, &len);
     rsa->private_exponent = NULL;
+    rsa->p = rsa->q = rsa->iqmp = NULL;
     rsa->comment = NULL;
 
+    if (!rsa->exponent || !rsa->modulus) {
+        rsa2_freekey(rsa);
+        return NULL;
+    }
+
     return rsa;
 }
 
@@ -574,8 +709,6 @@ static void *rsa2_openssh_createkey(unsigned char **blob, int *len)
     struct RSAKey *rsa;
 
     rsa = snew(struct RSAKey);
-    if (!rsa)
-       return NULL;
     rsa->comment = NULL;
 
     rsa->modulus = getmp(b, len);
@@ -587,13 +720,12 @@ static void *rsa2_openssh_createkey(unsigned char **blob, int *len)
 
     if (!rsa->modulus || !rsa->exponent || !rsa->private_exponent ||
        !rsa->iqmp || !rsa->p || !rsa->q) {
-       sfree(rsa->modulus);
-       sfree(rsa->exponent);
-       sfree(rsa->private_exponent);
-       sfree(rsa->iqmp);
-       sfree(rsa->p);
-       sfree(rsa->q);
-       sfree(rsa);
+        rsa2_freekey(rsa);
+       return NULL;
+    }
+
+    if (!rsa_verify(rsa)) {
+       rsa2_freekey(rsa);
        return NULL;
     }
 
@@ -629,6 +761,18 @@ static int rsa2_openssh_fmtkey(void *key, unsigned char *blob, int len)
     return bloblen;
 }
 
+static int rsa2_pubkey_bits(void *blob, int len)
+{
+    struct RSAKey *rsa;
+    int ret;
+
+    rsa = rsa2_newkey((char *) blob, len);
+    ret = bignum_bitcount(rsa->modulus);
+    rsa2_freekey(rsa);
+
+    return ret;
+}
+
 static char *rsa2_fingerprint(void *key)
 {
     struct RSAKey *rsa = (struct RSAKey *) key;
@@ -710,12 +854,14 @@ static int rsa2_verifysig(void *key, char *sig, int siglen,
        return 0;
     }
     in = getmp(&sig, &siglen);
+    if (!in)
+        return 0;
     out = modpow(in, rsa->exponent, rsa->modulus);
     freebn(in);
 
     ret = 1;
 
-    bytes = bignum_bitcount(rsa->modulus) / 8;
+    bytes = (bignum_bitcount(rsa->modulus)+7) / 8;
     /* Top (partial) byte should be zero. */
     if (bignum_byte(out, bytes - 1) != 0)
        ret = 0;
@@ -756,6 +902,7 @@ static unsigned char *rsa2_sign(void *key, char *data, int datalen,
     SHA_Simple(data, datalen, hash);
 
     nbytes = (bignum_bitcount(rsa->modulus) - 1) / 8;
+    assert(1 <= nbytes - 20 - ASN1_LEN);
     bytes = snewn(nbytes, unsigned char);
 
     bytes[0] = 1;
@@ -794,9 +941,164 @@ const struct ssh_signkey ssh_rsa = {
     rsa2_createkey,
     rsa2_openssh_createkey,
     rsa2_openssh_fmtkey,
+    rsa2_pubkey_bits,
     rsa2_fingerprint,
     rsa2_verifysig,
     rsa2_sign,
     "ssh-rsa",
     "rsa2"
 };
+
+void *ssh_rsakex_newkey(char *data, int len)
+{
+    return rsa2_newkey(data, len);
+}
+
+void ssh_rsakex_freekey(void *key)
+{
+    rsa2_freekey(key);
+}
+
+int ssh_rsakex_klen(void *key)
+{
+    struct RSAKey *rsa = (struct RSAKey *) key;
+
+    return bignum_bitcount(rsa->modulus);
+}
+
+static void oaep_mask(const struct ssh_hash *h, void *seed, int seedlen,
+                     void *vdata, int datalen)
+{
+    unsigned char *data = (unsigned char *)vdata;
+    unsigned count = 0;
+
+    while (datalen > 0) {
+        int i, max = (datalen > h->hlen ? h->hlen : datalen);
+        void *s;
+        unsigned char counter[4], hash[SSH2_KEX_MAX_HASH_LEN];
+
+       assert(h->hlen <= SSH2_KEX_MAX_HASH_LEN);
+        PUT_32BIT(counter, count);
+        s = h->init();
+        h->bytes(s, seed, seedlen);
+        h->bytes(s, counter, 4);
+        h->final(s, hash);
+        count++;
+
+        for (i = 0; i < max; i++)
+            data[i] ^= hash[i];
+
+        data += max;
+        datalen -= max;
+    }
+}
+
+void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen,
+                        unsigned char *out, int outlen,
+                        void *key)
+{
+    Bignum b1, b2;
+    struct RSAKey *rsa = (struct RSAKey *) key;
+    int k, i;
+    char *p;
+    const int HLEN = h->hlen;
+
+    /*
+     * Here we encrypt using RSAES-OAEP. Essentially this means:
+     * 
+     *  - we have a SHA-based `mask generation function' which
+     *    creates a pseudo-random stream of mask data
+     *    deterministically from an input chunk of data.
+     * 
+     *  - we have a random chunk of data called a seed.
+     * 
+     *  - we use the seed to generate a mask which we XOR with our
+     *    plaintext.
+     * 
+     *  - then we use _the masked plaintext_ to generate a mask
+     *    which we XOR with the seed.
+     * 
+     *  - then we concatenate the masked seed and the masked
+     *    plaintext, and RSA-encrypt that lot.
+     * 
+     * The result is that the data input to the encryption function
+     * is random-looking and (hopefully) contains no exploitable
+     * structure such as PKCS1-v1_5 does.
+     * 
+     * For a precise specification, see RFC 3447, section 7.1.1.
+     * Some of the variable names below are derived from that, so
+     * it'd probably help to read it anyway.
+     */
+
+    /* k denotes the length in octets of the RSA modulus. */
+    k = (7 + bignum_bitcount(rsa->modulus)) / 8;
+
+    /* The length of the input data must be at most k - 2hLen - 2. */
+    assert(inlen > 0 && inlen <= k - 2*HLEN - 2);
+
+    /* The length of the output data wants to be precisely k. */
+    assert(outlen == k);
+
+    /*
+     * Now perform EME-OAEP encoding. First set up all the unmasked
+     * output data.
+     */
+    /* Leading byte zero. */
+    out[0] = 0;
+    /* At position 1, the seed: HLEN bytes of random data. */
+    for (i = 0; i < HLEN; i++)
+        out[i + 1] = random_byte();
+    /* At position 1+HLEN, the data block DB, consisting of: */
+    /* The hash of the label (we only support an empty label here) */
+    h->final(h->init(), out + HLEN + 1);
+    /* A bunch of zero octets */
+    memset(out + 2*HLEN + 1, 0, outlen - (2*HLEN + 1));
+    /* A single 1 octet, followed by the input message data. */
+    out[outlen - inlen - 1] = 1;
+    memcpy(out + outlen - inlen, in, inlen);
+
+    /*
+     * Now use the seed data to mask the block DB.
+     */
+    oaep_mask(h, out+1, HLEN, out+HLEN+1, outlen-HLEN-1);
+
+    /*
+     * And now use the masked DB to mask the seed itself.
+     */
+    oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN);
+
+    /*
+     * Now `out' contains precisely the data we want to
+     * RSA-encrypt.
+     */
+    b1 = bignum_from_bytes(out, outlen);
+    b2 = modpow(b1, rsa->exponent, rsa->modulus);
+    p = (char *)out;
+    for (i = outlen; i--;) {
+       *p++ = bignum_byte(b2, i);
+    }
+    freebn(b1);
+    freebn(b2);
+
+    /*
+     * And we're done.
+     */
+}
+
+static const struct ssh_kex ssh_rsa_kex_sha1 = {
+    "rsa1024-sha1", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha1
+};
+
+static const struct ssh_kex ssh_rsa_kex_sha256 = {
+    "rsa2048-sha256", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha256
+};
+
+static const struct ssh_kex *const rsa_kex_list[] = {
+    &ssh_rsa_kex_sha256,
+    &ssh_rsa_kex_sha1
+};
+
+const struct ssh_kexes ssh_rsa_kex = {
+    sizeof(rsa_kex_list) / sizeof(*rsa_kex_list),
+    rsa_kex_list
+};