Prevent duplicate sk_close() calls on the same socket when the
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index e32f51e..5ad1ea4 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -19,8 +19,9 @@
                       if ((flags & FLAG_STDERR) && (flags & FLAG_VERBOSE)) \
                       fprintf(stderr, "%s\n", s); }
 
-#define bombout(msg) ( ssh_state = SSH_STATE_CLOSED, sk_close(s), \
-                       s = NULL, connection_fatal msg )
+#define bombout(msg) ( ssh_state = SSH_STATE_CLOSED, \
+                          (s ? sk_close(s), s = NULL : (void)0), \
+                          connection_fatal msg )
 
 #define SSH1_MSG_DISCONNECT                       1    /* 0x1 */
 #define SSH1_SMSG_PUBLIC_KEY                      2    /* 0x2 */
@@ -165,7 +166,7 @@ extern const struct ssh_cipher ssh_des;
 extern const struct ssh_cipher ssh_blowfish_ssh1;
 extern const struct ssh_cipher ssh_blowfish_ssh2;
 
-extern char *x11_init (Socket *, char *, void *, char **);
+extern char *x11_init (Socket *, char *, void *);
 extern void x11_close (Socket);
 extern void x11_send  (Socket , char *, int);
 extern void x11_invent_auth(char *, int, char *, int);
@@ -272,6 +273,7 @@ int (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL;
 static char *savedhost;
 static int savedport;
 static int ssh_send_ok;
+static int ssh_echoing, ssh_editing;
 
 static tree234 *ssh_channels;           /* indexed by local id */
 static struct ssh_channel *mainchan;   /* primary session channel */
@@ -1148,7 +1150,14 @@ static void ssh_gotdata(unsigned char *data, int datalen)
 }
 
 static int ssh_receive(Socket skt, int urgent, char *data, int len) {
-    if (!len) {
+    if (urgent==3) {
+        /* A socket error has occurred. */
+        ssh_state = SSH_STATE_CLOSED;
+        sk_close(s);
+        s = NULL;
+        connection_fatal(data);
+        return 0;
+    } else if (!len) {
        /* Connection has closed. */
        ssh_state = SSH_STATE_CLOSED;
        sk_close(s);
@@ -1835,8 +1844,11 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
             crReturnV;
         } else if (pktin.type == SSH1_SMSG_FAILURE) {
             c_write("Server refused to allocate pty\r\n", 32);
+            ssh_editing = ssh_echoing = 1;
         }
        logevent("Allocated pty");
+    } else {
+        ssh_editing = ssh_echoing = 1;
     }
 
     if (cfg.compression) {
@@ -1866,9 +1878,9 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
     if (eof_needed)
         ssh_special(TS_EOF);
 
+    ldisc_send(NULL, 0);               /* cause ldisc to notice changes */
     ssh_send_ok = 1;
     ssh_channels = newtree234(ssh_channelcmp);
-    begin_session();
     while (1) {
        crReturnV;
        if (ispkt) {
@@ -1896,11 +1908,9 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
                                PKT_END);
                    logevent("Rejected X11 connect request");
                } else {
-                    char *rh;
-
                    c = smalloc(sizeof(struct ssh_channel));
 
-                   if ( x11_init(&c->u.x11.s, cfg.x11_display, c, &rh) != NULL ) {
+                   if ( x11_init(&c->u.x11.s, cfg.x11_display, c) != NULL ) {
                      logevent("opening X11 forward connection failed");
                      sfree(c);
                      send_packet(SSH1_MSG_CHANNEL_OPEN_FAILURE,
@@ -2087,14 +2097,14 @@ static int in_commasep_string(char *needle, char *haystack, int haylen) {
 /*
  * SSH2 key creation method.
  */
-static void ssh2_mkkey(Bignum K, char *H, char chr, char *keyspace) {
+static void ssh2_mkkey(Bignum K, char *H, char *sessid, char chr, char *keyspace) {
     SHA_State s;
     /* First 20 bytes. */
     SHA_Init(&s);
     sha_mpint(&s, K);
     SHA_Bytes(&s, H, 20);
     SHA_Bytes(&s, &chr, 1);
-    SHA_Bytes(&s, H, 20);
+    SHA_Bytes(&s, sessid, 20);
     SHA_Final(&s, keyspace);
     /* Next 20 bytes. */
     SHA_Init(&s);
@@ -2124,6 +2134,7 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     static int hostkeylen, siglen;
     static void *hkey;                /* actual host key */
     static unsigned char exchange_hash[20];
+    static unsigned char first_exchange_hash[20];
     static unsigned char keyspace[40];
     static const struct ssh_cipher *preferred_cipher;
     static const struct ssh_compress *preferred_comp;
@@ -2380,8 +2391,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     fingerprint = hostkey->fingerprint(hkey);
     verify_ssh_host_key(savedhost, savedport, hostkey->keytype,
                         keystr, fingerprint);
-    logevent("Host key fingerprint is:");
-    logevent(fingerprint);
+    if (first_kex) {                  /* don't bother logging this in rekeys */
+       logevent("Host key fingerprint is:");
+       logevent(fingerprint);
+    }
     sfree(fingerprint);
     sfree(keystr);
     hostkey->freekey(hkey);
@@ -2404,14 +2417,23 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     cscomp->compress_init();
     sccomp->decompress_init();
     /*
-     * Set IVs after keys.
+     * Set IVs after keys. Here we use the exchange hash from the
+     * _first_ key exchange.
      */
-    ssh2_mkkey(K, exchange_hash, 'C', keyspace); cscipher->setcskey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'D', keyspace); sccipher->setsckey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'A', keyspace); cscipher->setcsiv(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'B', keyspace); sccipher->setsciv(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'E', keyspace); csmac->setcskey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'F', keyspace); scmac->setsckey(keyspace);
+    if (first_kex)
+       memcpy(first_exchange_hash, exchange_hash, sizeof(exchange_hash));
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'C', keyspace);
+    cscipher->setcskey(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'D', keyspace);
+    sccipher->setsckey(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'A', keyspace);
+    cscipher->setcsiv(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'B', keyspace);
+    sccipher->setsciv(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'E', keyspace);
+    csmac->setcskey(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'F', keyspace);
+    scmac->setsckey(keyspace);
 
     /*
      * If this is the first key exchange phase, we must pass the
@@ -2435,6 +2457,7 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     do {
         crReturn(1);
     } while (!(ispkt && pktin.type == SSH2_MSG_KEXINIT));
+    logevent("Server initiated key re-exchange");
     goto begin_key_exchange;
 
     crFinish(1);
@@ -2740,9 +2763,12 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                 crReturnV;
             }
             c_write("Server refused to allocate pty\r\n", 32);
+            ssh_editing = ssh_echoing = 1;
         } else {
             logevent("Allocated pty");
         }
+    } else {
+        ssh_editing = ssh_echoing = 1;
     }
 
     /*
@@ -2790,8 +2816,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
     /*
      * Transfer data!
      */
+    ldisc_send(NULL, 0);               /* cause ldisc to notice changes */
     ssh_send_ok = 1;
-    begin_session();
     while (1) {
         static int try_send;
        crReturnV;
@@ -2908,10 +2934,9 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                 c = smalloc(sizeof(struct ssh_channel));
 
                 if (typelen == 3 && !memcmp(type, "x11", 3)) {
-                    char *rh;
                     if (!ssh_X11_fwd_enabled)
                         error = "X11 forwarding is not enabled";
-                    else if ( x11_init(&c->u.x11.s, cfg.x11_display, c, &rh) != NULL ) {
+                    else if ( x11_init(&c->u.x11.s, cfg.x11_display, c) != NULL ) {
                         error = "Unable to open an X11 connection";
                     } else {
                         c->type = CHAN_X11;
@@ -3003,6 +3028,8 @@ static char *ssh_init (char *host, int port, char **realhost) {
 #endif
 
     ssh_send_ok = 0;
+    ssh_editing = 0;
+    ssh_echoing = 0;
 
     p = connect_to_host(host, port, realhost);
     if (p != NULL)
@@ -3098,6 +3125,12 @@ static Socket ssh_socket(void) { return s; }
 
 static int ssh_sendok(void) { return ssh_send_ok; }
 
+static int ssh_ldisc(int option) {
+    if (option == LD_ECHO) return ssh_echoing;
+    if (option == LD_EDIT) return ssh_editing;
+    return FALSE;
+}
+
 Backend ssh_backend = {
     ssh_init,
     ssh_send,
@@ -3105,5 +3138,6 @@ Backend ssh_backend = {
     ssh_special,
     ssh_socket,
     ssh_sendok,
+    ssh_ldisc,
     22
 };