+
+void *ssh_rsakex_newkey(char *data, int len)
+{
+ return rsa2_newkey(data, len);
+}
+
+void ssh_rsakex_freekey(void *key)
+{
+ rsa2_freekey(key);
+}
+
+int ssh_rsakex_klen(void *key)
+{
+ struct RSAKey *rsa = (struct RSAKey *) key;
+
+ return bignum_bitcount(rsa->modulus);
+}
+
+static void oaep_mask(const struct ssh_hash *h, void *seed, int seedlen,
+ void *vdata, int datalen)
+{
+ unsigned char *data = (unsigned char *)vdata;
+ unsigned count = 0;
+
+ while (datalen > 0) {
+ int i, max = (datalen > h->hlen ? h->hlen : datalen);
+ void *s;
+ unsigned char counter[4], hash[SSH2_KEX_MAX_HASH_LEN];
+
+ assert(h->hlen <= SSH2_KEX_MAX_HASH_LEN);
+ PUT_32BIT(counter, count);
+ s = h->init();
+ h->bytes(s, seed, seedlen);
+ h->bytes(s, counter, 4);
+ h->final(s, hash);
+ count++;
+
+ for (i = 0; i < max; i++)
+ data[i] ^= hash[i];
+
+ data += max;
+ datalen -= max;
+ }
+}
+
+void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen,
+ unsigned char *out, int outlen,
+ void *key)
+{
+ Bignum b1, b2;
+ struct RSAKey *rsa = (struct RSAKey *) key;
+ int k, i;
+ char *p;
+ const int HLEN = h->hlen;
+
+ /*
+ * Here we encrypt using RSAES-OAEP. Essentially this means:
+ *
+ * - we have a SHA-based `mask generation function' which
+ * creates a pseudo-random stream of mask data
+ * deterministically from an input chunk of data.
+ *
+ * - we have a random chunk of data called a seed.
+ *
+ * - we use the seed to generate a mask which we XOR with our
+ * plaintext.
+ *
+ * - then we use _the masked plaintext_ to generate a mask
+ * which we XOR with the seed.
+ *
+ * - then we concatenate the masked seed and the masked
+ * plaintext, and RSA-encrypt that lot.
+ *
+ * The result is that the data input to the encryption function
+ * is random-looking and (hopefully) contains no exploitable
+ * structure such as PKCS1-v1_5 does.
+ *
+ * For a precise specification, see RFC 3447, section 7.1.1.
+ * Some of the variable names below are derived from that, so
+ * it'd probably help to read it anyway.
+ */
+
+ /* k denotes the length in octets of the RSA modulus. */
+ k = (7 + bignum_bitcount(rsa->modulus)) / 8;
+
+ /* The length of the input data must be at most k - 2hLen - 2. */
+ assert(inlen > 0 && inlen <= k - 2*HLEN - 2);
+
+ /* The length of the output data wants to be precisely k. */
+ assert(outlen == k);
+
+ /*
+ * Now perform EME-OAEP encoding. First set up all the unmasked
+ * output data.
+ */
+ /* Leading byte zero. */
+ out[0] = 0;
+ /* At position 1, the seed: HLEN bytes of random data. */
+ for (i = 0; i < HLEN; i++)
+ out[i + 1] = random_byte();
+ /* At position 1+HLEN, the data block DB, consisting of: */
+ /* The hash of the label (we only support an empty label here) */
+ h->final(h->init(), out + HLEN + 1);
+ /* A bunch of zero octets */
+ memset(out + 2*HLEN + 1, 0, outlen - (2*HLEN + 1));
+ /* A single 1 octet, followed by the input message data. */
+ out[outlen - inlen - 1] = 1;
+ memcpy(out + outlen - inlen, in, inlen);
+
+ /*
+ * Now use the seed data to mask the block DB.
+ */
+ oaep_mask(h, out+1, HLEN, out+HLEN+1, outlen-HLEN-1);
+
+ /*
+ * And now use the masked DB to mask the seed itself.
+ */
+ oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN);
+
+ /*
+ * Now `out' contains precisely the data we want to
+ * RSA-encrypt.
+ */
+ b1 = bignum_from_bytes(out, outlen);
+ b2 = modpow(b1, rsa->exponent, rsa->modulus);
+ p = (char *)out;
+ for (i = outlen; i--;) {
+ *p++ = bignum_byte(b2, i);
+ }
+ freebn(b1);
+ freebn(b2);
+
+ /*
+ * And we're done.
+ */
+}
+
+static const struct ssh_kex ssh_rsa_kex_sha1 = {
+ "rsa1024-sha1", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha1
+};
+
+static const struct ssh_kex ssh_rsa_kex_sha256 = {
+ "rsa2048-sha256", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha256
+};
+
+static const struct ssh_kex *const rsa_kex_list[] = {
+ &ssh_rsa_kex_sha256,
+ &ssh_rsa_kex_sha1
+};
+
+const struct ssh_kexes ssh_rsa_kex = {
+ sizeof(rsa_kex_list) / sizeof(*rsa_kex_list),
+ rsa_kex_list
+};