Merged SSH1 robustness changes from 0.55 release branch on to trunk.
[u/mdw/putty] / pageant.c
index 3ac9ca7..2ebf3ab 100644 (file)
--- a/pageant.c
+++ b/pageant.c
@@ -119,8 +119,8 @@ static gsi_fn_t getsecurityinfo;
  */
 static void *make_keylist1(int *length);
 static void *make_keylist2(int *length);
-static void *get_keylist1(void);
-static void *get_keylist2(void);
+static void *get_keylist1(int *length);
+static void *get_keylist2(int *length);
 
 /*
  * We need this to link with the RSA code, because rsaencrypt()
@@ -414,7 +414,7 @@ static void add_keyfile(Filename filename)
     {
        void *blob;
        unsigned char *keylist, *p;
-       int i, nkeys, bloblen;
+       int i, nkeys, bloblen, keylistlen;
 
        if (type == SSH_KEYTYPE_SSH1) {
            if (!rsakey_pubblob(&filename, &blob, &bloblen, NULL)) {
@@ -422,7 +422,7 @@ static void add_keyfile(Filename filename)
                           MB_OK | MB_ICONERROR);
                return;
            }
-           keylist = get_keylist1();
+           keylist = get_keylist1(&keylistlen);
        } else {
            unsigned char *blob2;
            blob = ssh2_userkey_loadpub(&filename, NULL, &bloblen, NULL);
@@ -438,11 +438,17 @@ static void add_keyfile(Filename filename)
            sfree(blob);
            blob = blob2;
 
-           keylist = get_keylist2();
+           keylist = get_keylist2(&keylistlen);
        }
        if (keylist) {
+           if (keylistlen < 4) {
+               MessageBox(NULL, "Received broken key list?!", APPNAME,
+                          MB_OK | MB_ICONERROR);
+               return;
+           }
            nkeys = GET_32BIT(keylist);
            p = keylist + 4;
+           keylistlen -= 4;
 
            for (i = 0; i < nkeys; i++) {
                if (!memcmp(blob, p, bloblen)) {
@@ -452,12 +458,48 @@ static void add_keyfile(Filename filename)
                    return;
                }
                /* Now skip over public blob */
-               if (type == SSH_KEYTYPE_SSH1)
-                   p += rsa_public_blob_len(p);
-               else
-                   p += 4 + GET_32BIT(p);
+               if (type == SSH_KEYTYPE_SSH1) {
+                   int n = rsa_public_blob_len(p, keylistlen);
+                   if (n < 0) {
+                       MessageBox(NULL, "Received broken key list?!", APPNAME,
+                                  MB_OK | MB_ICONERROR);
+                       return;
+                   }
+                   p += n;
+                   keylistlen -= n;
+               } else {
+                   int n;
+                   if (keylistlen < 4) {
+                       MessageBox(NULL, "Received broken key list?!", APPNAME,
+                                  MB_OK | MB_ICONERROR);
+                       return;
+                   }
+                   n = 4 + GET_32BIT(p);
+                   if (keylistlen < n) {
+                       MessageBox(NULL, "Received broken key list?!", APPNAME,
+                                  MB_OK | MB_ICONERROR);
+                       return;
+                   }
+                   p += n;
+                   keylistlen -= n;
+               }
                /* Now skip over comment field */
-               p += 4 + GET_32BIT(p);
+               {
+                   int n;
+                   if (keylistlen < 4) {
+                       MessageBox(NULL, "Received broken key list?!", APPNAME,
+                                  MB_OK | MB_ICONERROR);
+                       return;
+                   }
+                   n = 4 + GET_32BIT(p);
+                   if (keylistlen < n) {
+                       MessageBox(NULL, "Received broken key list?!", APPNAME,
+                                  MB_OK | MB_ICONERROR);
+                       return;
+                   }
+                   p += n;
+                   keylistlen -= n;
+               }
            }
 
            sfree(keylist);
@@ -608,8 +650,8 @@ static void add_keyfile(Filename filename)
                                                keybloblen);
            PUT_32BIT(request + reqlen, clen);
            memcpy(request + reqlen + 4, skey->comment, clen);
-           PUT_32BIT(request, reqlen - 4);
            reqlen += clen + 4;
+           PUT_32BIT(request, reqlen - 4);
 
            ret = agent_query(request, reqlen, &vresponse, &resplen,
                              NULL, NULL);
@@ -731,7 +773,7 @@ static void *make_keylist2(int *length)
  * calling make_keylist1 (if that's us) or sending a message to the
  * primary Pageant (if it's not).
  */
-static void *get_keylist1(void)
+static void *get_keylist1(int *length)
 {
     void *ret;
 
@@ -751,8 +793,11 @@ static void *get_keylist1(void)
        ret = snewn(resplen-5, unsigned char);
        memcpy(ret, response+5, resplen-5);
        sfree(response);
+
+       if (length)
+           *length = resplen-5;
     } else {
-       ret = make_keylist1(NULL);
+       ret = make_keylist1(length);
     }
     return ret;
 }
@@ -762,7 +807,7 @@ static void *get_keylist1(void)
  * calling make_keylist2 (if that's us) or sending a message to the
  * primary Pageant (if it's not).
  */
-static void *get_keylist2(void)
+static void *get_keylist2(int *length)
 {
     void *ret;
 
@@ -783,8 +828,11 @@ static void *get_keylist2(void)
        ret = snewn(resplen-5, unsigned char);
        memcpy(ret, response+5, resplen-5);
        sfree(response);
+
+       if (length)
+           *length = resplen-5;
     } else {
-       ret = make_keylist2(NULL);
+       ret = make_keylist2(length);
     }
     return ret;
 }
@@ -796,11 +844,19 @@ static void answer_msg(void *msg)
 {
     unsigned char *p = msg;
     unsigned char *ret = msg;
+    unsigned char *msgend;
     int type;
 
     /*
+     * Get the message length.
+     */
+    msgend = p + 4 + GET_32BIT(p);
+
+    /*
      * Get the message type.
      */
+    if (msgend < p+5)
+       goto failure;
     type = p[4];
 
     p += 5;
@@ -857,12 +913,28 @@ static void answer_msg(void *msg)
            int i, len;
 
            p += 4;
-           p += ssh1_read_bignum(p, &reqkey.exponent);
-           p += ssh1_read_bignum(p, &reqkey.modulus);
-           p += ssh1_read_bignum(p, &challenge);
+           i = ssh1_read_bignum(p, msgend - p, &reqkey.exponent);
+           if (i < 0)
+               goto failure;
+           p += i;
+           i = ssh1_read_bignum(p, msgend - p, &reqkey.modulus);
+           if (i < 0)
+               goto failure;
+           p += i;
+           i = ssh1_read_bignum(p, msgend - p, &challenge);
+           if (i < 0)
+               goto failure;
+           p += i;
+           if (msgend < p+16) {
+               freebn(reqkey.exponent);
+               freebn(reqkey.modulus);
+               freebn(challenge);
+               goto failure;
+           }
            memcpy(response_source + 32, p, 16);
            p += 16;
-           if (GET_32BIT(p) != 1 ||
+           if (msgend < p+4 ||
+               GET_32BIT(p) != 1 ||
                (key = find234(rsakeys, &reqkey, NULL)) == NULL) {
                freebn(reqkey.exponent);
                freebn(reqkey.modulus);
@@ -904,12 +976,20 @@ static void answer_msg(void *msg)
            unsigned char *data, *signature;
            int datalen, siglen, len;
 
+           if (msgend < p+4)
+               goto failure;
            b.len = GET_32BIT(p);
            p += 4;
+           if (msgend < p+b.len)
+               goto failure;
            b.blob = p;
            p += b.len;
+           if (msgend < p+4)
+               goto failure;
            datalen = GET_32BIT(p);
            p += 4;
+           if (msgend < p+datalen)
+               goto failure;
            data = p;
            key = find234(ssh2keys, &b, cmpkeys_ssh2_asymm);
            if (!key)
@@ -931,15 +1011,64 @@ static void answer_msg(void *msg)
        {
            struct RSAKey *key;
            char *comment;
-            int commentlen;
+            int n, commentlen;
+
            key = snew(struct RSAKey);
            memset(key, 0, sizeof(struct RSAKey));
-           p += makekey(p, key, NULL, 1);
-           p += makeprivate(p, key);
-           p += ssh1_read_bignum(p, &key->iqmp);       /* p^-1 mod q */
-           p += ssh1_read_bignum(p, &key->p);  /* p */
-           p += ssh1_read_bignum(p, &key->q);  /* q */
+
+           n = makekey(p, msgend - p, key, NULL, 1);
+           if (n < 0) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
+           p += n;
+
+           n = makeprivate(p, msgend - p, key);
+           if (n < 0) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
+           p += n;
+
+           n = ssh1_read_bignum(p, msgend - p, &key->iqmp);  /* p^-1 mod q */
+           if (n < 0) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
+           p += n;
+
+           n = ssh1_read_bignum(p, msgend - p, &key->p);  /* p */
+           if (n < 0) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
+           p += n;
+
+           n = ssh1_read_bignum(p, msgend - p, &key->q);  /* q */
+           if (n < 0) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
+           p += n;
+
+           if (msgend < p+4) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
             commentlen = GET_32BIT(p);
+
+           if (msgend < p+commentlen) {
+               freersakey(key);
+               sfree(key);
+               goto failure;
+           }
+
            comment = snewn(commentlen+1, char);
            if (comment) {
                memcpy(comment, p + 4, commentlen);
@@ -968,12 +1097,17 @@ static void answer_msg(void *msg)
            int alglen, commlen;
            int bloblen;
 
-           key = snew(struct ssh2_userkey);
 
+           if (msgend < p+4)
+               goto failure;
            alglen = GET_32BIT(p);
            p += 4;
+           if (msgend < p+alglen)
+               goto failure;
            alg = p;
            p += alglen;
+
+           key = snew(struct ssh2_userkey);
            /* Add further algorithm names here. */
            if (alglen == 7 && !memcmp(alg, "ssh-rsa", 7))
                key->alg = &ssh_rsa;
@@ -984,18 +1118,32 @@ static void answer_msg(void *msg)
                goto failure;
            }
 
-           bloblen =
-               GET_32BIT((unsigned char *) msg) - (p -
-                                                   (unsigned char *) msg -
-                                                   4);
+           bloblen = msgend - p;
            key->data = key->alg->openssh_createkey(&p, &bloblen);
            if (!key->data) {
                sfree(key);
                goto failure;
            }
+
+           /*
+            * p has been advanced by openssh_createkey, but
+            * certainly not _beyond_ the end of the buffer.
+            */
+           assert(p <= msgend);
+
+           if (msgend < p+4) {
+               key->alg->freekey(key->data);
+               sfree(key);
+               goto failure;
+           }
            commlen = GET_32BIT(p);
            p += 4;
 
+           if (msgend < p+commlen) {
+               key->alg->freekey(key->data);
+               sfree(key);
+               goto failure;
+           }
            comment = snewn(commlen + 1, char);
            if (comment) {
                memcpy(comment, p, commlen);
@@ -1023,8 +1171,12 @@ static void answer_msg(void *msg)
         */
        {
            struct RSAKey reqkey, *key;
+           int n;
+
+           n = makekey(p, msgend - p, &reqkey, NULL, 0);
+           if (n < 0)
+               goto failure;
 
-           p += makekey(p, &reqkey, NULL, 0);
            key = find234(rsakeys, &reqkey, NULL);
            freebn(reqkey.exponent);
            freebn(reqkey.modulus);
@@ -1049,10 +1201,16 @@ static void answer_msg(void *msg)
            struct ssh2_userkey *key;
            struct blob b;
 
+           if (msgend < p+4)
+               goto failure;
            b.len = GET_32BIT(p);
            p += 4;
+
+           if (msgend < p+b.len)
+               goto failure;
            b.blob = p;
            p += b.len;
+
            key = find234(ssh2keys, &b, cmpkeys_ssh2_asymm);
            if (!key)
                goto failure;