Add support for RFC 4432 RSA key exchange, the patch for which has been
[u/mdw/putty] / sshrsa.c
index f684c2a..6db265e 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
 #include "ssh.h"
 #include "misc.h"
 
-#define GET_32BIT(cp) \
-    (((unsigned long)(unsigned char)(cp)[0] << 24) | \
-    ((unsigned long)(unsigned char)(cp)[1] << 16) | \
-    ((unsigned long)(unsigned char)(cp)[2] << 8) | \
-    ((unsigned long)(unsigned char)(cp)[3]))
-
-#define PUT_32BIT(cp, value) { \
-    (cp)[0] = (unsigned char)((value) >> 24); \
-    (cp)[1] = (unsigned char)((value) >> 16); \
-    (cp)[2] = (unsigned char)((value) >> 8); \
-    (cp)[3] = (unsigned char)(value); }
-
 int makekey(unsigned char *data, int len, struct RSAKey *result,
            unsigned char **keystr, int order)
 {
@@ -54,7 +42,7 @@ int makekey(unsigned char *data, int len, struct RSAKey *result,
     }
 
     n = ssh1_read_bignum(p, len, result ? &result->modulus : NULL);
-    if (n < 0 || bignum_bitcount(result->modulus) == 0) return -1;
+    if (n < 0 || (result && bignum_bitcount(result->modulus) == 0)) return -1;
     if (result)
        result->bytes = n - 2;
     if (keystr)
@@ -802,6 +790,7 @@ static unsigned char *rsa2_sign(void *key, char *data, int datalen,
     SHA_Simple(data, datalen, hash);
 
     nbytes = (bignum_bitcount(rsa->modulus) - 1) / 8;
+    assert(1 <= nbytes - 20 - ASN1_LEN);
     bytes = snewn(nbytes, unsigned char);
 
     bytes[0] = 1;
@@ -847,3 +836,156 @@ const struct ssh_signkey ssh_rsa = {
     "ssh-rsa",
     "rsa2"
 };
+
+void *ssh_rsakex_newkey(char *data, int len)
+{
+    return rsa2_newkey(data, len);
+}
+
+void ssh_rsakex_freekey(void *key)
+{
+    rsa2_freekey(key);
+}
+
+int ssh_rsakex_klen(void *key)
+{
+    struct RSAKey *rsa = (struct RSAKey *) key;
+
+    return bignum_bitcount(rsa->modulus);
+}
+
+static void oaep_mask(const struct ssh_hash *h, void *seed, int seedlen,
+                     void *vdata, int datalen)
+{
+    unsigned char *data = (unsigned char *)vdata;
+    unsigned count = 0;
+
+    while (datalen > 0) {
+        int i, max = (datalen > h->hlen ? h->hlen : datalen);
+        void *s;
+        unsigned char counter[4], hash[h->hlen];
+
+        PUT_32BIT(counter, count);
+        s = h->init();
+        h->bytes(s, seed, seedlen);
+        h->bytes(s, counter, 4);
+        h->final(s, hash);
+        count++;
+
+        for (i = 0; i < max; i++)
+            data[i] ^= hash[i];
+
+        data += max;
+        datalen -= max;
+    }
+}
+
+void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen,
+                        unsigned char *out, int outlen,
+                        void *key)
+{
+    Bignum b1, b2;
+    struct RSAKey *rsa = (struct RSAKey *) key;
+    int k, i;
+    char *p;
+    const int HLEN = h->hlen;
+
+    /*
+     * Here we encrypt using RSAES-OAEP. Essentially this means:
+     * 
+     *  - we have a SHA-based `mask generation function' which
+     *    creates a pseudo-random stream of mask data
+     *    deterministically from an input chunk of data.
+     * 
+     *  - we have a random chunk of data called a seed.
+     * 
+     *  - we use the seed to generate a mask which we XOR with our
+     *    plaintext.
+     * 
+     *  - then we use _the masked plaintext_ to generate a mask
+     *    which we XOR with the seed.
+     * 
+     *  - then we concatenate the masked seed and the masked
+     *    plaintext, and RSA-encrypt that lot.
+     * 
+     * The result is that the data input to the encryption function
+     * is random-looking and (hopefully) contains no exploitable
+     * structure such as PKCS1-v1_5 does.
+     * 
+     * For a precise specification, see RFC 3447, section 7.1.1.
+     * Some of the variable names below are derived from that, so
+     * it'd probably help to read it anyway.
+     */
+
+    /* k denotes the length in octets of the RSA modulus. */
+    k = (7 + bignum_bitcount(rsa->modulus)) / 8;
+
+    /* The length of the input data must be at most k - 2hLen - 2. */
+    assert(inlen > 0 && inlen <= k - 2*HLEN - 2);
+
+    /* The length of the output data wants to be precisely k. */
+    assert(outlen == k);
+
+    /*
+     * Now perform EME-OAEP encoding. First set up all the unmasked
+     * output data.
+     */
+    /* Leading byte zero. */
+    out[0] = 0;
+    /* At position 1, the seed: HLEN bytes of random data. */
+    for (i = 0; i < HLEN; i++)
+        out[i + 1] = random_byte();
+    /* At position 1+HLEN, the data block DB, consisting of: */
+    /* The hash of the label (we only support an empty label here) */
+    h->final(h->init(), out + HLEN + 1);
+    /* A bunch of zero octets */
+    memset(out + 2*HLEN + 1, 0, outlen - (2*HLEN + 1));
+    /* A single 1 octet, followed by the input message data. */
+    out[outlen - inlen - 1] = 1;
+    memcpy(out + outlen - inlen, in, inlen);
+
+    /*
+     * Now use the seed data to mask the block DB.
+     */
+    oaep_mask(h, out+1, HLEN, out+HLEN+1, outlen-HLEN-1);
+
+    /*
+     * And now use the masked DB to mask the seed itself.
+     */
+    oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN);
+
+    /*
+     * Now `out' contains precisely the data we want to
+     * RSA-encrypt.
+     */
+    b1 = bignum_from_bytes(out, outlen);
+    b2 = modpow(b1, rsa->exponent, rsa->modulus);
+    p = out;
+    for (i = outlen; i--;) {
+       *p++ = bignum_byte(b2, i);
+    }
+    freebn(b1);
+    freebn(b2);
+
+    /*
+     * And we're done.
+     */
+}
+
+static const struct ssh_kex ssh_rsa_kex_sha1 = {
+    "rsa1024-sha1", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha1
+};
+
+static const struct ssh_kex ssh_rsa_kex_sha256 = {
+    "rsa2048-sha256", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha256
+};
+
+static const struct ssh_kex *const rsa_kex_list[] = {
+    &ssh_rsa_kex_sha256,
+    &ssh_rsa_kex_sha1
+};
+
+const struct ssh_kexes ssh_rsa_kex = {
+    sizeof(rsa_kex_list) / sizeof(*rsa_kex_list),
+    rsa_kex_list
+};