Make memory management uniform: _everything_ now goes through the
[u/mdw/putty] / sshpubk.c
index 3e4de1f..377be3a 100644 (file)
--- a/sshpubk.c
+++ b/sshpubk.c
@@ -1,19 +1,21 @@
 /*
- * Read SSH public keys from files.
- *
- * First implementation: only supports unencrypted SSH 1.1 format
- * RSA keys. Encryption, and SSH 2 DSS keys, to be supported later.
+ * Generic SSH public-key handling operations. In particular,
+ * reading of SSH public-key files, and also the generic `sign'
+ * operation for ssh2 (which checks the type of the key and
+ * dispatches to the appropriate key-type specific function).
  */
 
 #include <stdio.h>
-
-#include <stdio.h> /* FIXME */
-#include <stdarg.h> /* FIXME */
-#include <windows.h> /* FIXME */
-#include "putty.h" /* FIXME */
+#include <stdlib.h>
 
 #include "ssh.h"
 
+#define PUT_32BIT(cp, value) do { \
+  (cp)[3] = (value); \
+  (cp)[2] = (value) >> 8; \
+  (cp)[1] = (value) >> 16; \
+  (cp)[0] = (value) >> 24; } while (0)
+
 #define GET_32BIT(cp) \
     (((unsigned long)(unsigned char)(cp)[0] << 24) | \
     ((unsigned long)(unsigned char)(cp)[1] << 16) | \
 
 #define rsa_signature "SSH PRIVATE KEY FILE FORMAT 1.1\n"
 
-int loadrsakey(char *filename, struct RSAKey *key, char *passphrase) {
-    FILE *fp;
+#define BASE64_TOINT(x) ( (x)-'A'<26 ? (x)-'A'+0 :\
+                          (x)-'a'<26 ? (x)-'a'+26 :\
+                          (x)-'0'<10 ? (x)-'0'+52 :\
+                          (x)=='+' ? 62 : \
+                          (x)=='/' ? 63 : 0 )
+
+static int loadrsakey_main(FILE *fp, struct RSAKey *key, struct RSAAux *aux,
+                           char **commentptr, char *passphrase) {
     unsigned char buf[16384];
     unsigned char keybuf[16];
     int len;
     int i, j, ciphertype;
     int ret = 0;
     struct MD5Context md5c;
+    char *comment;
 
-    fp = fopen(filename, "rb");
-    if (!fp)
-        goto end;
-
-    /* Slurp the whole file into a buffer. */
+    /* Slurp the whole file (minus the header) into a buffer. */
     len = fread(buf, 1, sizeof(buf), fp);
     fclose(fp);
     if (len < 0 || len == sizeof(buf))
         goto end;                      /* file too big or not read */
 
-    if (len < sizeof(rsa_signature) ||
-        memcmp(buf, rsa_signature, sizeof(rsa_signature)) != 0)
-        goto end;                      /* failure to have sig at front */
+    i = 0;
 
-    i = sizeof(rsa_signature);
+    /*
+     * A zero byte. (The signature includes a terminating NUL.)
+     */
+    if (len-i < 1 || buf[i] != 0)
+        goto end;
+    i++;
 
-    /* Next, one byte giving encryption type, and one reserved uint32. */
+    /* One byte giving encryption type, and one reserved uint32. */
     if (len-i < 1)
         goto end;
     ciphertype = buf[i];
@@ -69,12 +77,19 @@ int loadrsakey(char *filename, struct RSAKey *key, char *passphrase) {
     j = GET_32BIT(buf+i);
     i += 4;
     if (len-i < j) goto end;
-    key->comment = malloc(j+1);
-    if (key->comment) {
-        memcpy(key->comment, buf+i, j);
-        key->comment[j] = '\0';
+    comment = smalloc(j+1);
+    if (comment) {
+        memcpy(comment, buf+i, j);
+        comment[j] = '\0';
     }
     i += j;
+    if (commentptr)
+        *commentptr = comment;
+    if (key)
+        key->comment = comment;
+    if (!key) {
+        return ciphertype != 0;
+    }
 
     /*
      * Decrypt remainder of buffer.
@@ -97,10 +112,19 @@ int loadrsakey(char *filename, struct RSAKey *key, char *passphrase) {
 
     /*
      * After that, we have one further bignum which is our
-     * decryption modulus, and then we're done.
+     * decryption exponent, and then the three auxiliary values
+     * (iqmp, q, p).
      */
     i += makeprivate(buf+i, key);
     if (len-i < 0) goto end;
+    if (aux) {
+        i += ssh1_read_bignum(buf+i, &aux->iqmp);
+        if (len-i < 0) goto end;
+        i += ssh1_read_bignum(buf+i, &aux->q);
+        if (len-i < 0) goto end;
+        i += ssh1_read_bignum(buf+i, &aux->p);
+        if (len-i < 0) goto end;
+    }
 
     ret = 1;
     end:
@@ -108,24 +132,147 @@ int loadrsakey(char *filename, struct RSAKey *key, char *passphrase) {
     return ret;
 }
 
+int loadrsakey(char *filename, struct RSAKey *key, struct RSAAux *aux,
+               char *passphrase) {
+    FILE *fp;
+    unsigned char buf[64];
+
+    fp = fopen(filename, "rb");
+    if (!fp)
+        return 0;                      /* doesn't even exist */
+
+    /*
+     * Read the first line of the file and see if it's a v1 private
+     * key file.
+     */
+    if (fgets(buf, sizeof(buf), fp) &&
+        !strcmp(buf, rsa_signature)) {
+        return loadrsakey_main(fp, key, aux, NULL, passphrase);
+    }
+
+    /*
+     * Otherwise, we have nothing. Return empty-handed.
+     */
+    fclose(fp);
+    return 0;
+}
+
 /*
- * See whether an RSA key is encrypted.
+ * See whether an RSA key is encrypted. Return its comment field as
+ * well.
  */
-int rsakey_encrypted(char *filename) {
+int rsakey_encrypted(char *filename, char **comment) {
     FILE *fp;
-    unsigned char buf[1+sizeof(rsa_signature)];
-    int len;
+    unsigned char buf[64];
 
     fp = fopen(filename, "rb");
     if (!fp)
         return 0;                      /* doesn't even exist */
 
-    /* Slurp the whole file into a buffer. */
-    len = fread(buf, 1, sizeof(buf), fp);
+    /*
+     * Read the first line of the file and see if it's a v1 private
+     * key file.
+     */
+    if (fgets(buf, sizeof(buf), fp) &&
+        !strcmp(buf, rsa_signature)) {
+        return loadrsakey_main(fp, NULL, NULL, comment, NULL);
+    }
     fclose(fp);
-    if (len < sizeof(buf))
-        return 0;                      /* not even valid */
-    if (buf[sizeof(buf)-1])
-        return 1;                      /* encrypted */
-    return 0;
+    return 0;                          /* wasn't the right kind of file */
+}
+
+/*
+ * Save an RSA key file. Return nonzero on success.
+ */
+int saversakey(char *filename, struct RSAKey *key, struct RSAAux *aux,
+               char *passphrase) {
+    unsigned char buf[16384];
+    unsigned char keybuf[16];
+    struct MD5Context md5c;
+    unsigned char *p, *estart;
+    FILE *fp;
+
+    /*
+     * Write the initial signature.
+     */
+    p = buf;
+    memcpy(p, rsa_signature, sizeof(rsa_signature));
+    p += sizeof(rsa_signature);
+
+    /*
+     * One byte giving encryption type, and one reserved (zero)
+     * uint32.
+     */
+    *p++ = (passphrase ? SSH_CIPHER_3DES : 0);
+    PUT_32BIT(p, 0); p += 4;
+
+    /*
+     * An ordinary SSH 1 public key consists of: a uint32
+     * containing the bit count, then two bignums containing the
+     * modulus and exponent respectively.
+     */
+    PUT_32BIT(p, ssh1_bignum_bitcount(key->modulus)); p += 4;
+    p += ssh1_write_bignum(p, key->modulus);
+    p += ssh1_write_bignum(p, key->exponent);
+
+    /*
+     * A string containing the comment field.
+     */
+    if (key->comment) {
+        PUT_32BIT(p, strlen(key->comment)); p += 4;
+        memcpy(p, key->comment, strlen(key->comment));
+        p += strlen(key->comment);
+    } else {
+        PUT_32BIT(p, 0); p += 4;
+    }
+
+    /*
+     * The encrypted portion starts here.
+     */
+    estart = p;
+
+    /*
+     * Two bytes, then the same two bytes repeated.
+     */
+    *p++ = random_byte();
+    *p++ = random_byte();
+    p[0] = p[-2]; p[1] = p[-1]; p += 2;
+
+    /*
+     * Four more bignums: the decryption exponent, then iqmp, then
+     * q, then p.
+     */
+    p += ssh1_write_bignum(p, key->private_exponent);
+    p += ssh1_write_bignum(p, aux->iqmp);
+    p += ssh1_write_bignum(p, aux->q);
+    p += ssh1_write_bignum(p, aux->p);
+
+    /*
+     * Now write zeros until the encrypted portion is a multiple of
+     * 8 bytes.
+     */
+    while ((p-estart) % 8)
+        *p++ = '\0';
+
+    /*
+     * Now encrypt the encrypted portion.
+     */
+    if (passphrase) {
+        MD5Init(&md5c);
+        MD5Update(&md5c, passphrase, strlen(passphrase));
+        MD5Final(keybuf, &md5c);
+        des3_encrypt_pubkey(keybuf, estart, p-estart);
+        memset(keybuf, 0, sizeof(keybuf));    /* burn the evidence */
+    }
+
+    /*
+     * Done. Write the result to the file.
+     */
+    fp = fopen(filename, "wb");
+    if (fp) {
+        int ret = (fwrite(buf, 1, p-buf, fp) == (size_t)(p-buf));
+        ret = ret && (fclose(fp) == 0);
+        return ret;
+    } else
+        return 0;
 }