Robustness in the face of sudden connection closures: we now make a
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 6abdb28..4243a04 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -515,6 +515,7 @@ static void ssh2_add_channel_data(struct ssh_channel *c, char *buf, int len);
 static void ssh_throttle_all(Ssh ssh, int enable, int bufsize);
 static void ssh2_set_window(struct ssh_channel *c, unsigned newwin);
 static int ssh_sendbuffer(void *handle);
+static void ssh_do_close(Ssh ssh);
 
 struct rdpkt1_state_tag {
     long len, pad, biglen, to_read;
@@ -668,8 +669,7 @@ static void logeventf(Ssh ssh, char *fmt, ...)
 #define bombout(msg) \
     do { \
         char *text = dupprintf msg; \
-        ssh->state = SSH_STATE_CLOSED; \
-        if (ssh->s) { sk_close(ssh->s); ssh->s = NULL; } \
+       ssh_do_close(ssh); \
         logevent(text); \
         connection_fatal(ssh->frontend, "%s", text); \
         sfree(text); \
@@ -2037,15 +2037,41 @@ static void ssh_gotdata(Ssh ssh, unsigned char *data, int datalen)
     crFinishV;
 }
 
-static int ssh_closing(Plug plug, char *error_msg, int error_code,
-                      int calling_back)
+static void ssh_do_close(Ssh ssh)
 {
-    Ssh ssh = (Ssh) plug;
+    int i;
+    struct ssh_channel *c;
+
     ssh->state = SSH_STATE_CLOSED;
     if (ssh->s) {
         sk_close(ssh->s);
         ssh->s = NULL;
     }
+    /*
+     * Now we must shut down any port and X forwardings going
+     * through this connection.
+     */
+    for (i = 0; NULL != (c = index234(ssh->channels, i)); i++) {
+       switch (c->type) {
+         case CHAN_X11:
+           x11_close(c->u.x11.s);
+           break;
+         case CHAN_SOCKDATA:
+           pfd_close(c->u.pfd.s);
+           break;
+       }
+       del234(ssh->channels, c);
+       if (ssh->version == 2)
+           bufchain_clear(&c->v.v2.outbuffer);
+       sfree(c);
+    }
+}
+
+static int ssh_closing(Plug plug, char *error_msg, int error_code,
+                      int calling_back)
+{
+    Ssh ssh = (Ssh) plug;
+    ssh_do_close(ssh);
     if (error_msg) {
        /* A socket error has occurred. */
        logevent(error_msg);
@@ -2061,10 +2087,7 @@ static int ssh_receive(Plug plug, int urgent, char *data, int len)
     Ssh ssh = (Ssh) plug;
     ssh_gotdata(ssh, (unsigned char *)data, len);
     if (ssh->state == SSH_STATE_CLOSED) {
-       if (ssh->s) {
-           sk_close(ssh->s);
-           ssh->s = NULL;
-       }
+       ssh_do_close(ssh);
        return 0;
     }
     return 1;
@@ -2180,7 +2203,7 @@ static void ssh_throttle_all(Ssh ssh, int enable, int bufsize)
            /* Agent channels require no buffer management. */
            break;
          case CHAN_SOCKDATA:
-           pfd_override_throttle(c->u.x11.s, enable);
+           pfd_override_throttle(c->u.pfd.s, enable);
            break;
        }
     }
@@ -2992,6 +3015,11 @@ void sshfwd_close(struct ssh_channel *c)
 {
     Ssh ssh = c->ssh;
 
+    if (ssh->state != SSH_STATE_SESSION) {
+       assert(ssh->state == SSH_STATE_CLOSED);
+       return;
+    }
+
     if (c && !c->closes) {
        /*
         * If the channel's remoteid is -1, we have sent
@@ -3026,6 +3054,11 @@ int sshfwd_write(struct ssh_channel *c, char *buf, int len)
 {
     Ssh ssh = c->ssh;
 
+    if (ssh->state != SSH_STATE_SESSION) {
+       assert(ssh->state == SSH_STATE_CLOSED);
+       return 0;
+    }
+
     if (ssh->version == 1) {
        send_packet(ssh, SSH1_MSG_CHANNEL_DATA,
                    PKT_INT, c->remoteid,
@@ -3048,6 +3081,11 @@ void sshfwd_unthrottle(struct ssh_channel *c, int bufsize)
 {
     Ssh ssh = c->ssh;
 
+    if (ssh->state != SSH_STATE_SESSION) {
+       assert(ssh->state == SSH_STATE_CLOSED);
+       return;
+    }
+
     if (ssh->version == 1) {
        if (c->v.v1.throttling && bufsize < SSH1_BUFFER_LIMIT) {
            c->v.v1.throttling = 0;
@@ -6141,7 +6179,7 @@ static void ssh_free(void *handle)
     sfree(ssh->do_ssh2_authconn_state);
     
     if (ssh->s)
-       sk_close(ssh->s);
+       ssh_do_close(ssh);
     sfree(ssh);
 }