Validate newly created DSA keys more carefully. Don't want a structure
[u/mdw/putty] / sshdss.c
index 6487d75..eba03aa 100644 (file)
--- a/sshdss.c
+++ b/sshdss.c
@@ -82,6 +82,8 @@ static Bignum get160(char **data, int *datalen)
     return b;
 }
 
+static void dss_freekey(void *key);    /* forward reference */
+
 static void *dss_newkey(char *data, int len)
 {
     char *p;
@@ -89,8 +91,6 @@ static void *dss_newkey(char *data, int len)
     struct dss_key *dss;
 
     dss = snew(struct dss_key);
-    if (!dss)
-       return NULL;
     getstring(&data, &len, &p, &slen);
 
 #ifdef DEBUG_DSS
@@ -111,6 +111,14 @@ static void *dss_newkey(char *data, int len)
     dss->q = getmp(&data, &len);
     dss->g = getmp(&data, &len);
     dss->y = getmp(&data, &len);
+    dss->x = NULL;
+
+    if (!dss->p || !dss->q || !dss->g || !dss->y ||
+        !bignum_cmp(dss->q, Zero) || !bignum_cmp(dss->p, Zero)) {
+        /* Invalid key. */
+        dss_freekey(dss);
+        return NULL;
+    }
 
     return dss;
 }
@@ -118,10 +126,16 @@ static void *dss_newkey(char *data, int len)
 static void dss_freekey(void *key)
 {
     struct dss_key *dss = (struct dss_key *) key;
-    freebn(dss->p);
-    freebn(dss->q);
-    freebn(dss->g);
-    freebn(dss->y);
+    if (dss->p)
+        freebn(dss->p);
+    if (dss->q)
+        freebn(dss->q);
+    if (dss->g)
+        freebn(dss->g);
+    if (dss->y)
+        freebn(dss->y);
+    if (dss->x)
+        freebn(dss->x);
     sfree(dss);
 }
 
@@ -384,7 +398,13 @@ static void *dss_createkey(unsigned char *pub_blob, int pub_len,
     Bignum ytest;
 
     dss = dss_newkey((char *) pub_blob, pub_len);
+    if (!dss)
+        return NULL;
     dss->x = getmp(&pb, &priv_len);
+    if (!dss->x) {
+        dss_freekey(dss);
+        return NULL;
+    }
 
     /*
      * Check the obsolete hash in the old DSS key format.
@@ -423,8 +443,6 @@ static void *dss_openssh_createkey(unsigned char **blob, int *len)
     struct dss_key *dss;
 
     dss = snew(struct dss_key);
-    if (!dss)
-       return NULL;
 
     dss->p = getmp(b, len);
     dss->q = getmp(b, len);
@@ -432,14 +450,11 @@ static void *dss_openssh_createkey(unsigned char **blob, int *len)
     dss->y = getmp(b, len);
     dss->x = getmp(b, len);
 
-    if (!dss->p || !dss->q || !dss->g || !dss->y || !dss->x) {
-       freebn(dss->p);
-       freebn(dss->q);
-       freebn(dss->g);
-       freebn(dss->y);
-       freebn(dss->x);
-       sfree(dss);
-       return NULL;
+    if (!dss->p || !dss->q || !dss->g || !dss->y || !dss->x ||
+        !bignum_cmp(dss->q, Zero) || !bignum_cmp(dss->p, Zero)) {
+        /* Invalid key. */
+        dss_freekey(dss);
+        return NULL;
     }
 
     return dss;
@@ -479,6 +494,8 @@ static int dss_pubkey_bits(void *blob, int len)
     int ret;
 
     dss = dss_newkey((char *) blob, len);
+    if (!dss)
+        return -1;
     ret = bignum_bitcount(dss->p);
     dss_freekey(dss);