Add a preference list for SSH-2 key exchange algorithms, on a new "Kex" panel
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 422897f..126e6ff 100644 (file)
--- a/ssh.c
+++ b/ssh.c
  * Packet type contexts, so that ssh2_pkt_type can correctly decode
  * the ambiguous type numbers back into the correct type strings.
  */
-#define SSH2_PKTCTX_DHGROUP1         0x0001
+#define SSH2_PKTCTX_DHGROUP          0x0001
 #define SSH2_PKTCTX_DHGEX            0x0002
 #define SSH2_PKTCTX_PUBLICKEY        0x0010
 #define SSH2_PKTCTX_PASSWORD         0x0020
@@ -162,7 +162,7 @@ static const char *const ssh2_disconnect_reasons[] = {
 #define BUG_CHOKES_ON_RSA                        8
 #define BUG_SSH2_RSA_PADDING                    16
 #define BUG_SSH2_DERIVEKEY                       32
-#define BUG_SSH2_DH_GEX                          64
+/* 64 was BUG_SSH2_DH_GEX, now spare */
 #define BUG_SSH2_PK_SESSIONID                   128
 
 #define translate(x) if (type == x) return #x
@@ -222,8 +222,8 @@ static char *ssh2_pkt_type(int pkt_ctx, int type)
     translate(SSH2_MSG_SERVICE_ACCEPT);
     translate(SSH2_MSG_KEXINIT);
     translate(SSH2_MSG_NEWKEYS);
-    translatec(SSH2_MSG_KEXDH_INIT, SSH2_PKTCTX_DHGROUP1);
-    translatec(SSH2_MSG_KEXDH_REPLY, SSH2_PKTCTX_DHGROUP1);
+    translatec(SSH2_MSG_KEXDH_INIT, SSH2_PKTCTX_DHGROUP);
+    translatec(SSH2_MSG_KEXDH_REPLY, SSH2_PKTCTX_DHGROUP);
     translatec(SSH2_MSG_KEX_DH_GEX_REQUEST, SSH2_PKTCTX_DHGEX);
     translatec(SSH2_MSG_KEX_DH_GEX_GROUP, SSH2_PKTCTX_DHGEX);
     translatec(SSH2_MSG_KEX_DH_GEX_INIT, SSH2_PKTCTX_DHGEX);
@@ -360,11 +360,6 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
 #define SSH_MAX_BACKLOG 32768
 #define OUR_V2_WINSIZE 16384
 
-const static struct ssh_kex *kex_algs[] = {
-    &ssh_diffiehellman_gex,
-    &ssh_diffiehellman
-};
-
 const static struct ssh_signkey *hostkey_algs[] = { &ssh_rsa, &ssh_dss };
 
 static void *nullmac_make_context(void)
@@ -2057,13 +2052,28 @@ static void ssh_detect_bugs(Ssh ssh, char *vstring)
        ssh->remote_bugs |= BUG_SSH2_PK_SESSIONID;
        logevent("We believe remote version has SSH2 public-key-session-ID bug");
     }
+}
 
-    if (ssh->cfg.sshbug_dhgex2 == FORCE_ON) {
-       /*
-        * User specified the SSH2 DH GEX bug.
-        */
-       ssh->remote_bugs |= BUG_SSH2_DH_GEX;
-       logevent("We believe remote version has SSH2 DH group exchange bug");
+/*
+ * The `software version' part of an SSH version string is required
+ * to contain no spaces or minus signs.
+ */
+static void ssh_fix_verstring(char *str)
+{
+    /* Eat "SSH-<protoversion>-". */
+    assert(*str == 'S'); str++;
+    assert(*str == 'S'); str++;
+    assert(*str == 'H'); str++;
+    assert(*str == '-'); str++;
+    while (*str && *str != '-') str++;
+    assert(*str == '-'); str++;
+
+    /* Convert minus signs and spaces in the remaining string into
+     * underscores. */
+    while (*str) {
+        if (*str == '-' || *str == ' ')
+            *str = '_';
+        str++;
     }
 }
 
@@ -2126,7 +2136,7 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
     ssh->rdpkt2_state.incoming_sequence = 0;
 
     s->vstring[s->vslen] = 0;
-    s->vstring[strcspn(s->vstring, "\r\n")] = '\0';/* remove EOL chars */
+    s->vstring[strcspn(s->vstring, "\015\012")] = '\0';/* remove EOL chars */
     {
        char *vlog;
        vlog = snewn(20 + s->vslen, char);
@@ -2154,46 +2164,60 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
        crStop(0);
     }
 
-    if (s->proto2 && (ssh->cfg.sshprot >= 2 || !s->proto1)) {
-       /*
-        * Use v2 protocol.
-        */
-       char verstring[80], vlog[100];
-       sprintf(verstring, "SSH-2.0-%s", sshver);
-       SHA_Init(&ssh->exhashbase);
-       /*
-        * Hash our version string and their version string.
-        */
-       sha_string(&ssh->exhashbase, verstring, strlen(verstring));
-       sha_string(&ssh->exhashbase, s->vstring, strcspn(s->vstring, "\r\n"));
-       sprintf(vlog, "We claim version: %s", verstring);
-       logevent(vlog);
-       strcat(verstring, "\012");
-       logevent("Using SSH protocol version 2");
-       sk_write(ssh->s, verstring, strlen(verstring));
-       ssh->protocol = ssh2_protocol;
-       ssh2_protocol_setup(ssh);
-       ssh->version = 2;
-       ssh->s_rdpkt = ssh2_rdpkt;
-    } else {
-       /*
-        * Use v1 protocol.
-        */
-       char verstring[80], vlog[100];
-       sprintf(verstring, "SSH-%s-%s",
-               (ssh_versioncmp(s->version, "1.5") <= 0 ? s->version : "1.5"),
-               sshver);
-       sprintf(vlog, "We claim version: %s", verstring);
-       logevent(vlog);
-       strcat(verstring, "\012");
-
-       logevent("Using SSH protocol version 1");
+    {
+        char *verstring;
+
+        if (s->proto2 && (ssh->cfg.sshprot >= 2 || !s->proto1)) {
+            /*
+             * Construct a v2 version string.
+             */
+            verstring = dupprintf("SSH-2.0-%s\015\012", sshver);
+            ssh->version = 2;
+        } else {
+            /*
+             * Construct a v1 version string.
+             */
+            verstring = dupprintf("SSH-%s-%s\012",
+                                  (ssh_versioncmp(s->version, "1.5") <= 0 ?
+                                   s->version : "1.5"),
+                                  sshver);
+            ssh->version = 1;
+        }
+
+        ssh_fix_verstring(verstring);
+
+        if (ssh->version == 2) {
+            /*
+             * Hash our version string and their version string.
+             */
+            SHA_Init(&ssh->exhashbase);
+            sha_string(&ssh->exhashbase, verstring,
+                       strcspn(verstring, "\015\012"));
+            sha_string(&ssh->exhashbase, s->vstring,
+                       strcspn(s->vstring, "\015\012"));
+
+            /*
+             * Initialise SSHv2 protocol.
+             */
+            ssh->protocol = ssh2_protocol;
+            ssh2_protocol_setup(ssh);
+            ssh->s_rdpkt = ssh2_rdpkt;
+        } else {
+            /*
+             * Initialise SSHv1 protocol.
+             */
+            ssh->protocol = ssh1_protocol;
+            ssh1_protocol_setup(ssh);
+            ssh->s_rdpkt = ssh1_rdpkt;
+        }
+        logeventf(ssh, "We claim version: %.*s",
+                  strcspn(verstring, "\015\012"), verstring);
        sk_write(ssh->s, verstring, strlen(verstring));
-       ssh->protocol = ssh1_protocol;
-       ssh1_protocol_setup(ssh);
-       ssh->version = 1;
-       ssh->s_rdpkt = ssh1_rdpkt;
+        sfree(verstring);
     }
+
+    logeventf(ssh, "Using SSH protocol version %d", ssh->version);
+
     update_specials_menu(ssh->frontend);
     ssh->state = SSH_STATE_BEFORE_SIZE;
     ssh->pinger = pinger_new(&ssh->cfg, &ssh_backend, ssh);
@@ -2576,8 +2600,6 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
 
     crBegin(ssh->do_ssh1_login_crstate);
 
-    random_init();
-
     if (!pktin)
        crWaitUntil(pktin);
 
@@ -2725,7 +2747,7 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
 
        /* Warn about chosen cipher if necessary. */
        if (warn)
-           askcipher(ssh->frontend, cipher_string, 0);
+           askalg(ssh->frontend, "cipher", cipher_string);
     }
 
     switch (s->cipher_type) {
@@ -4285,6 +4307,8 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
        int hostkeylen, siglen;
        void *hkey;                    /* actual host key */
        unsigned char exchange_hash[20];
+       int n_preferred_kex;
+       const struct ssh_kex *preferred_kex[KEX_MAX];
        int n_preferred_ciphers;
        const struct ssh2_ciphers *preferred_ciphers[CIPHER_MAX];
        const struct ssh_compress *preferred_comp;
@@ -4299,12 +4323,42 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
     s->csmac_tobe = s->scmac_tobe = NULL;
     s->cscomp_tobe = s->sccomp_tobe = NULL;
 
-    random_init();
     s->first_kex = 1;
 
     {
        int i;
        /*
+        * Set up the preferred key exchange. (NULL => warn below here)
+        */
+       s->n_preferred_kex = 0;
+       for (i = 0; i < KEX_MAX; i++) {
+           switch (ssh->cfg.ssh_kexlist[i]) {
+             case KEX_DHGEX:
+               s->preferred_kex[s->n_preferred_kex++] =
+                   &ssh_diffiehellman_gex;
+               break;
+             case KEX_DHGROUP14:
+               s->preferred_kex[s->n_preferred_kex++] =
+                   &ssh_diffiehellman_group14;
+               break;
+             case KEX_DHGROUP1:
+               s->preferred_kex[s->n_preferred_kex++] =
+                   &ssh_diffiehellman_group1;
+               break;
+             case CIPHER_WARN:
+               /* Flag for later. Don't bother if it's the last in
+                * the list. */
+               if (i < KEX_MAX - 1) {
+                   s->preferred_kex[s->n_preferred_kex++] = NULL;
+               }
+               break;
+           }
+       }
+    }
+
+    {
+       int i;
+       /*
         * Set up the preferred ciphers. (NULL => warn below here)
         */
        s->n_preferred_ciphers = 0;
@@ -4353,7 +4407,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
 
   begin_key_exchange:
     {
-       int i, j, cipherstr_started;
+       int i, j, commalist_started;
 
        /*
         * Enable queueing of outgoing auth- or connection-layer
@@ -4374,13 +4428,14 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
            ssh2_pkt_addbyte(s->pktout, (unsigned char) random_byte());
        /* List key exchange algorithms. */
        ssh2_pkt_addstring_start(s->pktout);
-       for (i = 0; i < lenof(kex_algs); i++) {
-           if (kex_algs[i] == &ssh_diffiehellman_gex &&
-               (ssh->remote_bugs & BUG_SSH2_DH_GEX))
-               continue;
-           ssh2_pkt_addstring_str(s->pktout, kex_algs[i]->name);
-           if (i < lenof(kex_algs) - 1)
+       commalist_started = 0;
+       for (i = 0; i < s->n_preferred_kex; i++) {
+           const struct ssh_kex *k = s->preferred_kex[i];
+           if (!k) continue;          /* warning flag */
+           if (commalist_started)
                ssh2_pkt_addstring_str(s->pktout, ",");
+           ssh2_pkt_addstring_str(s->pktout, s->preferred_kex[i]->name);
+           commalist_started = 1;
        }
        /* List server host key algorithms. */
        ssh2_pkt_addstring_start(s->pktout);
@@ -4391,28 +4446,28 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
        }
        /* List client->server encryption algorithms. */
        ssh2_pkt_addstring_start(s->pktout);
-       cipherstr_started = 0;
+       commalist_started = 0;
        for (i = 0; i < s->n_preferred_ciphers; i++) {
            const struct ssh2_ciphers *c = s->preferred_ciphers[i];
            if (!c) continue;          /* warning flag */
            for (j = 0; j < c->nciphers; j++) {
-               if (cipherstr_started)
+               if (commalist_started)
                    ssh2_pkt_addstring_str(s->pktout, ",");
                ssh2_pkt_addstring_str(s->pktout, c->list[j]->name);
-               cipherstr_started = 1;
+               commalist_started = 1;
            }
        }
        /* List server->client encryption algorithms. */
        ssh2_pkt_addstring_start(s->pktout);
-       cipherstr_started = 0;
+       commalist_started = 0;
        for (i = 0; i < s->n_preferred_ciphers; i++) {
            const struct ssh2_ciphers *c = s->preferred_ciphers[i];
            if (!c) continue; /* warning flag */
            for (j = 0; j < c->nciphers; j++) {
-               if (cipherstr_started)
+               if (commalist_started)
                    ssh2_pkt_addstring_str(s->pktout, ",");
                ssh2_pkt_addstring_str(s->pktout, c->list[j]->name);
-               cipherstr_started = 1;
+               commalist_started = 1;
            }
        }
        /* List client->server MAC algorithms. */
@@ -4493,15 +4548,26 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
        s->sccomp_tobe = NULL;
        pktin->savedpos += 16;          /* skip garbage cookie */
        ssh_pkt_getstring(pktin, &str, &len);    /* key exchange algorithms */
-       for (i = 0; i < lenof(kex_algs); i++) {
-           if (kex_algs[i] == &ssh_diffiehellman_gex &&
-               (ssh->remote_bugs & BUG_SSH2_DH_GEX))
-               continue;
-           if (in_commasep_string(kex_algs[i]->name, str, len)) {
-               ssh->kex = kex_algs[i];
+       s->warn = 0;
+       for (i = 0; i < s->n_preferred_kex; i++) {
+           const struct ssh_kex *k = s->preferred_kex[i];
+           if (!k) {
+               s->warn = 1;
+           } else if (in_commasep_string(k->name, str, len)) {
+               ssh->kex = k;
+           }
+           if (ssh->kex) {
+               if (s->warn)
+                   askalg(ssh->frontend, "key-exchange algorithm",
+                          ssh->kex->name);
                break;
            }
        }
+       if (!ssh->kex) {
+           bombout(("Couldn't agree a key exchange algorithm (available: %s)",
+                    str ? str : "(null)"));
+           crStop(0);
+       }
        ssh_pkt_getstring(pktin, &str, &len);    /* host key algorithms */
        for (i = 0; i < lenof(hostkey_algs); i++) {
            if (in_commasep_string(hostkey_algs[i]->name, str, len)) {
@@ -4525,7 +4591,8 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
            }
            if (s->cscipher_tobe) {
                if (s->warn)
-                   askcipher(ssh->frontend, s->cscipher_tobe->name, 1);
+                   askalg(ssh->frontend, "client-to-server cipher",
+                          s->cscipher_tobe->name);
                break;
            }
        }
@@ -4551,7 +4618,8 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
            }
            if (s->sccipher_tobe) {
                if (s->warn)
-                   askcipher(ssh->frontend, s->sccipher_tobe->name, 2);
+                   askalg(ssh->frontend, "server-to-client cipher",
+                          s->sccipher_tobe->name);
                break;
            }
        }
@@ -4616,7 +4684,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
      * If we're doing Diffie-Hellman group exchange, start by
      * requesting a group.
      */
-    if (ssh->kex == &ssh_diffiehellman_gex) {
+    if (!ssh->kex->pdata) {
        logevent("Doing Diffie-Hellman group exchange");
        ssh->pkt_ctx |= SSH2_PKTCTX_DHGEX;
        /*
@@ -4639,14 +4707,16 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
            bombout(("unable to read mp-ints from incoming group packet"));
            crStop(0);
        }
-       ssh->kex_ctx = dh_setup_group(s->p, s->g);
+       ssh->kex_ctx = dh_setup_gex(s->p, s->g);
        s->kex_init_value = SSH2_MSG_KEX_DH_GEX_INIT;
        s->kex_reply_value = SSH2_MSG_KEX_DH_GEX_REPLY;
     } else {
-       ssh->pkt_ctx |= SSH2_PKTCTX_DHGROUP1;
-       ssh->kex_ctx = dh_setup_group1();
+       ssh->pkt_ctx |= SSH2_PKTCTX_DHGROUP;
+       ssh->kex_ctx = dh_setup_group(ssh->kex);
        s->kex_init_value = SSH2_MSG_KEXDH_INIT;
        s->kex_reply_value = SSH2_MSG_KEXDH_REPLY;
+       logeventf(ssh, "Using Diffie-Hellman with standard group \"%s\"",
+                 ssh->kex->groupname);
     }
 
     logevent("Doing Diffie-Hellman key exchange");
@@ -4735,7 +4805,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
 
     /*
      * We've sent client NEWKEYS, so create and initialise
-     * client-to-servere session keys.
+     * client-to-server session keys.
      */
     if (ssh->cs_cipher_ctx)
        ssh->cscipher->free_context(ssh->cs_cipher_ctx);
@@ -7145,6 +7215,8 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
     if (p != NULL)
        return p;
 
+    random_ref();
+
     return NULL;
 }
 
@@ -7199,12 +7271,14 @@ static void ssh_free(void *handle)
            sfree(c);
        }
        freetree234(ssh->channels);
+       ssh->channels = NULL;
     }
 
     if (ssh->rportfwds) {
        while ((pf = delpos234(ssh->rportfwds, 0)) != NULL)
            sfree(pf);
        freetree234(ssh->rportfwds);
+       ssh->rportfwds = NULL;
     }
     sfree(ssh->deferred_send_data);
     if (ssh->x11auth)
@@ -7220,9 +7294,11 @@ static void ssh_free(void *handle)
     if (ssh->s)
        ssh_do_close(ssh);
     expire_timer_context(ssh);
-    sfree(ssh);
     if (ssh->pinger)
        pinger_free(ssh->pinger);
+    sfree(ssh);
+
+    random_unref();
 }
 
 /*