Replace PuTTY's 2-3-4 tree implementation with the shiny new counted
[sgt/putty] / sshzlib.c
index 9e2e6a3..66ab214 100644 (file)
--- a/sshzlib.c
+++ b/sshzlib.c
@@ -70,9 +70,11 @@ static int lz77_init(struct LZ77Context *ctx);
 /*
  * Supply data to be compressed. Will update the private fields of
  * the LZ77Context, and will call literal() and match() to output.
+ * If `compress' is FALSE, it will never emit a match, but will
+ * instead call literal() for everything.
  */
 static void lz77_compress(struct LZ77Context *ctx,
-                          unsigned char *data, int len);
+                          unsigned char *data, int len, int compress);
 
 /*
  * Modifiable parameters.
@@ -93,12 +95,12 @@ static void lz77_compress(struct LZ77Context *ctx,
 
 #define INVALID -1                    /* invalid hash _and_ invalid offset */
 struct WindowEntry {
-    int next, prev;                   /* array indices within the window */
-    int hashval;
+    short next, prev;                 /* array indices within the window */
+    short hashval;
 };
 
 struct HashEntry {
-    int first;                        /* window index of first in chain */
+    short first;                      /* window index of first in chain */
 };
 
 struct Match {
@@ -174,7 +176,7 @@ static void lz77_advance(struct LZ77InternalContext *st,
 #define CHARAT(k) ( (k)<0 ? st->data[(st->winpos+k)&(WINSIZE-1)] : data[k] )
 
 static void lz77_compress(struct LZ77Context *ctx,
-                          unsigned char *data, int len) {
+                          unsigned char *data, int len, int compress) {
     struct LZ77InternalContext *st = ctx->ictx;
     int i, hash, distance, off, nmatch, matchlen, advance;
     struct Match defermatch, matches[MAXMATCH];
@@ -203,7 +205,8 @@ static void lz77_compress(struct LZ77Context *ctx,
     defermatch.len = 0;
     while (len > 0) {
 
-        if (len >= HASHCHARS) {
+        /* Don't even look for a match, if we're not compressing. */
+        if (compress && len >= HASHCHARS) {
             /*
              * Hash the next few characters.
              */
@@ -336,6 +339,7 @@ struct Outbuf {
     unsigned long outbits;
     int noutbits;
     int firstblock;
+    int comp_disabled;
 };
 
 static void outbits(struct Outbuf *out, unsigned long bits, int nbits) {
@@ -389,7 +393,8 @@ static const unsigned char mirrorbytes[256] = {
 };
 
 typedef struct {
-    int code, extrabits, min, max;
+    short code, extrabits;
+    int min, max;
 } coderecord;
 
 static const coderecord lencodes[] = {
@@ -460,6 +465,14 @@ static const coderecord distcodes[] = {
 static void zlib_literal(struct LZ77Context *ectx, unsigned char c) {
     struct Outbuf *out = (struct Outbuf *)ectx->userdata;
 
+    if (out->comp_disabled) {
+        /*
+         * We're in an uncompressed block, so just output the byte.
+         */
+        outbits(out, c, 8);
+        return;
+    }
+
     if (c <= 143) {
         /* 0 through 143 are 8 bits long starting at 00110000. */
         outbits(out, mirrorbytes[0x30 + c], 8);
@@ -473,6 +486,9 @@ static void zlib_match(struct LZ77Context *ectx, int distance, int len) {
     const coderecord *d, *l;
     int i, j, k;
     struct Outbuf *out = (struct Outbuf *)ectx->userdata;
+
+    assert(!out->comp_disabled);
+
     while (len > 0) {
         int thislen;
        
@@ -563,14 +579,56 @@ void zlib_compress_init(void) {
     out = smalloc(sizeof(struct Outbuf));
     out->outbits = out->noutbits = 0;
     out->firstblock = 1;
+    out->comp_disabled = FALSE;
     ectx.userdata = out;
 
     logevent("Initialised zlib (RFC1950) compression");
 }
 
+/*
+ * Turn off actual LZ77 analysis for one block, to facilitate
+ * construction of a precise-length IGNORE packet. Returns the
+ * length adjustment (which is only valid for packets < 65536
+ * bytes, but that seems reasonable enough).
+ */
+int zlib_disable_compression(void) {
+    struct Outbuf *out = (struct Outbuf *)ectx.userdata;
+    int n;
+
+    out->comp_disabled = TRUE;
+
+    n = 0;
+    /*
+     * If this is the first block, we will start by outputting two
+     * header bytes, and then three bits to begin an uncompressed
+     * block. This will cost three bytes (because we will start on
+     * a byte boundary, this is certain).
+     */
+    if (out->firstblock) {
+        n = 3;
+    } else {
+        /*
+         * Otherwise, we will output seven bits to close the
+         * previous static block, and _then_ three bits to begin an
+         * uncompressed block, and then flush the current byte.
+         * This may cost two bytes or three, depending on noutbits.
+         */
+        n += (out->noutbits + 10) / 8;
+    }
+
+    /*
+     * Now we output four bytes for the length / ~length pair in
+     * the uncompressed block.
+     */
+    n += 4;
+
+    return n;
+}
+
 int zlib_compress_block(unsigned char *block, int len,
                        unsigned char **outblock, int *outlen) {
     struct Outbuf *out = (struct Outbuf *)ectx.userdata;
+    int in_block;
 
     out->outbuf = NULL;
     out->outlen = out->outsize = 0;
@@ -583,43 +641,101 @@ int zlib_compress_block(unsigned char *block, int len,
     if (out->firstblock) {
         outbits(out, 0x9C78, 16);
         out->firstblock = 0;
-       /*
-        * Start a Deflate (RFC1951) fixed-trees block. We transmit
-        * a zero bit (BFINAL=0), followed by a zero bit and a one
-        * bit (BTYPE=01). Of course these are in the wrong order
-        * (01 0).
-        */
-       outbits(out, 2, 3);
+
+        in_block = FALSE;
     }
 
-    /*
-     * Do the compression.
-     */
-    lz77_compress(&ectx, block, len);
-    /*
-     * End the block (by transmitting code 256, which is 0000000 in
-     * fixed-tree mode), and transmit some empty blocks to ensure
-     * we have emitted the byte containing the last piece of
-     * genuine data. There are three ways we can do this:
-     * 
-     *  - Minimal flush. Output end-of-block and then open a new
-     *    static block. This takes 9 bits, which is guaranteed to
-     *    flush out the last genuine code in the closed block; but
-     *    allegedly zlib can't handle it.
-     * 
-     *  - Zlib partial flush. Output EOB, open and close an empty
-     *    static block, and _then_ open the new block. This is the
-     *    best zlib can handle.
-     * 
-     *  - Zlib sync flush. Output EOB, then an empty _uncompressed_
-     *    block (000, then sync to byte boundary, then send bytes
-     *    00 00 FF FF). Then open the new block.
-     * 
-     * For the moment, we will use Zlib partial flush.
-     */
-    outbits(out, 0, 7);                       /* close block */
-    outbits(out, 2, 3+7);             /* empty static block */
-    outbits(out, 2, 3);                       /* open new block */
+    if (out->comp_disabled) {
+        if (in_block)
+            outbits(out, 0, 7);                       /* close static block */
+
+        while (len > 0) {
+            int blen = (len < 65535 ? len : 65535);
+
+            /*
+             * Start a Deflate (RFC1951) uncompressed block. We
+             * transmit a zero bit (BFINAL=0), followed by a zero
+             * bit and a one bit (BTYPE=00). Of course these are in
+             * the wrong order (00 0).
+             */
+            outbits(out, 0, 3);
+
+            /*
+             * Output zero bits to align to a byte boundary.
+             */
+            if (out->noutbits)
+                outbits(out, 0, 8 - out->noutbits);
+
+            /*
+             * Output the block length, and then its one's
+             * complement. They're little-endian, so all we need to
+             * do is pass them straight to outbits() with bit count
+             * 16.
+             */
+            outbits(out, blen, 16);
+            outbits(out, blen ^ 0xFFFF, 16);
+
+            /*
+             * Do the `compression': we need to pass the data to
+             * lz77_compress so that it will be taken into account
+             * for subsequent (distance,length) pairs. But
+             * lz77_compress is passed FALSE, which means it won't
+             * actually find (or even look for) any matches; so
+             * every character will be passed straight to
+             * zlib_literal which will spot out->comp_disabled and
+             * emit in the uncompressed format.
+             */
+            lz77_compress(&ectx, block, blen, FALSE);
+
+            len -= blen;
+            block += blen;
+        }
+        outbits(out, 2, 3);                   /* open new block */
+    } else {
+        if (!in_block) {
+            /*
+             * Start a Deflate (RFC1951) fixed-trees block. We
+             * transmit a zero bit (BFINAL=0), followed by a zero
+             * bit and a one bit (BTYPE=01). Of course these are in
+             * the wrong order (01 0).
+             */
+            outbits(out, 2, 3);
+        }
+
+        /*
+         * Do the compression.
+         */
+        lz77_compress(&ectx, block, len, TRUE);
+
+        /*
+         * End the block (by transmitting code 256, which is
+         * 0000000 in fixed-tree mode), and transmit some empty
+         * blocks to ensure we have emitted the byte containing the
+         * last piece of genuine data. There are three ways we can
+         * do this:
+         *
+         *  - Minimal flush. Output end-of-block and then open a
+         *    new static block. This takes 9 bits, which is
+         *    guaranteed to flush out the last genuine code in the
+         *    closed block; but allegedly zlib can't handle it.
+         *
+         *  - Zlib partial flush. Output EOB, open and close an
+         *    empty static block, and _then_ open the new block.
+         *    This is the best zlib can handle.
+         *
+         *  - Zlib sync flush. Output EOB, then an empty
+         *    _uncompressed_ block (000, then sync to byte
+         *    boundary, then send bytes 00 00 FF FF). Then open the
+         *    new block.
+         *
+         * For the moment, we will use Zlib partial flush.
+         */
+        outbits(out, 0, 7);                   /* close block */
+        outbits(out, 2, 3+7);         /* empty static block */
+        outbits(out, 2, 3);                   /* open new block */
+    }
+
+    out->comp_disabled = FALSE;
 
     *outblock = out->outbuf;
     *outlen = out->outlen;
@@ -645,7 +761,7 @@ struct zlib_tableentry;
 
 struct zlib_tableentry {
     unsigned char nbits;
-    int code;
+    short code;
     struct zlib_table *nexttable;
 };
 
@@ -747,6 +863,31 @@ static struct zlib_table *zlib_mktable(unsigned char *lengths, int nlengths) {
                          maxlen < 9 ? maxlen : 9);
 }
 
+static int zlib_freetable(struct zlib_table ** ztab) {
+    struct zlib_table *tab;
+    int code;
+
+    if (ztab == NULL)
+       return -1;
+
+    if (*ztab == NULL)
+       return 0;
+
+    tab = *ztab;
+
+    for (code = 0; code <= tab->mask; code++)
+       if (tab->table[code].nexttable != NULL)
+           zlib_freetable(&tab->table[code].nexttable);
+
+    sfree(tab->table);
+    tab->table = NULL;
+
+    sfree(tab);
+    *ztab = NULL;
+
+    return(0);
+}
+
 static struct zlib_decompress_ctx {
     struct zlib_table *staticlentable, *staticdisttable;
     struct zlib_table *currlentable, *currdisttable, *lenlentable;
@@ -893,8 +1034,8 @@ int zlib_decompress_block(unsigned char *block, int len,
                 dctx.currlentable = zlib_mktable(dctx.lengths, dctx.hlit);
                 dctx.currdisttable = zlib_mktable(dctx.lengths + dctx.hlit,
                                                   dctx.hdist);
-                /* FIXME: zlib_freetable(dctx.lenlentable); */
-                dctx.state = INBLK;
+               zlib_freetable(&dctx.lenlentable);
+               dctx.state = INBLK;
                 break;
             }
             code = zlib_huflookup(&dctx.bits, &dctx.nbits, dctx.lenlentable);
@@ -930,7 +1071,10 @@ int zlib_decompress_block(unsigned char *block, int len,
                zlib_emit_char(code);
             else if (code == 256) {
                 dctx.state = OUTSIDEBLK;
-                /* FIXME: zlib_freetable(both) if not static */
+               if (dctx.currlentable != dctx.staticlentable)
+                   zlib_freetable(&dctx.currlentable);
+               if (dctx.currdisttable != dctx.staticdisttable)
+                   zlib_freetable(&dctx.currdisttable);
             } else if (code < 286) {   /* static tree can give >285; ignore */
                 dctx.state = GOTLENSYM;
                 dctx.sym = code;
@@ -1006,5 +1150,6 @@ const struct ssh_compress ssh_zlib = {
     zlib_compress_init,
     zlib_compress_block,
     zlib_decompress_init,
-    zlib_decompress_block
+    zlib_decompress_block,
+    zlib_disable_compression
 };