Fix a segfault (non-security-critical - null dereference for
[u/mdw/putty] / sshzlib.c
index 58069b6..91f5537 100644 (file)
--- a/sshzlib.c
+++ b/sshzlib.c
@@ -126,7 +126,7 @@ static int lz77_init(struct LZ77Context *ctx)
     struct LZ77InternalContext *st;
     int i;
 
-    st = (struct LZ77InternalContext *) smalloc(sizeof(*st));
+    st = snew(struct LZ77InternalContext);
     if (!st)
        return 0;
 
@@ -354,7 +354,7 @@ static void outbits(struct Outbuf *out, unsigned long bits, int nbits)
     while (out->noutbits >= 8) {
        if (out->outlen >= out->outsize) {
            out->outsize = out->outlen + 64;
-           out->outbuf = srealloc(out->outbuf, out->outsize);
+           out->outbuf = sresize(out->outbuf, out->outsize, unsigned char);
        }
        out->outbuf[out->outlen++] = (unsigned char) (out->outbits & 0xFF);
        out->outbits >>= 8;
@@ -583,13 +583,13 @@ static void zlib_match(struct LZ77Context *ectx, int distance, int len)
 void *zlib_compress_init(void)
 {
     struct Outbuf *out;
-    struct LZ77Context *ectx = smalloc(sizeof(struct LZ77Context));
+    struct LZ77Context *ectx = snew(struct LZ77Context);
 
     lz77_init(ectx);
     ectx->literal = zlib_literal;
     ectx->match = zlib_match;
 
-    out = smalloc(sizeof(struct Outbuf));
+    out = snew(struct Outbuf);
     out->outbits = out->noutbits = 0;
     out->firstblock = 1;
     out->comp_disabled = FALSE;
@@ -610,7 +610,7 @@ void zlib_compress_cleanup(void *handle)
  * length adjustment (which is only valid for packets < 65536
  * bytes, but that seems reasonable enough).
  */
-int zlib_disable_compression(void *handle)
+static int zlib_disable_compression(void *handle)
 {
     struct LZ77Context *ectx = (struct LZ77Context *)handle;
     struct Outbuf *out = (struct Outbuf *) ectx->userdata;
@@ -806,11 +806,11 @@ static struct zlib_table *zlib_mkonetab(int *codes, unsigned char *lengths,
                                        int nsyms,
                                        int pfx, int pfxbits, int bits)
 {
-    struct zlib_table *tab = smalloc(sizeof(struct zlib_table));
+    struct zlib_table *tab = snew(struct zlib_table);
     int pfxmask = (1 << pfxbits) - 1;
     int nbits, i, j, code;
 
-    tab->table = smalloc((1 << bits) * sizeof(struct zlib_tableentry));
+    tab->table = snewn(1 << bits, struct zlib_tableentry);
     tab->mask = (1 << bits) - 1;
 
     for (code = 0; code <= tab->mask; code++) {
@@ -941,8 +941,7 @@ struct zlib_decompress_ctx {
 
 void *zlib_decompress_init(void)
 {
-    struct zlib_decompress_ctx *dctx =
-       smalloc(sizeof(struct zlib_decompress_ctx));
+    struct zlib_decompress_ctx *dctx = snew(struct zlib_decompress_ctx);
     unsigned char lengths[288];
 
     memset(lengths, 8, 144);
@@ -974,7 +973,7 @@ void zlib_decompress_cleanup(void *handle)
     sfree(dctx);
 }
 
-int zlib_huflookup(unsigned long *bitsp, int *nbitsp,
+static int zlib_huflookup(unsigned long *bitsp, int *nbitsp,
                   struct zlib_table *tab)
 {
     unsigned long bits = *bitsp;
@@ -993,6 +992,16 @@ int zlib_huflookup(unsigned long *bitsp, int *nbitsp,
            *nbitsp = nbits;
            return ent->code;
        }
+
+       if (!tab) {
+           /*
+            * There was a missing entry in the table, presumably
+            * due to an invalid Huffman table description, and the
+            * subsequent data has attempted to use the missing
+            * entry. Return a decoding failure.
+            */
+           return -2;
+       }
     }
 }
 
@@ -1002,7 +1011,7 @@ static void zlib_emit_char(struct zlib_decompress_ctx *dctx, int c)
     dctx->winpos = (dctx->winpos + 1) & (WINSIZE - 1);
     if (dctx->outlen >= dctx->outsize) {
        dctx->outsize = dctx->outlen + 512;
-       dctx->outblk = srealloc(dctx->outblk, dctx->outsize);
+       dctx->outblk = sresize(dctx->outblk, dctx->outsize, unsigned char);
     }
     dctx->outblk[dctx->outlen++] = c;
 }
@@ -1100,6 +1109,8 @@ int zlib_decompress_block(void *handle, unsigned char *block, int len,
                zlib_huflookup(&dctx->bits, &dctx->nbits, dctx->lenlentable);
            if (code == -1)
                goto finished;
+           if (code == -2)
+               goto decode_error;
            if (code < 16)
                dctx->lengths[dctx->lenptr++] = code;
            else {
@@ -1129,6 +1140,8 @@ int zlib_decompress_block(void *handle, unsigned char *block, int len,
                zlib_huflookup(&dctx->bits, &dctx->nbits, dctx->currlentable);
            if (code == -1)
                goto finished;
+           if (code == -2)
+               goto decode_error;
            if (code < 256)
                zlib_emit_char(dctx, code);
            else if (code == 256) {
@@ -1161,6 +1174,8 @@ int zlib_decompress_block(void *handle, unsigned char *block, int len,
                               dctx->currdisttable);
            if (code == -1)
                goto finished;
+           if (code == -2)
+               goto decode_error;
            dctx->state = GOTDISTSYM;
            dctx->sym = code;
            break;
@@ -1214,8 +1229,13 @@ int zlib_decompress_block(void *handle, unsigned char *block, int len,
   finished:
     *outblock = dctx->outblk;
     *outlen = dctx->outlen;
-
     return 1;
+
+  decode_error:
+    sfree(dctx->outblk);
+    *outblock = dctx->outblk = NULL;
+    *outlen = 0;
+    return 0;
 }
 
 const struct ssh_compress ssh_zlib = {