Implement Zlib compression, in both SSH1 and SSH2.
[u/mdw/putty] / ssh.c
diff --git a/ssh.c b/ssh.c
index 3eddd7b..26379e1 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -51,6 +51,7 @@
 #define SSH1_CMSG_EXIT_CONFIRMATION               33   /* 0x21 */
 #define SSH1_MSG_IGNORE                           32   /* 0x20 */
 #define SSH1_MSG_DEBUG                            36   /* 0x24 */
+#define SSH1_CMSG_REQUEST_COMPRESSION             37   /* 0x25 */
 #define SSH1_CMSG_AUTH_TIS                        39   /* 0x27 */
 #define SSH1_SMSG_AUTH_TIS_CHALLENGE              40   /* 0x28 */
 #define SSH1_CMSG_AUTH_TIS_RESPONSE               41   /* 0x29 */
@@ -186,10 +187,19 @@ const static struct ssh_mac *macs[] = {
 const static struct ssh_mac *buggymacs[] = {
     &ssh_sha1_buggy, &ssh_md5, &ssh_mac_none };
 
+static void ssh_comp_none_init(void) { }
+static int ssh_comp_none_block(unsigned char *block, int len,
+                              unsigned char **outblock, int *outlen) {
+    return 0;
+}
 const static struct ssh_compress ssh_comp_none = {
-    "none"
+    "none",
+    ssh_comp_none_init, ssh_comp_none_block,
+    ssh_comp_none_init, ssh_comp_none_block
 };
-const static struct ssh_compress *compressions[] = { &ssh_comp_none };
+extern const struct ssh_compress ssh_zlib;
+const static struct ssh_compress *compressions[] = {
+    &ssh_zlib, &ssh_comp_none };
 
 /*
  * 2-3-4 tree storing channels.
@@ -226,6 +236,7 @@ static SHA_State exhash;
 static Socket s = NULL;
 
 static unsigned char session_key[32];
+static int ssh1_compressing;
 static const struct ssh_cipher *cipher = NULL;
 static const struct ssh_cipher *cscipher = NULL;
 static const struct ssh_cipher *sccipher = NULL;
@@ -373,9 +384,6 @@ next_packet:
     debug(("\r\n"));
 #endif
 
-    pktin.type = pktin.data[st->pad];
-    pktin.body = pktin.data + st->pad + 1;
-
     st->realcrc = crc32(pktin.data, st->biglen-4);
     st->gotcrc = GET_32BIT(pktin.data+st->biglen-4);
     if (st->gotcrc != st->realcrc) {
@@ -383,6 +391,39 @@ next_packet:
         crReturn(0);
     }
 
+    pktin.body = pktin.data + st->pad + 1;
+
+    if (ssh1_compressing) {
+       unsigned char *decompblk;
+       int decomplen;
+#if 0
+       int i;
+       debug(("Packet payload pre-decompression:\n"));
+       for (i = -1; i < pktin.length; i++)
+           debug(("  %02x", (unsigned char)pktin.body[i]));
+       debug(("\r\n"));
+#endif
+       zlib_decompress_block(pktin.body-1, pktin.length+1,
+                             &decompblk, &decomplen);
+
+       if (pktin.maxlen < st->pad + decomplen) {
+           pktin.maxlen = st->pad + decomplen;
+           pktin.data = realloc(pktin.data, pktin.maxlen+APIEXTRA);
+           if (!pktin.data)
+               fatalbox("Out of memory");
+       }
+
+       memcpy(pktin.body-1, decompblk, decomplen);
+       free(decompblk);
+       pktin.length = decomplen-1;
+#if 0
+       debug(("Packet payload post-decompression:\n"));
+       for (i = -1; i < pktin.length; i++)
+           debug(("  %02x", (unsigned char)pktin.body[i]));
+       debug(("\r\n"));
+#endif
+    }
+
     if (pktin.type == SSH1_SMSG_STDOUT_DATA ||
         pktin.type == SSH1_SMSG_STDERR_DATA ||
         pktin.type == SSH1_MSG_DEBUG ||
@@ -395,6 +436,8 @@ next_packet:
         }
     }
 
+    pktin.type = pktin.body[-1];
+
     if (pktin.type == SSH1_MSG_DEBUG) {
        /* log debug message */
        char buf[80];
@@ -515,6 +558,34 @@ next_packet:
     }
     st->incoming_sequence++;               /* whether or not we MACed */
 
+    /*
+     * Decompress packet payload.
+     */
+    {
+       unsigned char *newpayload;
+       int newlen;
+       if (sccomp && sccomp->decompress(pktin.data+5, pktin.length-5,
+                                        &newpayload, &newlen)) {
+           if (pktin.maxlen < newlen+5) {
+               pktin.maxlen = newlen+5;
+               pktin.data = (pktin.data == NULL ? malloc(pktin.maxlen+APIEXTRA) :
+                             realloc(pktin.data, pktin.maxlen+APIEXTRA));
+               if (!pktin.data)
+                   fatalbox("Out of memory");
+           }
+           pktin.length = 5 + newlen;
+           memcpy(pktin.data+5, newpayload, newlen);
+#if 0
+           debug(("Post-decompression payload:\r\n"));
+           for (st->i = 0; st->i < newlen; st->i++)
+               debug(("  %02x", (unsigned char)pktin.data[5+st->i]));
+           debug(("\r\n"));
+#endif
+
+           free(newpayload);
+       }
+    }
+
     pktin.savedpos = 6;
     pktin.type = pktin.data[5];
 
@@ -524,7 +595,7 @@ next_packet:
     crFinish(0);
 }
 
-static void s_wrpkt_start(int type, int len) {
+static void ssh1_pktout_size(int len) {
     int pad, biglen;
 
     len += 5;                         /* type and CRC */
@@ -546,20 +617,46 @@ static void s_wrpkt_start(int type, int len) {
        if (!pktout.data)
            fatalbox("Out of memory");
     }
+    pktout.body = pktout.data+4+pad+1;
+}
 
+static void s_wrpkt_start(int type, int len) {
+    ssh1_pktout_size(len);
     pktout.type = type;
-    pktout.body = pktout.data+4+pad+1;
 }
 
 static void s_wrpkt(void) {
     int pad, len, biglen, i;
     unsigned long crc;
 
+    pktout.body[-1] = pktout.type;
+
+    if (ssh1_compressing) {
+       unsigned char *compblk;
+       int complen;
+#if 0
+       debug(("Packet payload pre-compression:\n"));
+       for (i = -1; i < pktout.length; i++)
+           debug(("  %02x", (unsigned char)pktout.body[i]));
+       debug(("\r\n"));
+#endif
+       zlib_compress_block(pktout.body-1, pktout.length+1,
+                           &compblk, &complen);
+       ssh1_pktout_size(complen-1);
+       memcpy(pktout.body-1, compblk, complen);
+       free(compblk);
+#if 0
+       debug(("Packet payload post-compression:\n"));
+       for (i = -1; i < pktout.length; i++)
+           debug(("  %02x", (unsigned char)pktout.body[i]));
+       debug(("\r\n"));
+#endif
+    }
+
     len = pktout.length + 5;          /* type and CRC */
     pad = 8 - (len%8);
     biglen = len + pad;
 
-    pktout.body[-1] = pktout.type;
     for (i=0; i<pad; i++)
        pktout.data[i+4] = random_byte();
     crc = crc32(pktout.data+4, biglen-4);
@@ -768,6 +865,26 @@ static void ssh2_pkt_send(void) {
     static unsigned long outgoing_sequence = 0;
 
     /*
+     * Compress packet payload.
+     */
+#if 0
+    debug(("Pre-compression payload:\r\n"));
+    for (i = 5; i < pktout.length; i++)
+       debug(("  %02x", (unsigned char)pktout.data[i]));
+    debug(("\r\n"));
+#endif
+    {
+       unsigned char *newpayload;
+       int newlen;
+       if (cscomp && cscomp->compress(pktout.data+5, pktout.length-5,
+                                      &newpayload, &newlen)) {
+           pktout.length = 5;
+           ssh2_pkt_adddata(newpayload, newlen);
+           free(newpayload);
+       }
+    }
+
+    /*
      * Add padding. At least four bytes, and must also bring total
      * length (minus MAC) up to a multiple of the block size.
      */
@@ -1255,20 +1372,20 @@ static int do_ssh1_login(unsigned char *in, int inlen, int ispkt)
            c_write("\r\n", 2);
            username[strcspn(username, "\n\r")] = '\0';
        } else {
-           char stuff[200];
            strncpy(username, cfg.username, 99);
            username[99] = '\0';
-            if ((flags & FLAG_VERBOSE) || (flags & FLAG_INTERACTIVE)) {
-               sprintf(stuff, "Sent username \"%s\".\r\n", username);
-                c_write(stuff, strlen(stuff));
-           }
        }
 
        send_packet(SSH1_CMSG_USER, PKT_STR, username, PKT_END);
        {
-           char userlog[20+sizeof(username)];
+           char userlog[22+sizeof(username)];
            sprintf(userlog, "Sent username \"%s\"", username);
            logevent(userlog);
+            if (flags & FLAG_INTERACTIVE &&
+                (!((flags & FLAG_STDERR) && (flags & FLAG_VERBOSE)))) {
+               strcat(userlog, "\r\n");
+                c_write(userlog, strlen(userlog));
+           }
        }
     }
 
@@ -1637,6 +1754,21 @@ static void ssh1_protocol(unsigned char *in, int inlen, int ispkt) {
        logevent("Allocated pty");
     }
 
+    if (cfg.compression) {
+        send_packet(SSH1_CMSG_REQUEST_COMPRESSION, PKT_INT, 6, PKT_END);
+        do { crReturnV; } while (!ispkt);
+        if (pktin.type != SSH1_SMSG_SUCCESS && pktin.type != SSH1_SMSG_FAILURE) {
+            bombout(("Protocol confusion"));
+            crReturnV;
+        } else if (pktin.type == SSH1_SMSG_FAILURE) {
+            c_write("Server refused to compress\r\n", 32);
+        }
+       logevent("Started compression");
+       ssh1_compressing = TRUE;
+       zlib_compress_init();
+       zlib_decompress_init();
+    }
+
     if (*cfg.remote_cmd)
         send_packet(SSH1_CMSG_EXEC_CMD, PKT_STR, cfg.remote_cmd, PKT_END);
     else
@@ -1845,12 +1977,13 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     static unsigned char exchange_hash[20];
     static unsigned char keyspace[40];
     static const struct ssh_cipher *preferred_cipher;
+    static const struct ssh_compress *preferred_comp;
 
     crBegin;
     random_init();
 
     /*
-     * Set up the preferred cipher.
+     * Set up the preferred cipher and compression.
      */
     if (cfg.cipher == CIPHER_BLOWFISH) {
         preferred_cipher = &ssh_blowfish_ssh2;
@@ -1863,6 +1996,10 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
         /* Shouldn't happen, but we do want to initialise to _something_. */
         preferred_cipher = &ssh_3des_ssh2;
     }
+    if (cfg.compression)
+       preferred_comp = &ssh_zlib;
+    else
+       preferred_comp = &ssh_comp_none;
 
     /*
      * Be prepared to work around the buggy MAC problem.
@@ -1925,16 +2062,18 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     }
     /* List client->server compression algorithms. */
     ssh2_pkt_addstring_start();
-    for (i = 0; i < lenof(compressions); i++) {
-        ssh2_pkt_addstring_str(compressions[i]->name);
-        if (i < lenof(compressions)-1)
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        ssh2_pkt_addstring_str(c->name);
+        if (i < lenof(compressions))
             ssh2_pkt_addstring_str(",");
     }
     /* List server->client compression algorithms. */
     ssh2_pkt_addstring_start();
-    for (i = 0; i < lenof(compressions); i++) {
-        ssh2_pkt_addstring_str(compressions[i]->name);
-        if (i < lenof(compressions)-1)
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        ssh2_pkt_addstring_str(c->name);
+        if (i < lenof(compressions))
             ssh2_pkt_addstring_str(",");
     }
     /* List client->server languages. Empty list. */
@@ -2007,16 +2146,18 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
         }
     }
     ssh2_pkt_getstring(&str, &len);    /* client->server compression */
-    for (i = 0; i < lenof(compressions); i++) {
-        if (in_commasep_string(compressions[i]->name, str, len)) {
-            cscomp_tobe = compressions[i];
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        if (in_commasep_string(c->name, str, len)) {
+            cscomp_tobe = c;
             break;
         }
     }
     ssh2_pkt_getstring(&str, &len);    /* server->client compression */
-    for (i = 0; i < lenof(compressions); i++) {
-        if (in_commasep_string(compressions[i]->name, str, len)) {
-            sccomp_tobe = compressions[i];
+    for (i = 0; i < lenof(compressions)+1; i++) {
+        const struct ssh_compress *c = i==0 ? preferred_comp : compressions[i-1];
+        if (in_commasep_string(c->name, str, len)) {
+            sccomp_tobe = c;
             break;
         }
     }
@@ -2105,6 +2246,8 @@ static int do_ssh2_transport(unsigned char *in, int inlen, int ispkt)
     scmac = scmac_tobe;
     cscomp = cscomp_tobe;
     sccomp = sccomp_tobe;
+    cscomp->compress_init();
+    sccomp->decompress_init();
     /*
      * Set IVs after keys.
      */