Re-architected the top level of the SSH protocol handlers.
authorsimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Wed, 24 Nov 2004 20:35:15 +0000 (20:35 +0000)
committersimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Wed, 24 Nov 2004 20:35:15 +0000 (20:35 +0000)
ssh1_protocol() and ssh2_protocol() are now high-level functions
which see _every_ SSH packet and decide which lower-level function
to pass it to. Also, they each support a dispatch table of simple
handler functions for message types which can arrive at any time.
Results are:

 - ignore, debug and disconnect messages are now handled by the
   dispatch table rather than being warts in the rdpkt functions

 - SSH2_MSG_WINDOW_ADJUST is handled by the dispatch table, which
   means that do_ssh2_authconn doesn't have to explicitly
   special-case it absolutely every time it waits for a response to
   its latest channel request

 - the top-level SSH2 protocol function chooses whether messages get
   funnelled to the transport layer or the auth/conn layer based on
   the message number ranges defined in the SSH architecture draft -
   so things that should go to auth/conn go there even in the middle
   of a rekey (although a special case is that nothing goes to
   auth/conn until initial kex has finished). This should fix the
   other half of ssh2-kex-data.

git-svn-id: svn://svn.tartarus.org/sgt/putty@4901 cda61777-01e9-0310-a592-d414129be87e

ssh.c

diff --git a/ssh.c b/ssh.c
index 700aa70..3426098 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -512,6 +512,7 @@ struct ssh_rportfwd {
 struct Packet {
     long length;
     int type;
+    unsigned long sequence;
     unsigned char *data;
     unsigned char *body;
     long savedpos;
@@ -529,6 +530,8 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen,
                          struct Packet *pktin);
 static void ssh2_protocol(Ssh ssh, unsigned char *in, int inlen,
                          struct Packet *pktin);
+static void ssh1_protocol_setup(Ssh ssh);
+static void ssh2_protocol_setup(Ssh ssh);
 static void ssh_size(void *handle, int width, int height);
 static void ssh_special(void *handle, Telnet_Special);
 static int ssh2_try_send(struct ssh_channel *c);
@@ -558,6 +561,8 @@ struct rdpkt2_state_tag {
     struct Packet *pktin;
 };
 
+typedef void (*handler_fn_t)(Ssh ssh, struct Packet *pktin);
+
 struct ssh_tag {
     const struct plug_function_table *fn;
     /* the above field _must_ be first in the structure */
@@ -654,8 +659,8 @@ struct ssh_tag {
     int ssh2_rdpkt_crstate;
     int do_ssh_init_crstate;
     int ssh_gotdata_crstate;
-    int ssh1_protocol_crstate;
     int do_ssh1_login_crstate;
+    int do_ssh1_connection_crstate;
     int do_ssh2_transport_crstate;
     int do_ssh2_authconn_crstate;
 
@@ -667,6 +672,9 @@ struct ssh_tag {
     struct rdpkt1_state_tag rdpkt1_state;
     struct rdpkt2_state_tag rdpkt2_state;
 
+    /* ssh1 and ssh2 use this for different things, but both use it */
+    int protocol_initial_phase_done;
+
     void (*protocol) (Ssh ssh, unsigned char *in, int inlen,
                      struct Packet *pkt);
     struct Packet *(*s_rdpkt) (Ssh ssh, unsigned char **data, int *datalen);
@@ -684,6 +692,12 @@ struct ssh_tag {
      */
     void *agent_response;
     int agent_response_len;
+
+    /*
+     * Dispatch table for packet types that we may have to deal
+     * with at any time.
+     */
+    handler_fn_t packet_dispatch[256];
 };
 
 #define logevent(s) logevent(ssh->frontend, s)
@@ -873,8 +887,6 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
 
     crBegin(ssh->ssh1_rdpkt_crstate);
 
-  next_packet:
-
     st->pktin = ssh_new_packet();
 
     st->pktin->type = 0;
@@ -992,48 +1004,6 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
                   nblanks, &blank);
     }
 
-    if (st->pktin->type == SSH1_SMSG_STDOUT_DATA ||
-       st->pktin->type == SSH1_SMSG_STDERR_DATA ||
-       st->pktin->type == SSH1_MSG_DEBUG ||
-       st->pktin->type == SSH1_SMSG_AUTH_TIS_CHALLENGE ||
-       st->pktin->type == SSH1_SMSG_AUTH_CCARD_CHALLENGE) {
-       long stringlen = GET_32BIT(st->pktin->body);
-       if (stringlen + 4 != st->pktin->length) {
-           bombout(("Received data packet with bogus string length"));
-           ssh_free_packet(st->pktin);
-           crStop(NULL);
-       }
-    }
-
-    if (st->pktin->type == SSH1_MSG_DEBUG) {
-        char *buf, *msg;
-        int msglen;
-
-        ssh_pkt_getstring(st->pktin, &msg, &msglen);
-        buf = dupprintf("Remote debug message: %.*s", msglen, msg);
-       logevent(buf);
-        sfree(buf);
-
-       ssh_free_packet(st->pktin);
-       goto next_packet;
-    } else if (st->pktin->type == SSH1_MSG_IGNORE) {
-       /* do nothing */
-       ssh_free_packet(st->pktin);
-       goto next_packet;
-    }
-
-    if (st->pktin->type == SSH1_MSG_DISCONNECT) {
-       /* log reason code in disconnect message */
-       char *msg;
-       int msglen;
-
-        ssh_pkt_getstring(st->pktin, &msg, &msglen);
-
-       bombout(("Server sent disconnect message:\n\"%.*s\"", msglen, msg));
-       ssh_free_packet(st->pktin);
-       crStop(NULL);
-    }
-
     crFinish(st->pktin);
 }
 
@@ -1043,8 +1013,6 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
 
     crBegin(ssh->ssh2_rdpkt_crstate);
 
-  next_packet:
-
     st->pktin = ssh_new_packet();
 
     st->pktin->type = 0;
@@ -1136,7 +1104,8 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
        ssh_free_packet(st->pktin);
        crStop(NULL);
     }
-    st->incoming_sequence++;          /* whether or not we MACed */
+
+    st->pktin->sequence = st->incoming_sequence++;
 
     /*
      * Decompress packet payload.
@@ -1191,115 +1160,6 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, unsigned char **data, int *datalen)
                   nblanks, &blank);
     }
 
-    switch (st->pktin->type) {
-        /*
-         * These packets we must handle instantly.
-         */
-      case SSH2_MSG_DISCONNECT:
-        {
-            /* log reason code in disconnect message */
-            char *buf, *msg;
-            int nowlen, reason, msglen;
-
-            reason = ssh_pkt_getuint32(st->pktin);
-            ssh_pkt_getstring(st->pktin, &msg, &msglen);
-
-            if (reason > 0 && reason < lenof(ssh2_disconnect_reasons)) {
-                buf = dupprintf("Received disconnect message (%s)",
-                               ssh2_disconnect_reasons[reason]);
-            } else {
-                buf = dupprintf("Received disconnect message (unknown"
-                               " type %d)", reason);
-            }
-            logevent(buf);
-           sfree(buf);
-            buf = dupprintf("Disconnection message text: %n%.*s",
-                           &nowlen, msglen, msg);
-            logevent(buf);
-            bombout(("Server sent disconnect message\ntype %d (%s):\n\"%s\"",
-                     reason,
-                     (reason > 0 && reason < lenof(ssh2_disconnect_reasons)) ?
-                     ssh2_disconnect_reasons[reason] : "unknown",
-                     buf+nowlen));
-           sfree(buf);
-           ssh_free_packet(st->pktin);
-            crStop(NULL);
-        }
-        break;
-      case SSH2_MSG_IGNORE:
-       ssh_free_packet(st->pktin);
-       goto next_packet;
-      case SSH2_MSG_DEBUG:
-       {
-           /* log the debug message */
-           char *buf, *msg;
-           int msglen;
-           int always_display;
-
-           /* XXX maybe we should actually take notice of this */
-            always_display = ssh2_pkt_getbool(st->pktin);
-            ssh_pkt_getstring(st->pktin, &msg, &msglen);
-
-            buf = dupprintf("Remote debug message: %.*s", msglen, msg);
-           logevent(buf);
-            sfree(buf);
-       }
-       ssh_free_packet(st->pktin);
-        goto next_packet;
-
-        /*
-         * These packets we need do nothing about here.
-         */
-      case SSH2_MSG_UNIMPLEMENTED:
-      case SSH2_MSG_SERVICE_REQUEST:
-      case SSH2_MSG_SERVICE_ACCEPT:
-      case SSH2_MSG_KEXINIT:
-      case SSH2_MSG_NEWKEYS:
-      case SSH2_MSG_KEXDH_INIT:
-      case SSH2_MSG_KEXDH_REPLY:
-      /* case SSH2_MSG_KEX_DH_GEX_REQUEST: duplicate case value */
-      /* case SSH2_MSG_KEX_DH_GEX_GROUP: duplicate case value */
-      case SSH2_MSG_KEX_DH_GEX_INIT:
-      case SSH2_MSG_KEX_DH_GEX_REPLY:
-      case SSH2_MSG_USERAUTH_REQUEST:
-      case SSH2_MSG_USERAUTH_FAILURE:
-      case SSH2_MSG_USERAUTH_SUCCESS:
-      case SSH2_MSG_USERAUTH_BANNER:
-      case SSH2_MSG_USERAUTH_PK_OK:
-      /* case SSH2_MSG_USERAUTH_PASSWD_CHANGEREQ: duplicate case value */
-      /* case SSH2_MSG_USERAUTH_INFO_REQUEST: duplicate case value */
-      case SSH2_MSG_USERAUTH_INFO_RESPONSE:
-      case SSH2_MSG_GLOBAL_REQUEST:
-      case SSH2_MSG_REQUEST_SUCCESS:
-      case SSH2_MSG_REQUEST_FAILURE:
-      case SSH2_MSG_CHANNEL_OPEN:
-      case SSH2_MSG_CHANNEL_OPEN_CONFIRMATION:
-      case SSH2_MSG_CHANNEL_OPEN_FAILURE:
-      case SSH2_MSG_CHANNEL_WINDOW_ADJUST:
-      case SSH2_MSG_CHANNEL_DATA:
-      case SSH2_MSG_CHANNEL_EXTENDED_DATA:
-      case SSH2_MSG_CHANNEL_EOF:
-      case SSH2_MSG_CHANNEL_CLOSE:
-      case SSH2_MSG_CHANNEL_REQUEST:
-      case SSH2_MSG_CHANNEL_SUCCESS:
-      case SSH2_MSG_CHANNEL_FAILURE:
-        break;
-
-        /*
-         * For anything else we send SSH2_MSG_UNIMPLEMENTED.
-         */
-      default:
-       {
-           struct Packet *pktout;
-           pktout = ssh2_pkt_init(SSH2_MSG_UNIMPLEMENTED);
-           ssh2_pkt_adduint32(pktout, st->incoming_sequence - 1);
-           /* UNIMPLEMENTED messages MUST appear in the same order as
-            * the messages they respond to. Hence, never queue them. */
-           ssh2_pkt_send_noqueue(ssh, pktout);
-       }
-        break;
-    }
-
     crFinish(st->pktin);
 }
 
@@ -2273,6 +2133,7 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
        logevent("Using SSH protocol version 2");
        sk_write(ssh->s, verstring, strlen(verstring));
        ssh->protocol = ssh2_protocol;
+       ssh2_protocol_setup(ssh);
        ssh->version = 2;
        ssh->s_rdpkt = ssh2_rdpkt;
     } else {
@@ -2290,6 +2151,7 @@ static int do_ssh_init(Ssh ssh, unsigned char c)
        logevent("Using SSH protocol version 1");
        sk_write(ssh->s, verstring, strlen(verstring));
        ssh->protocol = ssh1_protocol;
+       ssh1_protocol_setup(ssh);
        ssh->version = 1;
        ssh->s_rdpkt = ssh1_rdpkt;
     }
@@ -2672,6 +2534,8 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
 
     crBegin(ssh->do_ssh1_login_crstate);
 
+    random_init();
+
     if (!pktin)
        crWaitUntil(pktin);
 
@@ -3574,18 +3438,10 @@ void sshfwd_unthrottle(struct ssh_channel *c, int bufsize)
     }
 }
 
-static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen,
-                         struct Packet *pktin)
+static void do_ssh1_connection(Ssh ssh, unsigned char *in, int inlen,
+                              struct Packet *pktin)
 {
-    crBegin(ssh->ssh1_protocol_crstate);
-
-    random_init();
-
-    while (!do_ssh1_login(ssh, in, inlen, pktin)) {
-       crReturnV;
-    }
-    if (ssh->state == SSH_STATE_CLOSED)
-       crReturnV;
+    crBegin(ssh->do_ssh1_connection_crstate);
 
     if (ssh->cfg.agentfwd && agent_exists()) {
        logevent("Requesting agent forwarding");
@@ -4201,6 +4057,74 @@ static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen,
 }
 
 /*
+ * Handle the top-level SSH2 protocol.
+ */
+static void ssh1_msg_debug(Ssh ssh, struct Packet *pktin)
+{
+    char *buf, *msg;
+    int msglen;
+
+    ssh_pkt_getstring(pktin, &msg, &msglen);
+    buf = dupprintf("Remote debug message: %.*s", msglen, msg);
+    logevent(buf);
+    sfree(buf);
+}
+
+static void ssh1_msg_disconnect(Ssh ssh, struct Packet *pktin)
+{
+    /* log reason code in disconnect message */
+    char *msg;
+    int msglen;
+
+    ssh_pkt_getstring(pktin, &msg, &msglen);
+    bombout(("Server sent disconnect message:\n\"%.*s\"", msglen, msg));
+}
+
+void ssh_msg_ignore(Ssh ssh, struct Packet *pktin)
+{
+    /* Do nothing, because we're ignoring it! Duhh. */
+}
+
+static void ssh1_protocol_setup(Ssh ssh)
+{
+    int i;
+
+    /*
+     * Most messages are handled by the coroutines.
+     */
+    for (i = 0; i < 256; i++)
+       ssh->packet_dispatch[i] = NULL;
+
+    /*
+     * These special message types we install handlers for.
+     */
+    ssh->packet_dispatch[SSH1_MSG_DISCONNECT] = ssh1_msg_disconnect;
+    ssh->packet_dispatch[SSH1_MSG_IGNORE] = ssh_msg_ignore;
+    ssh->packet_dispatch[SSH1_MSG_DEBUG] = ssh1_msg_debug;
+}
+
+static void ssh1_protocol(Ssh ssh, unsigned char *in, int inlen,
+                         struct Packet *pktin)
+{
+    if (ssh->state == SSH_STATE_CLOSED)
+       return;
+
+    if (pktin && ssh->packet_dispatch[pktin->type]) {
+       ssh->packet_dispatch[pktin->type](ssh, pktin);
+       return;
+    }
+
+    if (!ssh->protocol_initial_phase_done) {
+       if (do_ssh1_login(ssh, in, inlen, pktin))
+           ssh->protocol_initial_phase_done = TRUE;
+       else
+           return;
+    }
+
+    do_ssh1_connection(ssh, in, inlen, pktin);
+}
+
+/*
  * Utility routine for decoding comma-separated strings in KEXINIT.
  */
 static int in_commasep_string(char *needle, char *haystack, int haylen)
@@ -4812,7 +4736,7 @@ static int do_ssh2_transport(Ssh ssh, unsigned char *in, int inlen,
      * it would only confuse the layer above.
      */
     if (!s->first_kex) {
-       crReturn(0);
+       crReturn(1);
     }
     s->first_kex = 0;
 
@@ -4900,6 +4824,15 @@ static void ssh2_set_window(struct ssh_channel *c, unsigned newwin)
     }
 }
 
+void ssh2_msg_channel_window_adjust(Ssh ssh, struct Packet *pktin)
+{
+    unsigned i = ssh_pkt_getuint32(pktin);
+    struct ssh_channel *c;
+    c = find234(ssh->channels, &i, ssh_channelfind);
+    if (c && !c->closes)
+       c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
+}
+
 /*
  * Handle the SSH2 userauth and connection layers.
  */
@@ -5775,6 +5708,13 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
     ssh->channels = newtree234(ssh_channelcmp);
 
     /*
+     * Set up handlers for some connection protocol messages, so we
+     * don't have to handle them repeatedly in this coroutine.
+     */
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_WINDOW_ADJUST] =
+       ssh2_msg_channel_window_adjust;
+
+    /*
      * Create the main session channel.
      */
     if (!ssh->cfg.ssh_no_shell) {
@@ -5829,17 +5769,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
        ssh2_pkt_adduint32(s->pktout, x11_get_screen_number(ssh->cfg.x11_display));
        ssh2_pkt_send(ssh, s->pktout);
 
-       do {
-           crWaitUntilV(pktin);
-           if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-               unsigned i = ssh_pkt_getuint32(pktin);
-               struct ssh_channel *c;
-               c = find234(ssh->channels, &i, ssh_channelfind);
-               if (!c)
-                   continue;          /* nonexistent channel */
-               c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-           }
-       } while (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST);
+       crWaitUntilV(pktin);
 
        if (pktin->type != SSH2_MSG_CHANNEL_SUCCESS) {
            if (pktin->type != SSH2_MSG_CHANNEL_FAILURE) {
@@ -5995,17 +5925,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                        ssh2_pkt_adduint32(s->pktout, sport);
                        ssh2_pkt_send(ssh, s->pktout);
 
-                       do {
-                           crWaitUntilV(pktin);
-                           if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-                               unsigned i = ssh_pkt_getuint32(pktin);
-                               struct ssh_channel *c;
-                               c = find234(ssh->channels, &i, ssh_channelfind);
-                               if (!c)
-                                   continue;/* nonexistent channel */
-                               c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-                           }
-                       } while (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST);
+                       crWaitUntilV(pktin);
 
                        if (pktin->type != SSH2_MSG_REQUEST_SUCCESS) {
                            if (pktin->type != SSH2_MSG_REQUEST_FAILURE) {
@@ -6036,17 +5956,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
        ssh2_pkt_addbool(s->pktout, 1);        /* want reply */
        ssh2_pkt_send(ssh, s->pktout);
 
-       do {
-           crWaitUntilV(pktin);
-           if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-               unsigned i = ssh_pkt_getuint32(pktin);
-               struct ssh_channel *c;
-               c = find234(ssh->channels, &i, ssh_channelfind);
-               if (!c)
-                   continue;          /* nonexistent channel */
-               c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-           }
-       } while (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST);
+       crWaitUntilV(pktin);
 
        if (pktin->type != SSH2_MSG_CHANNEL_SUCCESS) {
            if (pktin->type != SSH2_MSG_CHANNEL_FAILURE) {
@@ -6088,17 +5998,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
        ssh2_pkt_send(ssh, s->pktout);
        ssh->state = SSH_STATE_INTERMED;
 
-       do {
-           crWaitUntilV(pktin);
-           if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-               unsigned i = ssh_pkt_getuint32(pktin);
-               struct ssh_channel *c;
-               c = find234(ssh->channels, &i, ssh_channelfind);
-               if (!c)
-                   continue;          /* nonexistent channel */
-               c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-           }
-       } while (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST);
+       crWaitUntilV(pktin);
 
        if (pktin->type != SSH2_MSG_CHANNEL_SUCCESS) {
            if (pktin->type != SSH2_MSG_CHANNEL_FAILURE) {
@@ -6155,17 +6055,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
        s->env_left = s->num_env;
 
        while (s->env_left > 0) {
-           do {
-               crWaitUntilV(pktin);
-               if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-                   unsigned i = ssh_pkt_getuint32(pktin);
-                   struct ssh_channel *c;
-                   c = find234(ssh->channels, &i, ssh_channelfind);
-                   if (!c)
-                       continue;              /* nonexistent channel */
-                   c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-               }
-           } while (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST);
+           crWaitUntilV(pktin);
 
            if (pktin->type != SSH2_MSG_CHANNEL_SUCCESS) {
                if (pktin->type != SSH2_MSG_CHANNEL_FAILURE) {
@@ -6224,17 +6114,9 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
            ssh2_pkt_addbool(s->pktout, 1);            /* want reply */
        }
        ssh2_pkt_send(ssh, s->pktout);
-       do {
-           crWaitUntilV(pktin);
-           if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-               unsigned i = ssh_pkt_getuint32(pktin);
-               struct ssh_channel *c;
-               c = find234(ssh->channels, &i, ssh_channelfind);
-               if (!c)
-                   continue;          /* nonexistent channel */
-               c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-           }
-       } while (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST);
+
+       crWaitUntilV(pktin);
+
        if (pktin->type != SSH2_MSG_CHANNEL_SUCCESS) {
            if (pktin->type != SSH2_MSG_CHANNEL_FAILURE) {
                bombout(("Unexpected response to shell/command request:"
@@ -6444,14 +6326,6 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                    crStopV;
                }
                continue;              /* remote sends close; ignore (FIXME) */
-           } else if (pktin->type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
-               unsigned i = ssh_pkt_getuint32(pktin);
-               struct ssh_channel *c;
-               c = find234(ssh->channels, &i, ssh_channelfind);
-               if (!c || c->closes)
-                   continue;          /* nonexistent or closing channel */
-               c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-               s->try_send = TRUE;
            } else if (pktin->type == SSH2_MSG_CHANNEL_OPEN_CONFIRMATION) {
                unsigned i = ssh_pkt_getuint32(pktin);
                struct ssh_channel *c;
@@ -6822,14 +6696,148 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
 }
 
 /*
+ * Handlers for SSH2 messages that might arrive at any moment.
+ */
+void ssh2_msg_disconnect(Ssh ssh, struct Packet *pktin)
+{
+    /* log reason code in disconnect message */
+    char *buf, *msg;
+    int nowlen, reason, msglen;
+
+    reason = ssh_pkt_getuint32(pktin);
+    ssh_pkt_getstring(pktin, &msg, &msglen);
+
+    if (reason > 0 && reason < lenof(ssh2_disconnect_reasons)) {
+       buf = dupprintf("Received disconnect message (%s)",
+                       ssh2_disconnect_reasons[reason]);
+    } else {
+       buf = dupprintf("Received disconnect message (unknown"
+                       " type %d)", reason);
+    }
+    logevent(buf);
+    sfree(buf);
+    buf = dupprintf("Disconnection message text: %n%.*s",
+                   &nowlen, msglen, msg);
+    logevent(buf);
+    bombout(("Server sent disconnect message\ntype %d (%s):\n\"%s\"",
+            reason,
+            (reason > 0 && reason < lenof(ssh2_disconnect_reasons)) ?
+            ssh2_disconnect_reasons[reason] : "unknown",
+            buf+nowlen));
+    sfree(buf);
+}
+
+void ssh2_msg_debug(Ssh ssh, struct Packet *pktin)
+{
+    /* log the debug message */
+    char *buf, *msg;
+    int msglen;
+    int always_display;
+
+    /* XXX maybe we should actually take notice of this */
+    always_display = ssh2_pkt_getbool(pktin);
+    ssh_pkt_getstring(pktin, &msg, &msglen);
+
+    buf = dupprintf("Remote debug message: %.*s", msglen, msg);
+    logevent(buf);
+    sfree(buf);
+}
+
+void ssh2_msg_something_unimplemented(Ssh ssh, struct Packet *pktin)
+{
+    struct Packet *pktout;
+    pktout = ssh2_pkt_init(SSH2_MSG_UNIMPLEMENTED);
+    ssh2_pkt_adduint32(pktout, pktin->sequence);
+    /*
+     * UNIMPLEMENTED messages MUST appear in the same order as the
+     * messages they respond to. Hence, never queue them.
+     */
+    ssh2_pkt_send_noqueue(ssh, pktout);
+}
+
+/*
  * Handle the top-level SSH2 protocol.
  */
+static void ssh2_protocol_setup(Ssh ssh)
+{
+    int i;
+
+    /*
+     * Most messages cause SSH2_MSG_UNIMPLEMENTED.
+     */
+    for (i = 0; i < 256; i++)
+       ssh->packet_dispatch[i] = ssh2_msg_something_unimplemented;
+
+    /*
+     * Any message we actually understand, we set to NULL so that
+     * the coroutines will get it.
+     */
+    ssh->packet_dispatch[SSH2_MSG_UNIMPLEMENTED] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_SERVICE_REQUEST] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_SERVICE_ACCEPT] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_KEXINIT] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_NEWKEYS] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_KEXDH_INIT] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_KEXDH_REPLY] = NULL;
+    /* ssh->packet_dispatch[SSH2_MSG_KEX_DH_GEX_REQUEST] = NULL; duplicate case value */
+    /* ssh->packet_dispatch[SSH2_MSG_KEX_DH_GEX_GROUP] = NULL; duplicate case value */
+    ssh->packet_dispatch[SSH2_MSG_KEX_DH_GEX_INIT] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_KEX_DH_GEX_REPLY] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_USERAUTH_REQUEST] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_USERAUTH_FAILURE] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_USERAUTH_SUCCESS] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_USERAUTH_BANNER] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_USERAUTH_PK_OK] = NULL;
+    /* ssh->packet_dispatch[SSH2_MSG_USERAUTH_PASSWD_CHANGEREQ] = NULL; duplicate case value */
+    /* ssh->packet_dispatch[SSH2_MSG_USERAUTH_INFO_REQUEST] = NULL; duplicate case value */
+    ssh->packet_dispatch[SSH2_MSG_USERAUTH_INFO_RESPONSE] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_GLOBAL_REQUEST] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_REQUEST_SUCCESS] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_REQUEST_FAILURE] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_OPEN] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_OPEN_CONFIRMATION] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_OPEN_FAILURE] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_WINDOW_ADJUST] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_DATA] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_EXTENDED_DATA] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_EOF] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_CLOSE] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_REQUEST] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_SUCCESS] = NULL;
+    ssh->packet_dispatch[SSH2_MSG_CHANNEL_FAILURE] = NULL;
+
+    /*
+     * These special message types we install handlers for.
+     */
+    ssh->packet_dispatch[SSH2_MSG_DISCONNECT] = ssh2_msg_disconnect;
+    ssh->packet_dispatch[SSH2_MSG_IGNORE] = ssh_msg_ignore; /* shared with ssh1 */
+    ssh->packet_dispatch[SSH2_MSG_DEBUG] = ssh2_msg_debug;
+}
+
 static void ssh2_protocol(Ssh ssh, unsigned char *in, int inlen,
                          struct Packet *pktin)
 {
-    if (do_ssh2_transport(ssh, in, inlen, pktin) == 0)
+    if (ssh->state == SSH_STATE_CLOSED)
+       return;
+
+    if (pktin && ssh->packet_dispatch[pktin->type]) {
+       ssh->packet_dispatch[pktin->type](ssh, pktin);
        return;
-    do_ssh2_authconn(ssh, in, inlen, pktin);
+    }
+
+    if (!ssh->protocol_initial_phase_done ||
+       (pktin && pktin->type >= 20 && pktin->type < 50)) {
+       if (do_ssh2_transport(ssh, in, inlen, pktin) &&
+           !ssh->protocol_initial_phase_done) {
+           ssh->protocol_initial_phase_done = TRUE;
+           /*
+            * Allow authconn to initialise itself.
+            */
+           do_ssh2_authconn(ssh, NULL, 0, NULL);
+       }
+    } else {
+       do_ssh2_authconn(ssh, in, inlen, pktin);
+    }
 }
 
 /*
@@ -6885,7 +6893,7 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh->ssh2_rdpkt_crstate = 0;
     ssh->do_ssh_init_crstate = 0;
     ssh->ssh_gotdata_crstate = 0;
-    ssh->ssh1_protocol_crstate = 0;
+    ssh->do_ssh1_connection_crstate = 0;
     ssh->do_ssh1_login_crstate = 0;
     ssh->do_ssh2_transport_crstate = 0;
     ssh->do_ssh2_authconn_crstate = 0;
@@ -6923,6 +6931,8 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
 
     ssh->protocol = NULL;
 
+    ssh->protocol_initial_phase_done = FALSE;
+
     p = connect_to_host(ssh, host, port, realhost, nodelay, keepalive);
     if (p != NULL)
        return p;