Simplify handling of responses to channel requests.
[sgt/putty] / import.c
index a3c2405..05cfdc1 100644 (file)
--- a/import.c
+++ b/import.c
@@ -308,9 +308,10 @@ static int ssh2_read_mpint(void *data, int len, struct mpint_pos *ret)
  */
 
 enum { OSSH_DSA, OSSH_RSA };
+enum { OSSH_ENC_3DES, OSSH_ENC_AES };
 struct openssh_key {
     int type;
-    int encrypted;
+    int encrypted, encryption;
     char iv[32];
     unsigned char *keyblob;
     int keyblob_len, keyblob_size;
@@ -333,7 +334,7 @@ static struct openssh_key *load_openssh_key(const Filename *filename,
     ret->encrypted = 0;
     memset(ret->iv, 0, sizeof(ret->iv));
 
-    fp = f_open(*filename, "r", FALSE);
+    fp = f_open(filename, "r", FALSE);
     if (!fp) {
        errmsg = "unable to open key file";
        goto error;
@@ -357,7 +358,7 @@ static struct openssh_key *load_openssh_key(const Filename *filename,
        errmsg = "unrecognised key type";
        goto error;
     }
-    memset(line, 0, strlen(line));
+    smemclr(line, strlen(line));
     sfree(line);
     line = NULL;
 
@@ -387,21 +388,29 @@ static struct openssh_key *load_openssh_key(const Filename *filename,
                if (!strcmp(p, "ENCRYPTED"))
                    ret->encrypted = 1;
            } else if (!strcmp(line, "DEK-Info")) {
-               int i, j;
-
-               if (strncmp(p, "DES-EDE3-CBC,", 13)) {
-                   errmsg = "ciphers other than DES-EDE3-CBC not supported";
+               int i, j, ivlen;
+
+               if (!strncmp(p, "DES-EDE3-CBC,", 13)) {
+                   ret->encryption = OSSH_ENC_3DES;
+                   ivlen = 8;
+               } else if (!strncmp(p, "AES-128-CBC,", 12)) {
+                   ret->encryption = OSSH_ENC_AES;
+                   ivlen = 16;
+               } else {
+                   errmsg = "unsupported cipher";
                    goto error;
                }
-               p += 13;
-               for (i = 0; i < 8; i++) {
-                   if (1 != sscanf(p, "%2x", &j))
-                       break;
+               p = strchr(p, ',') + 1;/* always non-NULL, by above checks */
+               for (i = 0; i < ivlen; i++) {
+                   if (1 != sscanf(p, "%2x", &j)) {
+                       errmsg = "expected more iv data in DEK-Info";
+                       goto error;
+                   }
                    ret->iv[i] = j;
                    p += 2;
                }
-               if (i < 8) {
-                   errmsg = "expected 16-digit iv in DEK-Info";
+               if (*p) {
+                   errmsg = "more iv data than expected in DEK-Info";
                    goto error;
                }
            }
@@ -433,13 +442,13 @@ static struct openssh_key *load_openssh_key(const Filename *filename,
                     memcpy(ret->keyblob + ret->keyblob_len, out, len);
                     ret->keyblob_len += len;
 
-                    memset(out, 0, sizeof(out));
+                    smemclr(out, sizeof(out));
                 }
 
                p++;
            }
        }
-       memset(line, 0, strlen(line));
+       smemclr(line, strlen(line));
        sfree(line);
        line = NULL;
     }
@@ -454,23 +463,23 @@ static struct openssh_key *load_openssh_key(const Filename *filename,
        goto error;
     }
 
-    memset(base64_bit, 0, sizeof(base64_bit));
+    smemclr(base64_bit, sizeof(base64_bit));
     if (errmsg_p) *errmsg_p = NULL;
     return ret;
 
     error:
     if (line) {
-       memset(line, 0, strlen(line));
+       smemclr(line, strlen(line));
        sfree(line);
        line = NULL;
     }
-    memset(base64_bit, 0, sizeof(base64_bit));
+    smemclr(base64_bit, sizeof(base64_bit));
     if (ret) {
        if (ret->keyblob) {
-            memset(ret->keyblob, 0, ret->keyblob_size);
+            smemclr(ret->keyblob, ret->keyblob_size);
             sfree(ret->keyblob);
         }
-        memset(ret, 0, sizeof(*ret));
+        smemclr(ret, sizeof(*ret));
        sfree(ret);
     }
     if (errmsg_p) *errmsg_p = errmsg;
@@ -485,9 +494,9 @@ int openssh_encrypted(const Filename *filename)
     if (!key)
        return 0;
     ret = key->encrypted;
-    memset(key->keyblob, 0, key->keyblob_size);
+    smemclr(key->keyblob, key->keyblob_size);
     sfree(key->keyblob);
-    memset(key, 0, sizeof(*key));
+    smemclr(key, sizeof(*key));
     sfree(key);
     return ret;
 }
@@ -520,6 +529,10 @@ struct ssh2_userkey *openssh_read(const Filename *filename, char *passphrase,
         *  - let block B equal MD5(A || passphrase || iv)
         *  - block C would be MD5(B || passphrase || iv) and so on
         *  - encryption key is the first N bytes of A || B
+        *
+        * (Note that only 8 bytes of the iv are used for key
+        * derivation, even when the key is encrypted with AES and
+        * hence there are 16 bytes available.)
         */
        struct MD5Context md5c;
        unsigned char keybuf[32];
@@ -538,11 +551,21 @@ struct ssh2_userkey *openssh_read(const Filename *filename, char *passphrase,
        /*
         * Now decrypt the key blob.
         */
-       des3_decrypt_pubkey_ossh(keybuf, (unsigned char *)key->iv,
-                                key->keyblob, key->keyblob_len);
+       if (key->encryption == OSSH_ENC_3DES)
+           des3_decrypt_pubkey_ossh(keybuf, (unsigned char *)key->iv,
+                                    key->keyblob, key->keyblob_len);
+       else {
+           void *ctx;
+           assert(key->encryption == OSSH_ENC_AES);
+           ctx = aes_make_context();
+           aes128_key(ctx, keybuf);
+           aes_iv(ctx, (unsigned char *)key->iv);
+           aes_ssh2_decrypt_blk(ctx, key->keyblob, key->keyblob_len);
+           aes_free_context(ctx);
+       }
 
-        memset(&md5c, 0, sizeof(md5c));
-        memset(keybuf, 0, sizeof(keybuf));
+        smemclr(&md5c, sizeof(md5c));
+        smemclr(keybuf, sizeof(keybuf));
     }
 
     /*
@@ -675,12 +698,12 @@ struct ssh2_userkey *openssh_read(const Filename *filename, char *passphrase,
 
     error:
     if (blob) {
-        memset(blob, 0, blobsize);
+        smemclr(blob, blobsize);
         sfree(blob);
     }
-    memset(key->keyblob, 0, key->keyblob_size);
+    smemclr(key->keyblob, key->keyblob_size);
     sfree(key->keyblob);
-    memset(key, 0, sizeof(*key));
+    smemclr(key, sizeof(*key));
     sfree(key);
     if (errmsg_p) *errmsg_p = errmsg;
     return retval;
@@ -853,6 +876,9 @@ int openssh_write(const Filename *filename, struct ssh2_userkey *key,
 
     /*
      * Encrypt the key.
+     *
+     * For the moment, we still encrypt our OpenSSH keys using
+     * old-style 3DES.
      */
     if (passphrase) {
        /*
@@ -885,15 +911,15 @@ int openssh_write(const Filename *filename, struct ssh2_userkey *key,
         */
        des3_encrypt_pubkey_ossh(keybuf, iv, outblob, outlen);
 
-        memset(&md5c, 0, sizeof(md5c));
-        memset(keybuf, 0, sizeof(keybuf));
+        smemclr(&md5c, sizeof(md5c));
+        smemclr(keybuf, sizeof(keybuf));
     }
 
     /*
      * And save it. We'll use Unix line endings just in case it's
      * subsequently transferred in binary mode.
      */
-    fp = f_open(*filename, "wb", TRUE);      /* ensure Unix line endings */
+    fp = f_open(filename, "wb", TRUE);      /* ensure Unix line endings */
     if (!fp)
        goto error;
     fputs(header, fp);
@@ -910,19 +936,19 @@ int openssh_write(const Filename *filename, struct ssh2_userkey *key,
 
     error:
     if (outblob) {
-        memset(outblob, 0, outlen);
+        smemclr(outblob, outlen);
         sfree(outblob);
     }
     if (spareblob) {
-        memset(spareblob, 0, sparelen);
+        smemclr(spareblob, sparelen);
         sfree(spareblob);
     }
     if (privblob) {
-        memset(privblob, 0, privlen);
+        smemclr(privblob, privlen);
         sfree(privblob);
     }
     if (pubblob) {
-        memset(pubblob, 0, publen);
+        smemclr(pubblob, publen);
         sfree(pubblob);
     }
     return ret;
@@ -1027,7 +1053,7 @@ static struct sshcom_key *load_sshcom_key(const Filename *filename,
     ret->keyblob = NULL;
     ret->keyblob_len = ret->keyblob_size = 0;
 
-    fp = f_open(*filename, "r", FALSE);
+    fp = f_open(filename, "r", FALSE);
     if (!fp) {
        errmsg = "unable to open key file";
        goto error;
@@ -1041,7 +1067,7 @@ static struct sshcom_key *load_sshcom_key(const Filename *filename,
        errmsg = "file does not begin with ssh.com key header";
        goto error;
     }
-    memset(line, 0, strlen(line));
+    smemclr(line, strlen(line));
     sfree(line);
     line = NULL;
 
@@ -1086,7 +1112,7 @@ static struct sshcom_key *load_sshcom_key(const Filename *filename,
                len += line2len - 1;
                assert(!line[len]);
 
-               memset(line2, 0, strlen(line2));
+               smemclr(line2, strlen(line2));
                sfree(line2);
                line2 = NULL;
             }
@@ -1132,7 +1158,7 @@ static struct sshcom_key *load_sshcom_key(const Filename *filename,
                p++;
            }
        }
-       memset(line, 0, strlen(line));
+       smemclr(line, strlen(line));
        sfree(line);
        line = NULL;
     }
@@ -1147,16 +1173,16 @@ static struct sshcom_key *load_sshcom_key(const Filename *filename,
 
     error:
     if (line) {
-       memset(line, 0, strlen(line));
+       smemclr(line, strlen(line));
        sfree(line);
        line = NULL;
     }
     if (ret) {
        if (ret->keyblob) {
-            memset(ret->keyblob, 0, ret->keyblob_size);
+            smemclr(ret->keyblob, ret->keyblob_size);
             sfree(ret->keyblob);
         }
-        memset(ret, 0, sizeof(*ret));
+        smemclr(ret, sizeof(*ret));
        sfree(ret);
     }
     if (errmsg_p) *errmsg_p = errmsg;
@@ -1196,9 +1222,9 @@ int sshcom_encrypted(const Filename *filename, char **comment)
 
     done:
     *comment = dupstr(key->comment);
-    memset(key->keyblob, 0, key->keyblob_size);
+    smemclr(key->keyblob, key->keyblob_size);
     sfree(key->keyblob);
-    memset(key, 0, sizeof(*key));
+    smemclr(key, sizeof(*key));
     sfree(key);
     return answer;
 }
@@ -1364,8 +1390,8 @@ struct ssh2_userkey *sshcom_read(const Filename *filename, char *passphrase,
        des3_decrypt_pubkey_ossh(keybuf, iv, (unsigned char *)ciphertext,
                                 cipherlen);
 
-        memset(&md5c, 0, sizeof(md5c));
-        memset(keybuf, 0, sizeof(keybuf));
+        smemclr(&md5c, sizeof(md5c));
+        smemclr(keybuf, sizeof(keybuf));
 
         /*
          * Hereafter we return WRONG_PASSPHRASE for any parsing
@@ -1468,12 +1494,12 @@ struct ssh2_userkey *sshcom_read(const Filename *filename, char *passphrase,
 
     error:
     if (blob) {
-        memset(blob, 0, blobsize);
+        smemclr(blob, blobsize);
         sfree(blob);
     }
-    memset(key->keyblob, 0, key->keyblob_size);
+    smemclr(key->keyblob, key->keyblob_size);
     sfree(key->keyblob);
-    memset(key, 0, sizeof(*key));
+    smemclr(key, sizeof(*key));
     sfree(key);
     if (errmsg_p) *errmsg_p = errmsg;
     return ret;
@@ -1638,15 +1664,15 @@ int sshcom_write(const Filename *filename, struct ssh2_userkey *key,
        des3_encrypt_pubkey_ossh(keybuf, iv, (unsigned char *)ciphertext,
                                 cipherlen);
 
-        memset(&md5c, 0, sizeof(md5c));
-        memset(keybuf, 0, sizeof(keybuf));
+        smemclr(&md5c, sizeof(md5c));
+        smemclr(keybuf, sizeof(keybuf));
     }
 
     /*
      * And save it. We'll use Unix line endings just in case it's
      * subsequently transferred in binary mode.
      */
-    fp = f_open(*filename, "wb", TRUE);      /* ensure Unix line endings */
+    fp = f_open(filename, "wb", TRUE);      /* ensure Unix line endings */
     if (!fp)
        goto error;
     fputs("---- BEGIN SSH2 ENCRYPTED PRIVATE KEY ----\n", fp);
@@ -1674,15 +1700,15 @@ int sshcom_write(const Filename *filename, struct ssh2_userkey *key,
 
     error:
     if (outblob) {
-        memset(outblob, 0, outlen);
+        smemclr(outblob, outlen);
         sfree(outblob);
     }
     if (privblob) {
-        memset(privblob, 0, privlen);
+        smemclr(privblob, privlen);
         sfree(privblob);
     }
     if (pubblob) {
-        memset(pubblob, 0, publen);
+        smemclr(pubblob, publen);
         sfree(pubblob);
     }
     return ret;