Prevent duplicate sk_close() calls on the same socket when the
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 2ef8ac4..5ad1ea4 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -19,8 +19,9 @@
                       if ((flags & FLAG_STDERR) && (flags & FLAG_VERBOSE)) \
                       fprintf(stderr, "%s\n", s); }
 
-#define bombout(msg) ( ssh_state = SSH_STATE_CLOSED, sk_close(s), \
-                       s = NULL, connection_fatal msg )
+#define bombout(msg) ( ssh_state = SSH_STATE_CLOSED, \
+                          (s ? sk_close(s), s = NULL : (void)0), \
+                          connection_fatal msg )
 
 #define SSH1_MSG_DISCONNECT                       1    /* 0x1 */
 #define SSH1_SMSG_PUBLIC_KEY                      2    /* 0x2 */
@@ -1152,11 +1153,11 @@ static int ssh_receive(Socket skt, int urgent, char *data, int len) {
     if (urgent==3) {
         /* A socket error has occurred. */
         ssh_state = SSH_STATE_CLOSED;
+        sk_close(s);
         s = NULL;
         connection_fatal(data);
-        len = 0;
-    }
-    if (!len) {
+        return 0;
+    } else if (!len) {
        /* Connection has closed. */
        ssh_state = SSH_STATE_CLOSED;
        sk_close(s);
@@ -2096,14 +2097,14 @@ static int in_commasep_string(char *needle, char *haystack, int haylen) {
 /*
  * SSH2 key creation method.
  */
-static void ssh2_mkkey(Bignum K, char *H, char chr, char *keyspace) {
+static void ssh2_mkkey(Bignum K, char *H, char *sessid, char chr, char *keyspace) {
     SHA_State s;
     /* First 20 bytes. */
     SHA_Init(&s);
     sha_mpint(&s, K);
     SHA_Bytes(&s, H, 20);
     SHA_Bytes(&s, &chr, 1);
-    SHA_Bytes(&s, H, 20);
+    SHA_Bytes(&s, sessid, 20);
     SHA_Final(&s, keyspace);
     /* Next 20 bytes. */
     SHA_Init(&s);
@@ -2133,6 +2134,7 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     static int hostkeylen, siglen;
     static void *hkey;                /* actual host key */
     static unsigned char exchange_hash[20];
+    static unsigned char first_exchange_hash[20];
     static unsigned char keyspace[40];
     static const struct ssh_cipher *preferred_cipher;
     static const struct ssh_compress *preferred_comp;
@@ -2389,8 +2391,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     fingerprint = hostkey->fingerprint(hkey);
     verify_ssh_host_key(savedhost, savedport, hostkey->keytype,
                         keystr, fingerprint);
-    logevent("Host key fingerprint is:");
-    logevent(fingerprint);
+    if (first_kex) {                  /* don't bother logging this in rekeys */
+       logevent("Host key fingerprint is:");
+       logevent(fingerprint);
+    }
     sfree(fingerprint);
     sfree(keystr);
     hostkey->freekey(hkey);
@@ -2413,14 +2417,23 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     cscomp->compress_init();
     sccomp->decompress_init();
     /*
-     * Set IVs after keys.
+     * Set IVs after keys. Here we use the exchange hash from the
+     * _first_ key exchange.
      */
-    ssh2_mkkey(K, exchange_hash, 'C', keyspace); cscipher->setcskey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'D', keyspace); sccipher->setsckey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'A', keyspace); cscipher->setcsiv(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'B', keyspace); sccipher->setsciv(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'E', keyspace); csmac->setcskey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'F', keyspace); scmac->setsckey(keyspace);
+    if (first_kex)
+       memcpy(first_exchange_hash, exchange_hash, sizeof(exchange_hash));
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'C', keyspace);
+    cscipher->setcskey(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'D', keyspace);
+    sccipher->setsckey(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'A', keyspace);
+    cscipher->setcsiv(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'B', keyspace);
+    sccipher->setsciv(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'E', keyspace);
+    csmac->setcskey(keyspace);
+    ssh2_mkkey(K, exchange_hash, first_exchange_hash, 'F', keyspace);
+    scmac->setsckey(keyspace);
 
     /*
      * If this is the first key exchange phase, we must pass the
@@ -2444,6 +2457,7 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     do {
         crReturn(1);
     } while (!(ispkt && pktin.type == SSH2_MSG_KEXINIT));
+    logevent("Server initiated key re-exchange");
     goto begin_key_exchange;
 
     crFinish(1);