Fix `disconnect': arrange that we keep track of when we're expecting
[sgt/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 32d2b00..8ece19a 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -561,7 +561,7 @@ static void ssh2_add_channel_data(struct ssh_channel *c, char *buf, int len);
 static void ssh_throttle_all(Ssh ssh, int enable, int bufsize);
 static void ssh2_set_window(struct ssh_channel *c, unsigned newwin);
 static int ssh_sendbuffer(void *handle);
-static void ssh_do_close(Ssh ssh);
+static int ssh_do_close(Ssh ssh, int notify_exit);
 static unsigned long ssh_pkt_getuint32(struct Packet *pkt);
 static int ssh2_pkt_getbool(struct Packet *pkt);
 static void ssh_pkt_getstring(struct Packet *pkt, char **p, int *length);
@@ -642,6 +642,7 @@ struct ssh_tag {
     tree234 *channels;                /* indexed by local id */
     struct ssh_channel *mainchan;      /* primary session channel */
     int exitcode;
+    int close_expected;
 
     tree234 *rportfwds, *portfwds;
 
@@ -772,7 +773,7 @@ static void logeventf(Ssh ssh, const char *fmt, ...)
 #define bombout(msg) \
     do { \
         char *text = dupprintf msg; \
-       ssh_do_close(ssh); \
+       ssh_do_close(ssh, FALSE); \
         logevent(text); \
         connection_fatal(ssh->frontend, "%s", text); \
         sfree(text); \
@@ -2368,16 +2369,19 @@ static void ssh_gotdata(Ssh ssh, unsigned char *data, int datalen)
     crFinishV;
 }
 
-static void ssh_do_close(Ssh ssh)
+static int ssh_do_close(Ssh ssh, int notify_exit)
 {
-    int i;
+    int i, ret = 0;
     struct ssh_channel *c;
 
     ssh->state = SSH_STATE_CLOSED;
     if (ssh->s) {
         sk_close(ssh->s);
         ssh->s = NULL;
-       notify_remote_exit(ssh->frontend);
+        if (notify_exit)
+            notify_remote_exit(ssh->frontend);
+        else
+            ret = 1;
     }
     /*
      * Now we must shut down any port and X forwardings going
@@ -2399,20 +2403,29 @@ static void ssh_do_close(Ssh ssh)
            sfree(c);
        }
     }
+
+    return ret;
 }
 
 static int ssh_closing(Plug plug, const char *error_msg, int error_code,
                       int calling_back)
 {
     Ssh ssh = (Ssh) plug;
-    ssh_do_close(ssh);
+    int need_notify = ssh_do_close(ssh, FALSE);
+
+    if (!error_msg && !ssh->close_expected) {
+        error_msg = "Server unexpectedly closed network connection";
+    }
+
     if (error_msg) {
        /* A socket error has occurred. */
        logevent(error_msg);
        connection_fatal(ssh->frontend, "%s", error_msg);
     } else {
-       /* Otherwise, the remote side closed the connection normally. */
+        logevent("Server closed network connection");
     }
+    if (need_notify)
+        notify_remote_exit(ssh->frontend);
     return 0;
 }
 
@@ -2421,7 +2434,7 @@ static int ssh_receive(Plug plug, int urgent, char *data, int len)
     Ssh ssh = (Ssh) plug;
     ssh_gotdata(ssh, (unsigned char *)data, len);
     if (ssh->state == SSH_STATE_CLOSED) {
-       ssh_do_close(ssh);
+       ssh_do_close(ssh, TRUE);
        return 0;
     }
     return 1;
@@ -2922,6 +2935,7 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
                     * Terminate.
                     */
                    logevent("No username provided. Abandoning session.");
+                   ssh->close_expected = TRUE;
                     ssh_closing((Plug)ssh, NULL, 0, 0);
                    crStop(1);
                }
@@ -3270,6 +3284,7 @@ static int do_ssh1_login(Ssh ssh, unsigned char *in, int inlen,
                            PKT_END);
                logevent("Unable to authenticate");
                connection_fatal(ssh->frontend, "Unable to authenticate");
+               ssh->close_expected = TRUE;
                 ssh_closing((Plug)ssh, NULL, 0, 0);
                crStop(1);
            }
@@ -4322,6 +4337,7 @@ static void ssh1_smsg_exit_status(Ssh ssh, struct Packet *pktin)
      * encrypted packet, we close the session once
      * we've sent EXIT_CONFIRMATION.
      */
+    ssh->close_expected = TRUE;
     ssh_closing((Plug)ssh, NULL, 0, 0);
 }
 
@@ -5600,6 +5616,7 @@ static void ssh2_msg_channel_close(Ssh ssh, struct Packet *pktin)
        ssh2_pkt_addstring(s->pktout, "en");    /* language tag */
        ssh2_pkt_send_noqueue(ssh, s->pktout);
 #endif
+       ssh->close_expected = TRUE;
        ssh_closing((Plug)ssh, NULL, 0, 0);
     }
 }
@@ -5700,6 +5717,7 @@ static void ssh2_msg_channel_request(Ssh ssh, struct Packet *pktin)
        ssh2_pkt_addstring(pktout, "en");       /* language tag */
        ssh2_pkt_send_noqueue(ssh, pktout);
        connection_fatal(ssh->frontend, "%s", buf);
+       ssh->close_expected = TRUE;
        ssh_closing((Plug)ssh, NULL, 0, 0);
        return;
     }
@@ -6053,6 +6071,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                     * Terminate.
                     */
                    logevent("No username provided. Abandoning session.");
+                   ssh->close_expected = TRUE;
                     ssh_closing((Plug)ssh, NULL, 0, 0);
                    crStopV;
                }
@@ -6614,6 +6633,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                        logevent("Unable to authenticate");
                        connection_fatal(ssh->frontend,
                                         "Unable to authenticate");
+                       ssh->close_expected = TRUE;
                         ssh_closing((Plug)ssh, NULL, 0, 0);
                        crStopV;
                    }
@@ -6812,6 +6832,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen,
                                   " methods available");
                ssh2_pkt_addstring(s->pktout, "en");    /* language tag */
                ssh2_pkt_send_noqueue(ssh, s->pktout);
+               ssh->close_expected = TRUE;
                 ssh_closing((Plug)ssh, NULL, 0, 0);
                crStopV;
            }
@@ -7393,6 +7414,7 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
     ssh->kex_ctx = NULL;
     ssh->hostkey = NULL;
     ssh->exitcode = -1;
+    ssh->close_expected = FALSE;
     ssh->state = SSH_STATE_PREPACKET;
     ssh->size_needed = FALSE;
     ssh->eof_needed = FALSE;
@@ -7548,7 +7570,7 @@ static void ssh_free(void *handle)
        ssh->crcda_ctx = NULL;
     }
     if (ssh->s)
-       ssh_do_close(ssh);
+       ssh_do_close(ssh, TRUE);
     expire_timer_context(ssh);
     if (ssh->pinger)
        pinger_free(ssh->pinger);