When iterating over all channels for a dead SSH connection, don't miss out
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 0616eff..a7f1c6b 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -590,6 +590,17 @@ struct ssh_channel {
      * A channel is completely finished with when all four bits are set.
      */
     int closes;
+
+    /*
+     * This flag indicates that a close is pending on the outgoing
+     * side of the channel: that is, wherever we're getting the data
+     * for this channel has sent us some data followed by EOF. We
+     * can't actually close the channel until we've finished sending
+     * the data, so we set this flag instead to remind us to
+     * initiate the closing process once our buffer is clear.
+     */
+    int pending_close;
+
     /*
      * True if this channel is causing the underlying connection to be
      * throttled.
@@ -930,6 +941,13 @@ struct ssh_tag {
      * Fully qualified host name, which we need if doing GSSAPI.
      */
     char *fullhostname;
+
+#ifndef NO_GSSAPI
+    /*
+     * GSSAPI libraries for this session.
+     */
+    struct ssh_gss_liblist *gsslibs;
+#endif
 };
 
 #define logevent(s) logevent(ssh->frontend, s)
@@ -2844,6 +2862,7 @@ static int ssh_do_close(Ssh ssh, int notify_exit)
                x11_close(c->u.x11.s);
                break;
              case CHAN_SOCKDATA:
+             case CHAN_SOCKDATA_DORMANT:
                pfd_close(c->u.pfd.s);
                break;
            }
@@ -4153,7 +4172,7 @@ void sshfwd_close(struct ssh_channel *c)
     if (ssh->state == SSH_STATE_CLOSED)
        return;
 
-    if (c && !c->closes) {
+    if (!c->closes) {
        /*
         * If halfopen is true, we have sent
         * CHANNEL_OPEN for this channel, but it hasn't even been
@@ -4165,14 +4184,42 @@ void sshfwd_close(struct ssh_channel *c)
            if (ssh->version == 1) {
                send_packet(ssh, SSH1_MSG_CHANNEL_CLOSE, PKT_INT, c->remoteid,
                            PKT_END);
+               c->closes = 1;                 /* sent MSG_CLOSE */
            } else {
-               struct Packet *pktout;
-               pktout = ssh2_pkt_init(SSH2_MSG_CHANNEL_CLOSE);
-               ssh2_pkt_adduint32(pktout, c->remoteid);
-               ssh2_pkt_send(ssh, pktout);
+               int bytes_to_send = bufchain_size(&c->v.v2.outbuffer);
+               if (bytes_to_send > 0) {
+                   /*
+                    * If we still have unsent data in our outgoing
+                    * buffer for this channel, we can't actually
+                    * initiate a close operation yet or that data
+                    * will be lost. Instead, set the pending_close
+                    * flag so that when we do clear the buffer
+                    * we'll start closing the channel.
+                    */
+                   char logmsg[160] = {'\0'};
+                   sprintf(
+                           logmsg,
+                           "Forwarded port pending to be closed : "
+                           "%d bytes remaining",
+                           bytes_to_send);
+                   logevent(logmsg);
+
+                   c->pending_close = TRUE;
+               } else {
+                   /*
+                    * No locally buffered data, so we can send the
+                    * close message immediately.
+                    */
+                   struct Packet *pktout;
+                   pktout = ssh2_pkt_init(SSH2_MSG_CHANNEL_CLOSE);
+                   ssh2_pkt_adduint32(pktout, c->remoteid);
+                   ssh2_pkt_send(ssh, pktout);
+                   c->closes = 1;                     /* sent MSG_CLOSE */
+                   logevent("Nothing left to send, closing channel");
+               }
            }
        }
-       c->closes = 1;                 /* sent MSG_CLOSE */
+
        if (c->type == CHAN_X11) {
            c->u.x11.s = NULL;
            logevent("Forwarded X11 connection terminated");
@@ -4311,6 +4358,7 @@ static void ssh_rportfwd_succfail(Ssh ssh, struct Packet *pktin, void *ctx)
 
        rpf = del234(ssh->rportfwds, pf);
        assert(rpf == pf);
+       pf->pfrec->remote = NULL;
        free_rportfwd(pf);
     }
 }
@@ -4487,6 +4535,8 @@ static void ssh_setup_portfwd(Ssh ssh, const Config *cfg)
            logeventf(ssh, "Cancelling %s", message);
            sfree(message);
 
+           /* epf->remote or epf->local may be NULL if setting up a
+            * forwarding failed. */
            if (epf->remote) {
                struct ssh_rportfwd *rpf = epf->remote;
                struct Packet *pktout;
@@ -4696,6 +4746,7 @@ static void ssh1_smsg_x11_open(Ssh ssh, struct Packet *pktin)
            c->halfopen = FALSE;
            c->localid = alloc_channel_id(ssh);
            c->closes = 0;
+           c->pending_close = FALSE;
            c->throttling_conn = 0;
            c->type = CHAN_X11; /* identify channel type */
            add234(ssh->channels, c);
@@ -4725,6 +4776,7 @@ static void ssh1_smsg_agent_open(Ssh ssh, struct Packet *pktin)
        c->halfopen = FALSE;
        c->localid = alloc_channel_id(ssh);
        c->closes = 0;
+       c->pending_close = FALSE;
        c->throttling_conn = 0;
        c->type = CHAN_AGENT;   /* identify channel type */
        c->u.a.lensofar = 0;
@@ -4779,6 +4831,7 @@ static void ssh1_msg_port_open(Ssh ssh, struct Packet *pktin)
            c->halfopen = FALSE;
            c->localid = alloc_channel_id(ssh);
            c->closes = 0;
+           c->pending_close = FALSE;
            c->throttling_conn = 0;
            c->type = CHAN_SOCKDATA;    /* identify channel type */
            add234(ssh->channels, c);
@@ -6345,7 +6398,7 @@ static int ssh2_try_send(struct ssh_channel *c)
     return bufchain_size(&c->v.v2.outbuffer);
 }
 
-static void ssh2_try_send_and_unthrottle(struct ssh_channel *c)
+static void ssh2_try_send_and_unthrottle(Ssh ssh, struct ssh_channel *c)
 {
     int bufsize;
     if (c->closes)
@@ -6369,6 +6422,19 @@ static void ssh2_try_send_and_unthrottle(struct ssh_channel *c)
            break;
        }
     }
+
+    /*
+     * If we've emptied the channel's output buffer and there's a
+     * pending close event, start the channel-closing procedure.
+     */
+    if (c->pending_close && bufchain_size(&c->v.v2.outbuffer) == 0) {
+       struct Packet *pktout;
+       pktout = ssh2_pkt_init(SSH2_MSG_CHANNEL_CLOSE);
+       ssh2_pkt_adduint32(pktout, c->remoteid);
+       ssh2_pkt_send(ssh, pktout);
+       c->closes = 1;
+       c->pending_close = FALSE;
+    }
 }
 
 /*
@@ -6379,6 +6445,7 @@ static void ssh2_channel_init(struct ssh_channel *c)
     Ssh ssh = c->ssh;
     c->localid = alloc_channel_id(ssh);
     c->closes = 0;
+    c->pending_close = FALSE;
     c->throttling_conn = FALSE;
     c->v.v2.locwindow = c->v.v2.locmaxwin = c->v.v2.remlocwin =
        ssh->cfg.ssh_simple ? OUR_V2_BIGWIN : OUR_V2_WINSIZE;
@@ -6562,7 +6629,7 @@ static void ssh2_msg_channel_window_adjust(Ssh ssh, struct Packet *pktin)
        return;
     if (!c->closes) {
        c->v.v2.remwindow += ssh_pkt_getuint32(pktin);
-       ssh2_try_send_and_unthrottle(c);
+       ssh2_try_send_and_unthrottle(ssh, c);
     }
 }
 
@@ -7128,12 +7195,14 @@ static void ssh2_msg_channel_open(Ssh ssh, struct Packet *pktin)
 }
 
 /*
- * Buffer banner messages for later display at some convenient point.
+ * Buffer banner messages for later display at some convenient point,
+ * if we're going to display them.
  */
 static void ssh2_msg_userauth_banner(Ssh ssh, struct Packet *pktin)
 {
     /* Arbitrary limit to prevent unbounded inflation of buffer */
-    if (bufchain_size(&ssh->banner) <= 131072) {
+    if (ssh->cfg.ssh_show_banner &&
+       bufchain_size(&ssh->banner) <= 131072) {
        char *banner = NULL;
        int size = 0;
        ssh_pkt_getstring(pktin, &banner, &size);
@@ -7479,6 +7548,9 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
        }
 
        while (1) {
+           char *methods = NULL;
+           int methlen = 0;
+
            /*
             * Wait for the result of the last authentication request.
             */
@@ -7528,8 +7600,6 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
             * helpfully try next.
             */
            if (pktin->type == SSH2_MSG_USERAUTH_FAILURE) {
-               char *methods;
-               int methlen;
                ssh_pkt_getstring(pktin, &methods, &methlen);
                if (!ssh2_pkt_getbool(pktin)) {
                    /*
@@ -7585,11 +7655,12 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                    in_commasep_string("password", methods, methlen);
                s->can_keyb_inter = ssh->cfg.try_ki_auth &&
                    in_commasep_string("keyboard-interactive", methods, methlen);
-#ifndef NO_GSSAPI              
-               ssh_gss_init();
+#ifndef NO_GSSAPI
+               if (!ssh->gsslibs)
+                   ssh->gsslibs = ssh_gss_setup(&ssh->cfg);
                s->can_gssapi = ssh->cfg.try_gssapi_auth &&
                    in_commasep_string("gssapi-with-mic", methods, methlen) &&
-                   n_ssh_gss_libraries > 0;
+                   ssh->gsslibs->nlibraries > 0;
 #endif
            }
 
@@ -7941,9 +8012,9 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                    s->gsslib = NULL;
                    for (i = 0; i < ngsslibs; i++) {
                        int want_id = ssh->cfg.ssh_gsslist[i];
-                       for (j = 0; j < n_ssh_gss_libraries; j++)
-                           if (ssh_gss_libraries[j].id == want_id) {
-                               s->gsslib = &ssh_gss_libraries[j];
+                       for (j = 0; j < ssh->gsslibs->nlibraries; j++)
+                           if (ssh->gsslibs->libraries[j].id == want_id) {
+                               s->gsslib = &ssh->gsslibs->libraries[j];
                                goto got_gsslib;   /* double break */
                            }
                    }
@@ -8493,11 +8564,16 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                sfree(s->password);
 
            } else {
+               char *str = dupprintf("No supported authentication methods available"
+                                     " (server sent: %.*s)",
+                                     methlen, methods);
 
-               ssh_disconnect(ssh, NULL,
+               ssh_disconnect(ssh, str,
                               "No supported authentication methods available",
                               SSH2_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
                               FALSE);
+               sfree(str);
+
                crStopV;
 
            }
@@ -8936,7 +9012,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
             * Try to send data on all channels if we can.
             */
            for (i = 0; NULL != (c = index234(ssh->channels, i)); i++)
-               ssh2_try_send_and_unthrottle(c);
+               ssh2_try_send_and_unthrottle(ssh, c);
        }
     }
 
@@ -9218,6 +9294,10 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh->max_data_size = parse_blocksize(ssh->cfg.ssh_rekey_data);
     ssh->kex_in_progress = FALSE;
 
+#ifndef NO_GSSAPI
+    ssh->gsslibs = NULL;
+#endif
+
     p = connect_to_host(ssh, host, port, realhost, nodelay, keepalive);
     if (p != NULL)
        return p;
@@ -9278,6 +9358,7 @@ static void ssh_free(void *handle)
                    x11_close(c->u.x11.s);
                break;
              case CHAN_SOCKDATA:
+             case CHAN_SOCKDATA_DORMANT:
                if (c->u.pfd.s != NULL)
                    pfd_close(c->u.pfd.s);
                break;
@@ -9290,7 +9371,7 @@ static void ssh_free(void *handle)
 
     if (ssh->rportfwds) {
        while ((pf = delpos234(ssh->rportfwds, 0)) != NULL)
-           sfree(pf);
+           free_rportfwd(pf);
        freetree234(ssh->rportfwds);
        ssh->rportfwds = NULL;
     }
@@ -9314,6 +9395,10 @@ static void ssh_free(void *handle)
     if (ssh->pinger)
        pinger_free(ssh->pinger);
     bufchain_clear(&ssh->queued_incoming_data);
+#ifndef NO_GSSAPI
+    if (ssh->gsslibs)
+       ssh_gss_cleanup(ssh->gsslibs);
+#endif
     sfree(ssh);
 
     random_unref();