Merged SSH1 robustness changes from 0.55 release branch on to trunk.
[u/mdw/putty] / sshbn.c
diff --git a/sshbn.c b/sshbn.c
index a51c3a5..d32eb1b 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -3,6 +3,7 @@
  */
 
 #include <stdio.h>
+#include <assert.h>
 #include <stdlib.h>
 #include <string.h>
 
@@ -15,6 +16,10 @@ typedef unsigned long long BignumDblInt;
 #define BIGNUM_TOP_BIT   0x80000000UL
 #define BIGNUM_INT_BITS  32
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
+#define DIVMOD_WORD(q, r, hi, lo, w) \
+    __asm__("div %2" : \
+           "=d" (r), "=a" (q) : \
+           "r" (w), "d" (hi), "a" (lo))
 #else
 typedef unsigned short BignumInt;
 typedef unsigned long BignumDblInt;
@@ -22,6 +27,11 @@ typedef unsigned long BignumDblInt;
 #define BIGNUM_TOP_BIT   0x8000U
 #define BIGNUM_INT_BITS  16
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
+#define DIVMOD_WORD(q, r, hi, lo, w) do { \
+    BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \
+    q = n / w; \
+    r = n % w; \
+} while (0)
 #endif
 
 #define BIGNUM_INT_BYTES (BIGNUM_INT_BITS / 8)
@@ -124,7 +134,7 @@ static void internal_add_shifted(BignumInt *number,
     int bshift = shift % BIGNUM_INT_BITS;
     BignumDblInt addend;
 
-    addend = n << bshift;
+    addend = (BignumDblInt)n << bshift;
 
     while (addend) {
        addend += number[word];
@@ -175,13 +185,11 @@ static void internal_mod(BignumInt *a, int alen,
            ai1 = a[i + 1];
 
        /* Find q = h:a[i] / m0 */
-       t = ((BignumDblInt) h << BIGNUM_INT_BITS) + a[i];
-       q = t / m0;
-       r = t % m0;
+       DIVMOD_WORD(q, r, h, a[i], m0);
 
        /* Refine our estimate of q by looking at
           h:a[i]:a[i+1] / m0:m1 */
-       t = (BignumDblInt) m1 * (BignumDblInt) q;
+       t = MUL_WORD(m1, q);
        if (t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) {
            q--;
            t -= m1;
@@ -193,7 +201,7 @@ static void internal_mod(BignumInt *a, int alen,
        /* Subtract q * m from a[i...] */
        c = 0;
        for (k = mlen - 1; k >= 0; k--) {
-           t = (BignumDblInt) q * (BignumDblInt) m[k];
+           t = MUL_WORD(q, m[k]);
            t += c;
            c = t >> BIGNUM_INT_BITS;
            if ((BignumInt) t > a[i + k])
@@ -219,16 +227,25 @@ static void internal_mod(BignumInt *a, int alen,
 
 /*
  * Compute (base ^ exp) % mod.
- * The base MUST be smaller than the modulus.
- * The most significant word of mod MUST be non-zero.
- * We assume that the result array is the same size as the mod array.
  */
-Bignum modpow(Bignum base, Bignum exp, Bignum mod)
+Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
 {
     BignumInt *a, *b, *n, *m;
     int mshift;
     int mlen, i, j;
-    Bignum result;
+    Bignum base, result;
+
+    /*
+     * The most significant word of mod needs to be non-zero. It
+     * should already be, but let's make sure.
+     */
+    assert(mod[mod[0]] != 0);
+
+    /*
+     * Make sure the base is smaller than the modulus, by reducing
+     * it modulo the modulus if not.
+     */
+    base = bigmod(base_in, mod);
 
     /* Allocate m of size mlen, copy mod to m */
     /* We use big endian internally */
@@ -324,6 +341,8 @@ Bignum modpow(Bignum base, Bignum exp, Bignum mod)
        n[i] = 0;
     sfree(n);
 
+    freebn(base);
+
     return result;
 }
 
@@ -521,19 +540,25 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
 
 /*
  * Read an ssh1-format bignum from a data buffer. Return the number
- * of bytes consumed.
+ * of bytes consumed, or -1 if there wasn't enough data.
  */
-int ssh1_read_bignum(const unsigned char *data, Bignum * result)
+int ssh1_read_bignum(const unsigned char *data, int len, Bignum * result)
 {
     const unsigned char *p = data;
     int i;
     int w, b;
 
+    if (len < 2)
+       return -1;
+
     w = 0;
     for (i = 0; i < 2; i++)
        w = (w << 8) + *p++;
     b = (w + 7) / 8;                  /* bits -> bytes */
 
+    if (len < b+2)
+       return -1;
+
     if (!result)                      /* just return length */
        return b + 2;
 
@@ -722,6 +747,7 @@ Bignum bigmuladd(Bignum a, Bignum b, Bignum addend)
     }
     ret[0] = maxspot;
 
+    sfree(workspace);
     return ret;
 }
 
@@ -807,7 +833,7 @@ unsigned short bignum_mod_short(Bignum number, unsigned short modulus)
     r = 0;
     mod = modulus;
     for (i = number[0]; i > 0; i--)
-       r = (r * 65536 + number[i]) % mod;
+       r = (r * (BIGNUM_TOP_BIT % mod) * 2 + number[i] % mod) % mod;
     return (unsigned short) r;
 }
 
@@ -901,6 +927,7 @@ Bignum modinv(Bignum number, Bignum modulus)
        x = bigmuladd(q, xp, t);
        sign = -sign;
        freebn(t);
+       freebn(q);
     }
 
     freebn(b);
@@ -1002,5 +1029,6 @@ char *bignum_decimal(Bignum x)
     /*
      * Done.
      */
+    sfree(workspace);
     return ret;
 }