Experimental Rlogin support, thanks to Delian Delchev. Local flow
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index a7f4130..f75041b 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -51,6 +51,7 @@
 #define SSH1_CMSG_EXIT_CONFIRMATION               33   /* 0x21 */
 #define SSH1_MSG_IGNORE                           32   /* 0x20 */
 #define SSH1_MSG_DEBUG                            36   /* 0x24 */
+#define SSH1_CMSG_REQUEST_COMPRESSION             37   /* 0x25 */
 #define SSH1_CMSG_AUTH_TIS                        39   /* 0x27 */
 #define SSH1_SMSG_AUTH_TIS_CHALLENGE              40   /* 0x28 */
 #define SSH1_CMSG_AUTH_TIS_RESPONSE               41   /* 0x29 */
 #define SSH2_MSG_CHANNEL_SUCCESS                  99   /* 0x63 */
 #define SSH2_MSG_CHANNEL_FAILURE                  100  /* 0x64 */
 
+#define SSH2_DISCONNECT_HOST_NOT_ALLOWED_TO_CONNECT 1  /* 0x1 */
+#define SSH2_DISCONNECT_PROTOCOL_ERROR            2    /* 0x2 */
+#define SSH2_DISCONNECT_KEY_EXCHANGE_FAILED       3    /* 0x3 */
+#define SSH2_DISCONNECT_HOST_AUTHENTICATION_FAILED 4   /* 0x4 */
+#define SSH2_DISCONNECT_MAC_ERROR                 5    /* 0x5 */
+#define SSH2_DISCONNECT_COMPRESSION_ERROR         6    /* 0x6 */
+#define SSH2_DISCONNECT_SERVICE_NOT_AVAILABLE     7    /* 0x7 */
+#define SSH2_DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED 8 /* 0x8 */
+#define SSH2_DISCONNECT_HOST_KEY_NOT_VERIFIABLE   9    /* 0x9 */
+#define SSH2_DISCONNECT_CONNECTION_LOST           10   /* 0xa */
+#define SSH2_DISCONNECT_BY_APPLICATION            11   /* 0xb */
+
 #define SSH2_OPEN_ADMINISTRATIVELY_PROHIBITED     1    /* 0x1 */
 #define SSH2_OPEN_CONNECT_FAILED                  2    /* 0x2 */
 #define SSH2_OPEN_UNKNOWN_CHANNEL_TYPE            3    /* 0x3 */
@@ -158,8 +171,8 @@ const static struct ssh_cipher *ciphers[] = { &ssh_blowfish_ssh2, &ssh_3des_ssh2
 extern const struct ssh_kex ssh_diffiehellman;
 const static struct ssh_kex *kex_algs[] = { &ssh_diffiehellman };
 
-extern const struct ssh_hostkey ssh_dss;
-const static struct ssh_hostkey *hostkey_algs[] = { &ssh_dss };
+extern const struct ssh_signkey ssh_dss;
+const static struct ssh_signkey *hostkey_algs[] = { &ssh_dss };
 
 extern const struct ssh_mac ssh_md5, ssh_sha1, ssh_sha1_buggy;
 
@@ -174,10 +187,19 @@ const static struct ssh_mac *macs[] = {
 const static struct ssh_mac *buggymacs[] = {
     &ssh_sha1_buggy, &ssh_md5, &ssh_mac_none };
 
+static void ssh_comp_none_init(void) { }
+static int ssh_comp_none_block(unsigned char *block, int len,
+                              unsigned char **outblock, int *outlen) {
+    return 0;
+}
 const static struct ssh_compress ssh_comp_none = {
-    "none"
+    "none",
+    ssh_comp_none_init, ssh_comp_none_block,
+    ssh_comp_none_init, ssh_comp_none_block
 };
-const static struct ssh_compress *compressions[] = { &ssh_comp_none };
+extern const struct ssh_compress ssh_zlib;
+const static struct ssh_compress *compressions[] = {
+    &ssh_zlib, &ssh_comp_none };
 
 /*
  * 2-3-4 tree storing channels.
@@ -209,11 +231,13 @@ struct Packet {
     long maxlen;
 };
 
-static SHA_State exhash;
+static SHA_State exhash, exhashbase;
 
 static Socket s = NULL;
 
 static unsigned char session_key[32];
+static int ssh1_compressing;
+static int ssh_agentfwd_enabled;
 static const struct ssh_cipher *cipher = NULL;
 static const struct ssh_cipher *cscipher = NULL;
 static const struct ssh_cipher *sccipher = NULL;
@@ -222,7 +246,7 @@ static const struct ssh_mac *scmac = NULL;
 static const struct ssh_compress *cscomp = NULL;
 static const struct ssh_compress *sccomp = NULL;
 static const struct ssh_kex *kex = NULL;
-static const struct ssh_hostkey *hostkey = NULL;
+static const struct ssh_signkey *hostkey = NULL;
 int (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL;
 
 static char *savedhost;
@@ -233,13 +257,14 @@ static tree234 *ssh_channels;           /* indexed by local id */
 static struct ssh_channel *mainchan;   /* primary session channel */
 
 static enum {
+    SSH_STATE_PREPACKET,
     SSH_STATE_BEFORE_SIZE,
     SSH_STATE_INTERMED,
     SSH_STATE_SESSION,
     SSH_STATE_CLOSED
-} ssh_state = SSH_STATE_BEFORE_SIZE;
+} ssh_state = SSH_STATE_PREPACKET;
 
-static int size_needed = FALSE;
+static int size_needed = FALSE, eof_needed = FALSE;
 
 static struct Packet pktin = { 0, 0, NULL, NULL, 0 };
 static struct Packet pktout = { 0, 0, NULL, NULL, 0 };
@@ -249,6 +274,7 @@ static void (*ssh_protocol)(unsigned char *in, int inlen, int ispkt);
 static void ssh1_protocol(unsigned char *in, int inlen, int ispkt);
 static void ssh2_protocol(unsigned char *in, int inlen, int ispkt);
 static void ssh_size(void);
+static void ssh_special (Telnet_Special);
 
 static int (*s_rdpkt)(unsigned char **data, int *datalen);
 
@@ -331,8 +357,8 @@ next_packet:
 
     if (pktin.maxlen < st->biglen) {
        pktin.maxlen = st->biglen;
-       pktin.data = (pktin.data == NULL ? malloc(st->biglen+APIEXTRA) :
-                     realloc(pktin.data, st->biglen+APIEXTRA));
+       pktin.data = (pktin.data == NULL ? smalloc(st->biglen+APIEXTRA) :
+                     srealloc(pktin.data, st->biglen+APIEXTRA));
        if (!pktin.data)
            fatalbox("Out of memory");
     }
@@ -361,9 +387,6 @@ next_packet:
     debug(("\r\n"));
 #endif
 
-    pktin.type = pktin.data[st->pad];
-    pktin.body = pktin.data + st->pad + 1;
-
     st->realcrc = crc32(pktin.data, st->biglen-4);
     st->gotcrc = GET_32BIT(pktin.data+st->biglen-4);
     if (st->gotcrc != st->realcrc) {
@@ -371,6 +394,40 @@ next_packet:
         crReturn(0);
     }
 
+    pktin.body = pktin.data + st->pad + 1;
+
+    if (ssh1_compressing) {
+       unsigned char *decompblk;
+       int decomplen;
+#if 0
+       int i;
+       debug(("Packet payload pre-decompression:\n"));
+       for (i = -1; i < pktin.length; i++)
+           debug(("  %02x", (unsigned char)pktin.body[i]));
+       debug(("\r\n"));
+#endif
+       zlib_decompress_block(pktin.body-1, pktin.length+1,
+                             &decompblk, &decomplen);
+
+       if (pktin.maxlen < st->pad + decomplen) {
+           pktin.maxlen = st->pad + decomplen;
+           pktin.data = srealloc(pktin.data, pktin.maxlen+APIEXTRA);
+            pktin.body = pktin.data + st->pad + 1;
+           if (!pktin.data)
+               fatalbox("Out of memory");
+       }
+
+       memcpy(pktin.body-1, decompblk, decomplen);
+       sfree(decompblk);
+       pktin.length = decomplen-1;
+#if 0
+       debug(("Packet payload post-decompression:\n"));
+       for (i = -1; i < pktin.length; i++)
+           debug(("  %02x", (unsigned char)pktin.body[i]));
+       debug(("\r\n"));
+#endif
+    }
+
     if (pktin.type == SSH1_SMSG_STDOUT_DATA ||
         pktin.type == SSH1_SMSG_STDERR_DATA ||
         pktin.type == SSH1_MSG_DEBUG ||
@@ -383,6 +440,8 @@ next_packet:
         }
     }
 
+    pktin.type = pktin.body[-1];
+
     if (pktin.type == SSH1_MSG_DEBUG) {
        /* log debug message */
        char buf[80];
@@ -419,8 +478,8 @@ next_packet:
 
     if (pktin.maxlen < st->cipherblk) {
        pktin.maxlen = st->cipherblk;
-       pktin.data = (pktin.data == NULL ? malloc(st->cipherblk+APIEXTRA) :
-                     realloc(pktin.data, st->cipherblk+APIEXTRA));
+       pktin.data = (pktin.data == NULL ? smalloc(st->cipherblk+APIEXTRA) :
+                     srealloc(pktin.data, st->cipherblk+APIEXTRA));
        if (!pktin.data)
            fatalbox("Out of memory");
     }
@@ -467,8 +526,8 @@ next_packet:
      */
     if (pktin.maxlen < st->packetlen+st->maclen) {
        pktin.maxlen = st->packetlen+st->maclen;
-       pktin.data = (pktin.data == NULL ? malloc(pktin.maxlen+APIEXTRA) :
-                     realloc(pktin.data, pktin.maxlen+APIEXTRA));
+       pktin.data = (pktin.data == NULL ? smalloc(pktin.maxlen+APIEXTRA) :
+                     srealloc(pktin.data, pktin.maxlen+APIEXTRA));
        if (!pktin.data)
            fatalbox("Out of memory");
     }
@@ -503,6 +562,34 @@ next_packet:
     }
     st->incoming_sequence++;               /* whether or not we MACed */
 
+    /*
+     * Decompress packet payload.
+     */
+    {
+       unsigned char *newpayload;
+       int newlen;
+       if (sccomp && sccomp->decompress(pktin.data+5, pktin.length-5,
+                                        &newpayload, &newlen)) {
+           if (pktin.maxlen < newlen+5) {
+               pktin.maxlen = newlen+5;
+               pktin.data = (pktin.data == NULL ? smalloc(pktin.maxlen+APIEXTRA) :
+                             srealloc(pktin.data, pktin.maxlen+APIEXTRA));
+               if (!pktin.data)
+                   fatalbox("Out of memory");
+           }
+           pktin.length = 5 + newlen;
+           memcpy(pktin.data+5, newpayload, newlen);
+#if 0
+           debug(("Post-decompression payload:\r\n"));
+           for (st->i = 0; st->i < newlen; st->i++)
+               debug(("  %02x", (unsigned char)pktin.data[5+st->i]));
+           debug(("\r\n"));
+#endif
+
+           sfree(newpayload);
+       }
+    }
+
     pktin.savedpos = 6;
     pktin.type = pktin.data[5];
 
@@ -512,7 +599,7 @@ next_packet:
     crFinish(0);
 }
 
-static void s_wrpkt_start(int type, int len) {
+static void ssh1_pktout_size(int len) {
     int pad, biglen;
 
     len += 5;                         /* type and CRC */
@@ -525,29 +612,55 @@ static void s_wrpkt_start(int type, int len) {
 #ifdef MSCRYPTOAPI
        /* Allocate enough buffer space for extra block
         * for MS CryptEncrypt() */
-       pktout.data = (pktout.data == NULL ? malloc(biglen+12) :
-                      realloc(pktout.data, biglen+12));
+       pktout.data = (pktout.data == NULL ? smalloc(biglen+12) :
+                      srealloc(pktout.data, biglen+12));
 #else
-       pktout.data = (pktout.data == NULL ? malloc(biglen+4) :
-                      realloc(pktout.data, biglen+4));
+       pktout.data = (pktout.data == NULL ? smalloc(biglen+4) :
+                      srealloc(pktout.data, biglen+4));
 #endif
        if (!pktout.data)
            fatalbox("Out of memory");
     }
+    pktout.body = pktout.data+4+pad+1;
+}
 
+static void s_wrpkt_start(int type, int len) {
+    ssh1_pktout_size(len);
     pktout.type = type;
-    pktout.body = pktout.data+4+pad+1;
 }
 
 static void s_wrpkt(void) {
     int pad, len, biglen, i;
     unsigned long crc;
 
+    pktout.body[-1] = pktout.type;
+
+    if (ssh1_compressing) {
+       unsigned char *compblk;
+       int complen;
+#if 0
+       debug(("Packet payload pre-compression:\n"));
+       for (i = -1; i < pktout.length; i++)
+           debug(("  %02x", (unsigned char)pktout.body[i]));
+       debug(("\r\n"));
+#endif
+       zlib_compress_block(pktout.body-1, pktout.length+1,
+                           &compblk, &complen);
+       ssh1_pktout_size(complen-1);
+       memcpy(pktout.body-1, compblk, complen);
+       sfree(compblk);
+#if 0
+       debug(("Packet payload post-compression:\n"));
+       for (i = -1; i < pktout.length; i++)
+           debug(("  %02x", (unsigned char)pktout.body[i]));
+       debug(("\r\n"));
+#endif
+    }
+
     len = pktout.length + 5;          /* type and CRC */
     pad = 8 - (len%8);
     biglen = len + pad;
 
-    pktout.body[-1] = pktout.type;
     for (i=0; i<pad; i++)
        pktout.data[i+4] = random_byte();
     crc = crc32(pktout.data+4, biglen-4);
@@ -685,8 +798,8 @@ static void ssh2_pkt_adddata(void *data, int len) {
     pktout.length += len;
     if (pktout.maxlen < pktout.length) {
         pktout.maxlen = pktout.length + 256;
-       pktout.data = (pktout.data == NULL ? malloc(pktout.maxlen+APIEXTRA) :
-                       realloc(pktout.data, pktout.maxlen+APIEXTRA));
+       pktout.data = (pktout.data == NULL ? smalloc(pktout.maxlen+APIEXTRA) :
+                       srealloc(pktout.data, pktout.maxlen+APIEXTRA));
         if (!pktout.data)
             fatalbox("Out of memory");
     }
@@ -728,7 +841,7 @@ static void ssh2_pkt_addstring(char *data) {
 static char *ssh2_mpint_fmt(Bignum b, int *len) {
     unsigned char *p;
     int i, n = b[0];
-    p = malloc(n * 2 + 1);
+    p = smalloc(n * 2 + 1);
     if (!p)
         fatalbox("out of memory");
     p[0] = 0;
@@ -749,13 +862,33 @@ static void ssh2_pkt_addmp(Bignum b) {
     p = ssh2_mpint_fmt(b, &len);
     ssh2_pkt_addstring_start();
     ssh2_pkt_addstring_data(p, len);
-    free(p);
+    sfree(p);
 }
 static void ssh2_pkt_send(void) {
     int cipherblk, maclen, padding, i;
     static unsigned long outgoing_sequence = 0;
 
     /*
+     * Compress packet payload.
+     */
+#if 0
+    debug(("Pre-compression payload:\r\n"));
+    for (i = 5; i < pktout.length; i++)
+       debug(("  %02x", (unsigned char)pktout.data[i]));
+    debug(("\r\n"));
+#endif
+    {
+       unsigned char *newpayload;
+       int newlen;
+       if (cscomp && cscomp->compress(pktout.data+5, pktout.length-5,
+                                      &newpayload, &newlen)) {
+           pktout.length = 5;
+           ssh2_pkt_adddata(newpayload, newlen);
+           sfree(newpayload);
+       }
+    }
+
+    /*
      * Add padding. At least four bytes, and must also bring total
      * length (minus MAC) up to a multiple of the block size.
      */
@@ -795,7 +928,7 @@ void bndebug(char *string, Bignum b) {
     for (i = 0; i < len; i++)
         debug((" %02x", p[i]));
     debug(("\r\n"));
-    free(p);
+    sfree(p);
 }
 #endif
 
@@ -804,7 +937,7 @@ static void sha_mpint(SHA_State *s, Bignum b) {
     int len;
     p = ssh2_mpint_fmt(b, &len);
     sha_string(s, p, len);
-    free(p);
+    sfree(p);
 }
 
 /*
@@ -849,6 +982,7 @@ static Bignum ssh2_pkt_getmp(void) {
         else
             b[j/2+1] |= ((unsigned char)p[i]);
     }
+    while (b[0] > 1 && b[b[0]] == 0) b[0]--;
     return b;
 }
 
@@ -894,6 +1028,7 @@ static int do_ssh_init(unsigned char c) {
            break;
     }
 
+    ssh_agentfwd_enabled = FALSE;
     rdpkt2_state.incoming_sequence = 0;
 
     *vsp = 0;
@@ -910,12 +1045,12 @@ static int do_ssh_init(unsigned char c) {
          * This is a v2 server. Begin v2 protocol.
          */
         char *verstring = "SSH-2.0-PuTTY";
-        SHA_Init(&exhash);
+        SHA_Init(&exhashbase);
         /*
          * Hash our version string and their version string.
          */
-        sha_string(&exhash, verstring, strlen(verstring));
-        sha_string(&exhash, vstring, strcspn(vstring, "\r\n"));
+        sha_string(&exhashbase, verstring, strlen(verstring));
+        sha_string(&exhashbase, vstring, strcspn(vstring, "\r\n"));
         sprintf(vstring, "%s\n", verstring);
         sprintf(vlog, "We claim version: %s", verstring);
         logevent(vlog);
@@ -939,6 +1074,7 @@ static int do_ssh_init(unsigned char c) {
         ssh_version = 1;
         s_rdpkt = ssh1_rdpkt;
     }
+    ssh_state = SSH_STATE_BEFORE_SIZE;
 
     crFinish(0);
 }
@@ -985,14 +1121,22 @@ static void ssh_gotdata(unsigned char *data, int datalen)
     crFinishV;
 }
 
-static int ssh_receive(Socket s, int urgent, char *data, int len) {
+static int ssh_receive(Socket skt, int urgent, char *data, int len) {
     if (!len) {
        /* Connection has closed. */
+       ssh_state = SSH_STATE_CLOSED;
        sk_close(s);
        s = NULL;
        return 0;
     }
     ssh_gotdata (data, len);
+    if (ssh_state == SSH_STATE_CLOSED) {
+        if (s) {
+            sk_close(s);
+            s = NULL;
+        }
+        return 0;
+    }
     return 1;
 }
 
@@ -1010,7 +1154,7 @@ static char *connect_to_host(char *host, int port, char **realhost)
     int FWport;
 #endif
 
-    savedhost = malloc(1+strlen(host));
+    savedhost = smalloc(1+strlen(host));
     if (!savedhost)
        fatalbox("Out of memory");
     strcpy(savedhost, host);
@@ -1040,7 +1184,7 @@ static char *connect_to_host(char *host, int port, char **realhost)
     /*
      * Open socket.
      */
-    s = sk_new(addr, port, ssh_receive);
+    s = sk_new(addr, port, 0, ssh_receive);
     if ( (err = sk_socket_error(s)) )
        return err;
 
@@ -1116,7 +1260,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
 
     len = (hostkey.bytes > servkey.bytes ? hostkey.bytes : servkey.bytes);
 
-    rsabuf = malloc(len);
+    rsabuf = smalloc(len);
     if (!rsabuf)
        fatalbox("Out of memory");
 
@@ -1129,13 +1273,13 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
          */
         int len = rsastr_len(&hostkey);
         char fingerprint[100];
-        char *keystr = malloc(len);
+        char *keystr = smalloc(len);
         if (!keystr)
             fatalbox("Out of memory");
         rsastr_fmt(keystr, &hostkey);
         rsa_fingerprint(fingerprint, sizeof(fingerprint), &hostkey);
         verify_ssh_host_key(savedhost, savedport, "rsa", keystr, fingerprint);
-        free(keystr);
+        sfree(keystr);
     }
 
     for (i=0; i<32; i++) {
@@ -1177,7 +1321,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
 
     logevent("Trying to enable encryption...");
 
-    free(rsabuf);
+    sfree(rsabuf);
 
     cipher = cipher_type == SSH_CIPHER_BLOWFISH ? &ssh_blowfish_ssh1 :
              cipher_type == SSH_CIPHER_DES ? &ssh_des :
@@ -1235,20 +1379,20 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
            c_write("\r\n", 2);
            username[strcspn(username, "\n\r")] = '\0';
        } else {
-           char stuff[200];
            strncpy(username, cfg.username, 99);
            username[99] = '\0';
-            if ((flags & FLAG_VERBOSE) || (flags & FLAG_INTERACTIVE)) {
-               sprintf(stuff, "Sent username \"%s\".\r\n", username);
-                c_write(stuff, strlen(stuff));
-           }
        }
 
        send_packet(SSH1_CMSG_USER, PKT_STR, username, PKT_END);
        {
-           char userlog[20+sizeof(username)];
+           char userlog[22+sizeof(username)];
            sprintf(userlog, "Sent username \"%s\"", username);
            logevent(userlog);
+            if (flags & FLAG_INTERACTIVE &&
+                (!((flags & FLAG_STDERR) && (flags & FLAG_VERBOSE)))) {
+               strcat(userlog, "\r\n");
+                c_write(userlog, strlen(userlog));
+           }
        }
     }
 
@@ -1321,7 +1465,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                         len += ssh1_bignum_length(challenge);
                         len += 16;     /* session id */
                         len += 4;      /* response format */
-                        agentreq = malloc(4 + len);
+                        agentreq = smalloc(4 + len);
                         PUT_32BIT(agentreq, len);
                         q = agentreq + 4;
                         *q++ = SSH_AGENTC_RSA_CHALLENGE;
@@ -1333,13 +1477,13 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                         memcpy(q, session_id, 16); q += 16;
                         PUT_32BIT(q, 1);   /* response format */
                         agent_query(agentreq, len+4, &ret, &retlen);
-                        free(agentreq);
+                        sfree(agentreq);
                         if (ret) {
                             if (ret[4] == SSH_AGENT_RSA_RESPONSE) {
                                 logevent("Sending Pageant's response");
                                 send_packet(SSH1_CMSG_AUTH_RSA_RESPONSE,
                                             PKT_DATA, ret+5, 16, PKT_END);
-                                free(ret);
+                                sfree(ret);
                                 crWaitUntil(ispkt);
                                 if (pktin.type == SSH1_SMSG_SUCCESS) {
                                     logevent("Pageant's response accepted");
@@ -1354,7 +1498,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                                     logevent("Pageant's response not accepted");
                             } else {
                                 logevent("Pageant failed to answer challenge");
-                                free(ret);
+                                sfree(ret);
                             }
                         } else {
                             logevent("No reply received from Pageant");
@@ -1434,7 +1578,7 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
                 goto tryauth;
             }
             sprintf(prompt, "Passphrase for key \"%.100s\": ", comment);
-            free(comment);
+            sfree(comment);
         }
 
        if (ssh_get_password) {
@@ -1595,8 +1739,10 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
             crReturnV;
         } else if (pktin.type == SSH1_SMSG_FAILURE) {
             logevent("Agent forwarding refused");
-        } else
+        } else {
             logevent("Agent forwarding enabled");
+           ssh_agentfwd_enabled = TRUE;
+       }
     }
 
     if (!cfg.nopty) {
@@ -1617,6 +1763,21 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
        logevent("Allocated pty");
     }
 
+    if (cfg.compression) {
+        send_packet(SSH1_CMSG_REQUEST_COMPRESSION, PKT_INT, 6, PKT_END);
+        do { crReturnV; } while (!ispkt);
+        if (pktin.type != SSH1_SMSG_SUCCESS && pktin.type != SSH1_SMSG_FAILURE) {
+            bombout(("Protocol confusion"));
+            crReturnV;
+        } else if (pktin.type == SSH1_SMSG_FAILURE) {
+            c_write("Server refused to compress\r\n", 32);
+        }
+       logevent("Started compression");
+       ssh1_compressing = TRUE;
+       zlib_compress_init();
+       zlib_decompress_init();
+    }
+
     if (*cfg.remote_cmd)
         send_packet(SSH1_CMSG_EXEC_CMD, PKT_STR, cfg.remote_cmd, PKT_END);
     else
@@ -1626,6 +1787,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
     ssh_state = SSH_STATE_SESSION;
     if (size_needed)
        ssh_size();
+    if (eof_needed)
+        ssh_special(TS_EOF);
 
     ssh_send_ok = 1;
     ssh_channels = newtree234(ssh_channelcmp);
@@ -1641,29 +1804,39 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
            } else if (pktin.type == SSH1_MSG_DISCONNECT) {
                 ssh_state = SSH_STATE_CLOSED;
                logevent("Received disconnect request");
+                crReturnV;
             } else if (pktin.type == SSH1_SMSG_AGENT_OPEN) {
                 /* Remote side is trying to open a channel to talk to our
                  * agent. Give them back a local channel number. */
-                unsigned i = 1;
+                unsigned i;
                 struct ssh_channel *c;
                 enum234 e;
-                for (c = first234(ssh_channels, &e); c; c = next234(&e)) {
-                    if (c->localid > i)
-                        break;         /* found a free number */
-                    i = c->localid + 1;
-                }
-                c = malloc(sizeof(struct ssh_channel));
-                c->remoteid = GET_32BIT(pktin.body);
-                c->localid = i;
-                c->closes = 0;
-                c->type = SSH1_SMSG_AGENT_OPEN;   /* identify channel type */
-                c->u.a.lensofar = 0;
-                add234(ssh_channels, c);
-                send_packet(SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
-                            PKT_INT, c->remoteid, PKT_INT, c->localid,
-                            PKT_END);
-            } else if (pktin.type == SSH1_MSG_CHANNEL_CLOSE ||
-                       pktin.type == SSH1_MSG_CHANNEL_CLOSE_CONFIRMATION) {
+
+               /* Refuse if agent forwarding is disabled. */
+               if (!ssh_agentfwd_enabled) {
+                   send_packet(SSH1_MSG_CHANNEL_OPEN_FAILURE,
+                               PKT_INT, GET_32BIT(pktin.body),
+                               PKT_END);
+               } else {
+                   i = 1;
+                   for (c = first234(ssh_channels, &e); c; c = next234(&e)) {
+                       if (c->localid > i)
+                           break;     /* found a free number */
+                       i = c->localid + 1;
+                   }
+                   c = smalloc(sizeof(struct ssh_channel));
+                   c->remoteid = GET_32BIT(pktin.body);
+                   c->localid = i;
+                   c->closes = 0;
+                   c->type = SSH1_SMSG_AGENT_OPEN;/* identify channel type */
+                   c->u.a.lensofar = 0;
+                   add234(ssh_channels, c);
+                   send_packet(SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
+                               PKT_INT, c->remoteid, PKT_INT, c->localid,
+                               PKT_END);
+               }
+           } else if (pktin.type == SSH1_MSG_CHANNEL_CLOSE ||
+                      pktin.type == SSH1_MSG_CHANNEL_CLOSE_CONFIRMATION) {
                 /* Remote side closes a channel. */
                 unsigned i = GET_32BIT(pktin.body);
                 struct ssh_channel *c;
@@ -1675,7 +1848,7 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
                     c->closes |= closetype;
                     if (c->closes == 3) {
                         del234(ssh_channels, c);
-                        free(c);
+                        sfree(c);
                     }
                 }
             } else if (pktin.type == SSH1_MSG_CHANNEL_DATA) {
@@ -1697,7 +1870,7 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
                             }
                             if (c->u.a.lensofar == 4) {
                                 c->u.a.totallen = 4 + GET_32BIT(c->u.a.msglen);
-                                c->u.a.message = malloc(c->u.a.totallen);
+                                c->u.a.message = smalloc(c->u.a.totallen);
                                 memcpy(c->u.a.message, c->u.a.msglen, 4);
                             }
                             if (c->u.a.lensofar >= 4 && len > 0) {
@@ -1723,8 +1896,8 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
                                             PKT_DATA, sentreply, replylen,
                                             PKT_END);
                                 if (reply)
-                                    free(reply);
-                                free(c->u.a.message);
+                                    sfree(reply);
+                                sfree(c->u.a.message);
                                 c->u.a.lensofar = 0;
                             }
                         }
@@ -1821,15 +1994,19 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     static const struct ssh_compress *sccomp_tobe = NULL;
     static char *hostkeydata, *sigdata, *keystr, *fingerprint;
     static int hostkeylen, siglen;
+    static void *hkey;                /* actual host key */
     static unsigned char exchange_hash[20];
     static unsigned char keyspace[40];
     static const struct ssh_cipher *preferred_cipher;
+    static const struct ssh_compress *preferred_comp;
+    static int first_kex;
 
     crBegin;
     random_init();
+    first_kex = 1;
 
     /*
-     * Set up the preferred cipher.
+     * Set up the preferred cipher and compression.
      */
     if (cfg.cipher == CIPHER_BLOWFISH) {
         preferred_cipher = &ssh_blowfish_ssh2;
@@ -1842,6 +2019,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
         /* Shouldn't happen, but we do want to initialise to _something_. */
         preferred_cipher = &ssh_3des_ssh2;
     }
+    if (cfg.compression)
+       preferred_comp = &ssh_zlib;
+    else
+       preferred_comp = &ssh_comp_none;
 
     /*
      * Be prepared to work around the buggy MAC problem.
@@ -1904,16 +2085,18 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     }
     /* List client->server compression algorithms. */
     ssh2_pkt_addstring_start();
-    for (i = 0; i < lenof(compressions); i++) {
-        ssh2_pkt_addstring_str(compressions[i]->name);
-        if (i < lenof(compressions)-1)
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        ssh2_pkt_addstring_str(c->name);
+        if (i < lenof(compressions))
             ssh2_pkt_addstring_str(",");
     }
     /* List server->client compression algorithms. */
     ssh2_pkt_addstring_start();
-    for (i = 0; i < lenof(compressions); i++) {
-        ssh2_pkt_addstring_str(compressions[i]->name);
-        if (i < lenof(compressions)-1)
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        ssh2_pkt_addstring_str(c->name);
+        if (i < lenof(compressions))
             ssh2_pkt_addstring_str(",");
     }
     /* List client->server languages. Empty list. */
@@ -1924,7 +2107,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     ssh2_pkt_addbool(FALSE);
     /* Reserved. */
     ssh2_pkt_adduint32(0);
+
+    exhash = exhashbase;
     sha_string(&exhash, pktout.data+5, pktout.length-5);
+
     ssh2_pkt_send();
 
     if (!ispkt) crWaitUntil(ispkt);
@@ -1986,16 +2172,18 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
         }
     }
     ssh2_pkt_getstring(&str, &len);    /* client->server compression */
-    for (i = 0; i < lenof(compressions); i++) {
-        if (in_commasep_string(compressions[i]->name, str, len)) {
-            cscomp_tobe = compressions[i];
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        if (in_commasep_string(c->name, str, len)) {
+            cscomp_tobe = c;
             break;
         }
     }
     ssh2_pkt_getstring(&str, &len);    /* server->client compression */
-    for (i = 0; i < lenof(compressions); i++) {
-        if (in_commasep_string(compressions[i]->name, str, len)) {
-            sccomp_tobe = compressions[i];
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        if (in_commasep_string(c->name, str, len)) {
+            sccomp_tobe = c;
             break;
         }
     }
@@ -2041,8 +2229,8 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     debug(("\r\n"));
 #endif
 
-    hostkey->setkey(hostkeydata, hostkeylen);
-    if (!hostkey->verifysig(sigdata, siglen, exchange_hash, 20)) {
+    hkey = hostkey->newkey(hostkeydata, hostkeylen);
+    if (!hostkey->verifysig(hkey, sigdata, siglen, exchange_hash, 20)) {
         bombout(("Server failed host key check"));
         crReturn(0);
     }
@@ -2060,14 +2248,15 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
      * Authenticate remote host: verify host key. (We've already
      * checked the signature of the exchange hash.)
      */
-    keystr = hostkey->fmtkey();
-    fingerprint = hostkey->fingerprint();
+    keystr = hostkey->fmtkey(hkey);
+    fingerprint = hostkey->fingerprint(hkey);
     verify_ssh_host_key(savedhost, savedport, hostkey->keytype,
                         keystr, fingerprint);
     logevent("Host key fingerprint is:");
     logevent(fingerprint);
-    free(fingerprint);
-    free(keystr);
+    sfree(fingerprint);
+    sfree(keystr);
+    hostkey->freekey(hkey);
 
     /*
      * Send SSH2_MSG_NEWKEYS.
@@ -2084,17 +2273,32 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     scmac = scmac_tobe;
     cscomp = cscomp_tobe;
     sccomp = sccomp_tobe;
+    cscomp->compress_init();
+    sccomp->decompress_init();
     /*
      * Set IVs after keys.
      */
     ssh2_mkkey(K, exchange_hash, 'C', keyspace); cscipher->setcskey(keyspace);
-    ssh2_mkkey(K, exchange_hash, 'D', keyspace); cscipher->setsckey(keyspace);
+    ssh2_mkkey(K, exchange_hash, 'D', keyspace); sccipher->setsckey(keyspace);
     ssh2_mkkey(K, exchange_hash, 'A', keyspace); cscipher->setcsiv(keyspace);
     ssh2_mkkey(K, exchange_hash, 'B', keyspace); sccipher->setsciv(keyspace);
     ssh2_mkkey(K, exchange_hash, 'E', keyspace); csmac->setcskey(keyspace);
     ssh2_mkkey(K, exchange_hash, 'F', keyspace); scmac->setsckey(keyspace);
 
     /*
+     * If this is the first key exchange phase, we must pass the
+     * SSH2_MSG_NEWKEYS packet to the next layer, not because it
+     * wants to see it but because it will need time to initialise
+     * itself before it sees an actual packet. In subsequent key
+     * exchange phases, we don't pass SSH2_MSG_NEWKEYS on, because
+     * it would only confuse the layer above.
+     */
+    if (!first_kex) {
+        crReturn(0);
+    }
+    first_kex = 0;
+
+    /*
      * Now we're encrypting. Begin returning 1 to the protocol main
      * function so that other things can run on top of the
      * transport. If we ever see a KEXINIT, we must go back to the
@@ -2263,7 +2467,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
     /*
      * So now create a channel with a session in it.
      */
-    mainchan = malloc(sizeof(struct ssh_channel));
+    mainchan = smalloc(sizeof(struct ssh_channel));
     mainchan->localid = 100;           /* as good as any */
     ssh2_pkt_init(SSH2_MSG_CHANNEL_OPEN);
     ssh2_pkt_addstring("session");
@@ -2364,6 +2568,8 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
     ssh_state = SSH_STATE_SESSION;
     if (size_needed)
        ssh_size();
+    if (eof_needed)
+        ssh_special(TS_EOF);
 
     /*
      * Transfer data!
@@ -2402,6 +2608,7 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
            } else if (pktin.type == SSH2_MSG_DISCONNECT) {
                 ssh_state = SSH_STATE_CLOSED;
                logevent("Received disconnect message");
+                crReturnV;
            } else if (pktin.type == SSH2_MSG_CHANNEL_REQUEST) {
                 continue;              /* exit status et al; ignore (FIXME?) */
            } else if (pktin.type == SSH2_MSG_CHANNEL_EOF) {
@@ -2417,10 +2624,12 @@ static void do_ssh2_authconn(unsigned char *in, int inlen, int ispkt)
                 if (1 /* FIXME: "all channels are closed" */) {
                     logevent("All channels closed. Disconnecting");
                     ssh2_pkt_init(SSH2_MSG_DISCONNECT);
+                    ssh2_pkt_adduint32(SSH2_DISCONNECT_BY_APPLICATION);
+                    ssh2_pkt_addstring("All open channels closed");
+                    ssh2_pkt_addstring("en");   /* language tag */
                     ssh2_pkt_send();
                     ssh_state = SSH_STATE_CLOSED;
-                    sk_close(s);
-                    s = NULL;
+                    crReturnV;
                 }
                 continue;              /* remote sends close; ignore (FIXME) */
            } else if (pktin.type == SSH2_MSG_CHANNEL_WINDOW_ADJUST) {
@@ -2513,7 +2722,7 @@ static char *ssh_init (char *host, int port, char **realhost) {
  * Called to send data down the Telnet connection.
  */
 static void ssh_send (char *buf, int len) {
-    if (s == NULL)
+    if (s == NULL || ssh_protocol == NULL)
        return;
 
     ssh_protocol(buf, len, 0);
@@ -2525,6 +2734,7 @@ static void ssh_send (char *buf, int len) {
 static void ssh_size(void) {
     switch (ssh_state) {
       case SSH_STATE_BEFORE_SIZE:
+      case SSH_STATE_PREPACKET:
       case SSH_STATE_CLOSED:
        break;                         /* do nothing */
       case SSH_STATE_INTERMED:
@@ -2548,6 +2758,7 @@ static void ssh_size(void) {
                 ssh2_pkt_send();
             }
         }
+        break;
     }
 }
 
@@ -2558,6 +2769,15 @@ static void ssh_size(void) {
  */
 static void ssh_special (Telnet_Special code) {
     if (code == TS_EOF) {
+        if (ssh_state != SSH_STATE_SESSION) {
+            /*
+             * Buffer the EOF in case we are pre-SESSION, so we can
+             * send it as soon as we reach SESSION.
+             */
+            if (code == TS_EOF)
+                eof_needed = TRUE;
+            return;
+        }
         if (ssh_version == 1) {
             send_packet(SSH1_CMSG_EOF, PKT_END);
         } else {
@@ -2567,6 +2787,8 @@ static void ssh_special (Telnet_Special code) {
         }
         logevent("Sent EOF message");
     } else if (code == TS_PING) {
+        if (ssh_state == SSH_STATE_CLOSED || ssh_state == SSH_STATE_PREPACKET)
+            return;
         if (ssh_version == 1) {
             send_packet(SSH1_MSG_IGNORE, PKT_STR, "", PKT_END);
         } else {