Split pkt_ctx into a separate enumeration for each of kex and userauth
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index e5d2364..3476d3d 100644 (file)
--- a/ssh.c
+++ b/ssh.c
  * Packet type contexts, so that ssh2_pkt_type can correctly decode
  * the ambiguous type numbers back into the correct type strings.
  */
-#define SSH2_PKTCTX_DHGROUP          0x0001
-#define SSH2_PKTCTX_DHGEX            0x0002
-#define SSH2_PKTCTX_RSAKEX           0x0004
-#define SSH2_PKTCTX_KEX_MASK         0x000F
-#define SSH2_PKTCTX_PUBLICKEY        0x0010
-#define SSH2_PKTCTX_PASSWORD         0x0020
-#define SSH2_PKTCTX_KBDINTER         0x0040
-#define SSH2_PKTCTX_AUTH_MASK        0x00F0
+typedef enum {
+    SSH2_PKTCTX_NOKEX,
+    SSH2_PKTCTX_DHGROUP,
+    SSH2_PKTCTX_DHGEX,
+    SSH2_PKTCTX_RSAKEX
+} Pkt_KCtx;
+typedef enum {
+    SSH2_PKTCTX_NOAUTH,
+    SSH2_PKTCTX_PUBLICKEY,
+    SSH2_PKTCTX_PASSWORD,
+    SSH2_PKTCTX_KBDINTER
+} Pkt_ACtx;
 
 #define SSH2_DISCONNECT_HOST_NOT_ALLOWED_TO_CONNECT 1  /* 0x1 */
 #define SSH2_DISCONNECT_PROTOCOL_ERROR            2    /* 0x2 */
@@ -281,7 +285,8 @@ static unsigned int ssh_tty_parse_boolean(char *s)
 }
 
 #define translate(x) if (type == x) return #x
-#define translatec(x,ctx) if (type == x && (pkt_ctx & ctx)) return #x
+#define translatek(x,ctx) if (type == x && (pkt_kctx == ctx)) return #x
+#define translatea(x,ctx) if (type == x && (pkt_actx == ctx)) return #x
 static char *ssh1_pkt_type(int type)
 {
     translate(SSH1_MSG_DISCONNECT);
@@ -327,7 +332,7 @@ static char *ssh1_pkt_type(int type)
     translate(SSH1_CMSG_AUTH_CCARD_RESPONSE);
     return "unknown";
 }
-static char *ssh2_pkt_type(int pkt_ctx, int type)
+static char *ssh2_pkt_type(Pkt_KCtx pkt_kctx, Pkt_ACtx pkt_actx, int type)
 {
     translate(SSH2_MSG_DISCONNECT);
     translate(SSH2_MSG_IGNORE);
@@ -337,23 +342,23 @@ static char *ssh2_pkt_type(int pkt_ctx, int type)
     translate(SSH2_MSG_SERVICE_ACCEPT);
     translate(SSH2_MSG_KEXINIT);
     translate(SSH2_MSG_NEWKEYS);
-    translatec(SSH2_MSG_KEXDH_INIT, SSH2_PKTCTX_DHGROUP);
-    translatec(SSH2_MSG_KEXDH_REPLY, SSH2_PKTCTX_DHGROUP);
-    translatec(SSH2_MSG_KEX_DH_GEX_REQUEST, SSH2_PKTCTX_DHGEX);
-    translatec(SSH2_MSG_KEX_DH_GEX_GROUP, SSH2_PKTCTX_DHGEX);
-    translatec(SSH2_MSG_KEX_DH_GEX_INIT, SSH2_PKTCTX_DHGEX);
-    translatec(SSH2_MSG_KEX_DH_GEX_REPLY, SSH2_PKTCTX_DHGEX);
-    translatec(SSH2_MSG_KEXRSA_PUBKEY, SSH2_PKTCTX_RSAKEX);
-    translatec(SSH2_MSG_KEXRSA_SECRET, SSH2_PKTCTX_RSAKEX);
-    translatec(SSH2_MSG_KEXRSA_DONE, SSH2_PKTCTX_RSAKEX);
+    translatek(SSH2_MSG_KEXDH_INIT, SSH2_PKTCTX_DHGROUP);
+    translatek(SSH2_MSG_KEXDH_REPLY, SSH2_PKTCTX_DHGROUP);
+    translatek(SSH2_MSG_KEX_DH_GEX_REQUEST, SSH2_PKTCTX_DHGEX);
+    translatek(SSH2_MSG_KEX_DH_GEX_GROUP, SSH2_PKTCTX_DHGEX);
+    translatek(SSH2_MSG_KEX_DH_GEX_INIT, SSH2_PKTCTX_DHGEX);
+    translatek(SSH2_MSG_KEX_DH_GEX_REPLY, SSH2_PKTCTX_DHGEX);
+    translatek(SSH2_MSG_KEXRSA_PUBKEY, SSH2_PKTCTX_RSAKEX);
+    translatek(SSH2_MSG_KEXRSA_SECRET, SSH2_PKTCTX_RSAKEX);
+    translatek(SSH2_MSG_KEXRSA_DONE, SSH2_PKTCTX_RSAKEX);
     translate(SSH2_MSG_USERAUTH_REQUEST);
     translate(SSH2_MSG_USERAUTH_FAILURE);
     translate(SSH2_MSG_USERAUTH_SUCCESS);
     translate(SSH2_MSG_USERAUTH_BANNER);
-    translatec(SSH2_MSG_USERAUTH_PK_OK, SSH2_PKTCTX_PUBLICKEY);
-    translatec(SSH2_MSG_USERAUTH_PASSWD_CHANGEREQ, SSH2_PKTCTX_PASSWORD);
-    translatec(SSH2_MSG_USERAUTH_INFO_REQUEST, SSH2_PKTCTX_KBDINTER);
-    translatec(SSH2_MSG_USERAUTH_INFO_RESPONSE, SSH2_PKTCTX_KBDINTER);
+    translatea(SSH2_MSG_USERAUTH_PK_OK, SSH2_PKTCTX_PUBLICKEY);
+    translatea(SSH2_MSG_USERAUTH_PASSWD_CHANGEREQ, SSH2_PKTCTX_PASSWORD);
+    translatea(SSH2_MSG_USERAUTH_INFO_REQUEST, SSH2_PKTCTX_KBDINTER);
+    translatea(SSH2_MSG_USERAUTH_INFO_RESPONSE, SSH2_PKTCTX_KBDINTER);
     translate(SSH2_MSG_GLOBAL_REQUEST);
     translate(SSH2_MSG_REQUEST_SUCCESS);
     translate(SSH2_MSG_REQUEST_FAILURE);
@@ -770,7 +775,8 @@ struct ssh_tag {
 
     bufchain banner;   /* accumulates banners during do_ssh2_authconn */
 
-    int pkt_ctx;
+    Pkt_KCtx pkt_kctx;
+    Pkt_ACtx pkt_actx;
 
     void *x11auth;
 
@@ -1387,7 +1393,8 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
            }
        }
        log_packet(ssh->logctx, PKT_INCOMING, st->pktin->type,
-                  ssh2_pkt_type(ssh->pkt_ctx, st->pktin->type),
+                  ssh2_pkt_type(ssh->pkt_kctx, ssh->pkt_actx,
+                                st->pktin->type),
                   st->pktin->data+6, st->pktin->length-6,
                   nblanks, &blank);
     }
@@ -1453,7 +1460,8 @@ static int s_wrpkt_prepare(Ssh ssh, struct Packet *pkt, int *offset_p)
 
 static int s_write(Ssh ssh, void *data, int len)
 {
-    log_packet(ssh->logctx, PKT_OUTGOING, -1, NULL, data, len, 0, NULL);
+    if (ssh->logctx)
+       log_packet(ssh->logctx, PKT_OUTGOING, -1, NULL, data, len, 0, NULL);
     return sk_write(ssh->s, (char *)data, len);
 }
 
@@ -1734,7 +1742,7 @@ static int ssh2_pkt_construct(Ssh ssh, struct Packet *pkt)
 
     if (ssh->logctx)
        log_packet(ssh->logctx, PKT_OUTGOING, pkt->data[5],
-                  ssh2_pkt_type(ssh->pkt_ctx, pkt->data[5]),
+                  ssh2_pkt_type(ssh->pkt_kctx, ssh->pkt_actx, pkt->data[5]),
                   pkt->body, pkt->length - (pkt->body - pkt->data),
                   pkt->nblanks, pkt->blanks);
     sfree(pkt->blanks); pkt->blanks = NULL;
@@ -2377,6 +2385,47 @@ static void ssh_fix_verstring(char *str)
     }
 }
 
+/*
+ * Send an appropriate SSH version string.
+ */
+static void ssh_send_verstring(Ssh ssh, char *svers)
+{
+    char *verstring;
+
+    if (ssh->version == 2) {
+       /*
+        * Construct a v2 version string.
+        */
+       verstring = dupprintf("SSH-2.0-%s\015\012", sshver);
+    } else {
+       /*
+        * Construct a v1 version string.
+        */
+       verstring = dupprintf("SSH-%s-%s\012",
+                             (ssh_versioncmp(svers, "1.5") <= 0 ?
+                              svers : "1.5"),
+                             sshver);
+    }
+
+    ssh_fix_verstring(verstring);
+
+    if (ssh->version == 2) {
+       size_t len;
+       /*
+        * Record our version string.
+        */
+       len = strcspn(verstring, "\015\012");
+       ssh->v_c = snewn(len + 1, char);
+       memcpy(ssh->v_c, verstring, len);
+       ssh->v_c[len] = 0;
+    }
+
+    logeventf(ssh, "We claim version: %.*s",
+             strcspn(verstring, "\015\012"), verstring);
+    s_write(ssh, verstring, strlen(verstring));
+    sfree(verstring);
+}
+
 static int do_ssh_init(Ssh ssh, unsigned char c)
 {
     struct do_ssh_init_state {
@@ -2455,65 +2504,43 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
        crStop(0);
     }
 
-    {
-        char *verstring;
+    if (s->proto2 && (ssh->cfg.sshprot >= 2 || !s->proto1))
+       ssh->version = 2;
+    else
+       ssh->version = 1;
 
-        if (s->proto2 && (ssh->cfg.sshprot >= 2 || !s->proto1)) {
-            /*
-             * Construct a v2 version string.
-             */
-            verstring = dupprintf("SSH-2.0-%s\015\012", sshver);
-            ssh->version = 2;
-        } else {
-            /*
-             * Construct a v1 version string.
-             */
-            verstring = dupprintf("SSH-%s-%s\012",
-                                  (ssh_versioncmp(s->version, "1.5") <= 0 ?
-                                   s->version : "1.5"),
-                                  sshver);
-            ssh->version = 1;
-        }
+    logeventf(ssh, "Using SSH protocol version %d", ssh->version);
 
-        ssh_fix_verstring(verstring);
+    /* Send the version string, if we haven't already */
+    if (ssh->cfg.sshprot != 3)
+       ssh_send_verstring(ssh, s->version);
 
-        if (ssh->version == 2) {
-           size_t len;
-            /*
-             * Hash our version string and their version string.
-             */
-           len = strcspn(verstring, "\015\012");
-           ssh->v_c = snewn(len + 1, char);
-           memcpy(ssh->v_c, verstring, len);
-           ssh->v_c[len] = 0;
-           len = strcspn(s->vstring, "\015\012");
-           ssh->v_s = snewn(len + 1, char);
-           memcpy(ssh->v_s, s->vstring, len);
-           ssh->v_s[len] = 0;
+    if (ssh->version == 2) {
+       size_t len;
+       /*
+        * Record their version string.
+        */
+       len = strcspn(s->vstring, "\015\012");
+       ssh->v_s = snewn(len + 1, char);
+       memcpy(ssh->v_s, s->vstring, len);
+       ssh->v_s[len] = 0;
            
-            /*
-             * Initialise SSH-2 protocol.
-             */
-            ssh->protocol = ssh2_protocol;
-            ssh2_protocol_setup(ssh);
-            ssh->s_rdpkt = ssh2_rdpkt;
-        } else {
-            /*
-             * Initialise SSH-1 protocol.
-             */
-            ssh->protocol = ssh1_protocol;
-            ssh1_protocol_setup(ssh);
-            ssh->s_rdpkt = ssh1_rdpkt;
-        }
-        logeventf(ssh, "We claim version: %.*s",
-                  strcspn(verstring, "\015\012"), verstring);
-       s_write(ssh, verstring, strlen(verstring));
-        sfree(verstring);
-       if (ssh->version == 2)
-           do_ssh2_transport(ssh, NULL, -1, NULL);
+       /*
+        * Initialise SSH-2 protocol.
+        */
+       ssh->protocol = ssh2_protocol;
+       ssh2_protocol_setup(ssh);
+       ssh->s_rdpkt = ssh2_rdpkt;
+    } else {
+       /*
+        * Initialise SSH-1 protocol.
+        */
+       ssh->protocol = ssh1_protocol;
+       ssh1_protocol_setup(ssh);
+       ssh->s_rdpkt = ssh1_rdpkt;
     }
-
-    logeventf(ssh, "Using SSH protocol version %d", ssh->version);
+    if (ssh->version == 2)
+       do_ssh2_transport(ssh, NULL, -1, NULL);
 
     update_specials_menu(ssh->frontend);
     ssh->state = SSH_STATE_BEFORE_SIZE;
@@ -2573,7 +2600,9 @@ static void ssh_set_frozen(Ssh ssh, int frozen)
 static void ssh_gotdata(Ssh ssh, unsigned char *data, int datalen)
 {
     /* Log raw data, if we're in that mode. */
-    log_packet(ssh->logctx, PKT_INCOMING, -1, NULL, data, datalen, 0, NULL);
+    if (ssh->logctx)
+       log_packet(ssh->logctx, PKT_INCOMING, -1, NULL, data, datalen,
+                  0, NULL);
 
     crBegin(ssh->ssh_gotdata_crstate);
 
@@ -2798,6 +2827,17 @@ static const char *connect_to_host(Ssh ssh, char *host, int port,
        return err;
     }
 
+    /*
+     * If the SSH version number's fixed, set it now, and if it's SSH-2,
+     * send the version string too.
+     */
+    if (ssh->cfg.sshprot == 0)
+       ssh->version = 1;
+    if (ssh->cfg.sshprot == 3) {
+       ssh->version = 2;
+       ssh_send_verstring(ssh, NULL);
+    }
+
     return NULL;
 }
 
@@ -5148,7 +5188,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
        s->maclist = macs, s->nmacs = lenof(macs);
 
   begin_key_exchange:
-    ssh->pkt_ctx &= ~SSH2_PKTCTX_KEX_MASK;
+    ssh->pkt_kctx = SSH2_PKTCTX_NOKEX;
     {
        int i, j, commalist_started;
 
@@ -5597,7 +5637,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
          */
         if (!ssh->kex->pdata) {
             logevent("Doing Diffie-Hellman group exchange");
-            ssh->pkt_ctx |= SSH2_PKTCTX_DHGEX;
+            ssh->pkt_kctx = SSH2_PKTCTX_DHGEX;
             /*
              * Work out how big a DH group we will need to allow that
              * much data.
@@ -5622,7 +5662,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
             s->kex_init_value = SSH2_MSG_KEX_DH_GEX_INIT;
             s->kex_reply_value = SSH2_MSG_KEX_DH_GEX_REPLY;
         } else {
-            ssh->pkt_ctx |= SSH2_PKTCTX_DHGROUP;
+            ssh->pkt_kctx = SSH2_PKTCTX_DHGROUP;
             ssh->kex_ctx = dh_setup_group(ssh->kex);
             s->kex_init_value = SSH2_MSG_KEXDH_INIT;
             s->kex_reply_value = SSH2_MSG_KEXDH_REPLY;
@@ -5681,7 +5721,7 @@ static int do_ssh2_transport(Ssh ssh, void *vin, int inlen,
     } else {
        logeventf(ssh, "Doing RSA key exchange with hash %s",
                  ssh->kex->hash->text_name);
-       ssh->pkt_ctx |= SSH2_PKTCTX_RSAKEX;
+       ssh->pkt_kctx = SSH2_PKTCTX_RSAKEX;
         /*
          * RSA key exchange. First expect a KEXRSA_PUBKEY packet
          * from the server.
@@ -7037,7 +7077,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
         * just in case it succeeds, and (b) so that we know what
         * authentication methods we can usefully try next.
         */
-       ssh->pkt_ctx &= ~SSH2_PKTCTX_AUTH_MASK;
+       ssh->pkt_actx = SSH2_PKTCTX_NOAUTH;
 
        s->pktout = ssh2_pkt_init(SSH2_MSG_USERAUTH_REQUEST);
        ssh2_pkt_addstring(s->pktout, s->username);
@@ -7171,7 +7211,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                    in_commasep_string("keyboard-interactive", methods, methlen);
            }
 
-           ssh->pkt_ctx &= ~SSH2_PKTCTX_AUTH_MASK;
+           ssh->pkt_actx = SSH2_PKTCTX_NOAUTH;
 
            if (s->can_pubkey && !s->done_agent && s->nkeys) {
 
@@ -7179,8 +7219,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                 * Attempt public-key authentication using a key from Pageant.
                 */
 
-               ssh->pkt_ctx &= ~SSH2_PKTCTX_AUTH_MASK;
-               ssh->pkt_ctx |= SSH2_PKTCTX_PUBLICKEY;
+               ssh->pkt_actx = SSH2_PKTCTX_PUBLICKEY;
 
                logeventf(ssh, "Trying Pageant key #%d", s->keyi);
 
@@ -7327,8 +7366,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                struct ssh2_userkey *key;   /* not live over crReturn */
                char *passphrase;           /* not live over crReturn */
 
-               ssh->pkt_ctx &= ~SSH2_PKTCTX_AUTH_MASK;
-               ssh->pkt_ctx |= SSH2_PKTCTX_PUBLICKEY;
+               ssh->pkt_actx = SSH2_PKTCTX_PUBLICKEY;
 
                s->tried_pubkey_config = TRUE;
 
@@ -7507,8 +7545,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
 
                s->type = AUTH_TYPE_KEYBOARD_INTERACTIVE;
 
-               ssh->pkt_ctx &= ~SSH2_PKTCTX_AUTH_MASK;
-               ssh->pkt_ctx |= SSH2_PKTCTX_KBDINTER;
+               ssh->pkt_actx = SSH2_PKTCTX_KBDINTER;
 
                s->pktout = ssh2_pkt_init(SSH2_MSG_USERAUTH_REQUEST);
                ssh2_pkt_addstring(s->pktout, s->username);
@@ -7650,8 +7687,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                int ret; /* not live over crReturn */
                int changereq_first_time; /* not live over crReturn */
 
-               ssh->pkt_ctx &= ~SSH2_PKTCTX_AUTH_MASK;
-               ssh->pkt_ctx |= SSH2_PKTCTX_PASSWORD;
+               ssh->pkt_actx = SSH2_PKTCTX_PASSWORD;
 
                s->cur_prompt = new_prompts(ssh->frontend);
                s->cur_prompt->to_server = TRUE;
@@ -8522,7 +8558,8 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh->deferred_len = 0;
     ssh->deferred_size = 0;
     ssh->fallback_cmd = 0;
-    ssh->pkt_ctx = 0;
+    ssh->pkt_kctx = SSH2_PKTCTX_NOKEX;
+    ssh->pkt_actx = SSH2_PKTCTX_NOAUTH;
     ssh->x11auth = NULL;
     ssh->v1_compressing = FALSE;
     ssh->v2_outgoing_sequence = 0;