Beginnings of a test suite for the bignum code. The output of
authorsimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Sun, 20 Feb 2011 14:59:00 +0000 (14:59 +0000)
committersimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Sun, 20 Feb 2011 14:59:00 +0000 (14:59 +0000)
testdata/bignum.py is twice the size of the rest of the PuTTY source
put together, so I'm not checking it in.

This reveals bugs in the new multiplication code, which I have yet to
fix.

git-svn-id: svn://svn.tartarus.org/sgt/putty@9097 cda61777-01e9-0310-a592-d414129be87e

sshbn.c
testdata/bignum.py [new file with mode: 0644]

diff --git a/sshbn.c b/sshbn.c
index 7368a43..20244e8 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -247,6 +247,9 @@ static void internal_mul(const BignumInt *a, const BignumInt *b,
         int midlen = botlen + 1;
         BignumInt *scratch;
         BignumDblInt carry;
+#ifdef KARA_DEBUG
+        int i;
+#endif
 
         /*
          * The coefficients a_1 b_1 and a_0 b_0 just avoid overlapping
@@ -254,11 +257,40 @@ static void internal_mul(const BignumInt *a, const BignumInt *b,
          * place.
          */
 
+#ifdef KARA_DEBUG
+        printf("a1,a0 = 0x");
+        for (i = 0; i < len; i++) {
+            if (i == toplen) printf(", 0x");
+            printf("%0*x", BIGNUM_INT_BITS/4, a[i]);
+        }
+        printf("\n");
+        printf("b1,b0 = 0x");
+        for (i = 0; i < len; i++) {
+            if (i == toplen) printf(", 0x");
+            printf("%0*x", BIGNUM_INT_BITS/4, b[i]);
+        }
+        printf("\n");
+#endif
+
         /* a_1 b_1 */
         internal_mul(a, b, c, toplen);
+#ifdef KARA_DEBUG
+        printf("a1b1 = 0x");
+        for (i = 0; i < 2*toplen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, c[i]);
+        }
+        printf("\n");
+#endif
 
         /* a_0 b_0 */
         internal_mul(a + toplen, b + toplen, c + 2*toplen, botlen);
+#ifdef KARA_DEBUG
+        printf("a0b0 = 0x");
+        for (i = 0; i < 2*botlen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, c[2*toplen+i]);
+        }
+        printf("\n");
+#endif
 
         /*
          * We must allocate scratch space for the central coefficient,
@@ -281,14 +313,35 @@ static void internal_mul(const BignumInt *a, const BignumInt *b,
 
         /* compute a_1 + a_0 */
         scratch[0] = internal_add(scratch+1, a+toplen, scratch+1, botlen);
+#ifdef KARA_DEBUG
+        printf("a1plusa0 = 0x");
+        for (i = 0; i < midlen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]);
+        }
+        printf("\n");
+#endif
         /* compute b_1 + b_0 */
         scratch[midlen] = internal_add(scratch+midlen+1, b+toplen,
                                        scratch+midlen+1, botlen);
+#ifdef KARA_DEBUG
+        printf("b1plusb0 = 0x");
+        for (i = 0; i < midlen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, scratch[midlen+i]);
+        }
+        printf("\n");
+#endif
 
         /*
          * Now we can do the third multiplication.
          */
         internal_mul(scratch, scratch + midlen, scratch + 2*midlen, midlen);
+#ifdef KARA_DEBUG
+        printf("a1plusa0timesb1plusb0 = 0x");
+        for (i = 0; i < 2*midlen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]);
+        }
+        printf("\n");
+#endif
 
         /*
          * Now we can reuse the first half of 'scratch' to compute the
@@ -300,9 +353,23 @@ static void internal_mul(const BignumInt *a, const BignumInt *b,
             scratch[2*midlen - 2*toplen + j] = c[j];
         scratch[1] = internal_add(scratch+2, c + 2*toplen,
                                   scratch+2, 2*botlen);
+#ifdef KARA_DEBUG
+        printf("a1b1plusa0b0 = 0x");
+        for (i = 0; i < 2*midlen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]);
+        }
+        printf("\n");
+#endif
 
         internal_sub(scratch + 2*midlen, scratch,
                      scratch + 2*midlen, 2*midlen);
+#ifdef KARA_DEBUG
+        printf("a1b0plusa0b1 = 0x");
+        for (i = 0; i < 2*midlen; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]);
+        }
+        printf("\n");
+#endif
 
         /*
          * And now all we need to do is to add that middle coefficient
@@ -320,6 +387,13 @@ static void internal_mul(const BignumInt *a, const BignumInt *b,
             c[j] = (BignumInt)carry;
             carry >>= BIGNUM_INT_BITS;
         }
+#ifdef KARA_DEBUG
+        printf("ab = 0x");
+        for (i = 0; i < 2*len; i++) {
+            printf("%0*x", BIGNUM_INT_BITS/4, c[i]);
+        }
+        printf("\n");
+#endif
 
         /* Free scratch. */
         for (j = 0; j < 4 * midlen; j++)
@@ -1531,3 +1605,110 @@ char *bignum_decimal(Bignum x)
     sfree(workspace);
     return ret;
 }
+
+#ifdef TESTBN
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <ctype.h>
+
+/*
+ * gcc -g -O0 -DTESTBN -o testbn sshbn.c misc.c -I unix -I charset
+ */
+
+void modalfatalbox(char *p, ...)
+{
+    va_list ap;
+    fprintf(stderr, "FATAL ERROR: ");
+    va_start(ap, p);
+    vfprintf(stderr, p, ap);
+    va_end(ap);
+    fputc('\n', stderr);
+    exit(1);
+}
+
+#define fromxdigit(c) ( (c)>'9' ? ((c)&0xDF) - 'A' + 10 : (c) - '0' )
+
+int main(int argc, char **argv)
+{
+    char *buf;
+    int line = 0;
+    int passes = 0, fails = 0;
+
+    while ((buf = fgetline(stdin)) != NULL) {
+        int maxlen = strlen(buf);
+        unsigned char *data = snewn(maxlen, unsigned char);
+        unsigned char *ptrs[4], *q;
+        int ptrnum;
+        char *bufp = buf;
+
+        line++;
+
+        q = data;
+        ptrnum = 0;
+
+        while (*bufp) {
+            char *start, *end;
+            int i;
+
+            while (*bufp && !isxdigit((unsigned char)*bufp))
+                bufp++;
+            start = bufp;
+
+            if (!*bufp)
+                break;
+
+            while (*bufp && isxdigit((unsigned char)*bufp))
+                bufp++;
+            end = bufp;
+
+            if (ptrnum >= lenof(ptrs))
+                break;
+            ptrs[ptrnum++] = q;
+            
+            for (i = -((end - start) & 1); i < end-start; i += 2) {
+                unsigned char val = (i < 0 ? 0 : fromxdigit(start[i]));
+                val = val * 16 + fromxdigit(start[i+1]);
+                *q++ = val;
+            }
+
+            ptrs[ptrnum] = q;
+        }
+
+        if (ptrnum == 3) {
+            Bignum a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
+            Bignum b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
+            Bignum c = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
+            Bignum p = bigmul(a, b);
+
+            if (bignum_cmp(c, p) == 0) {
+                passes++;
+            } else {
+                char *as = bignum_decimal(a);
+                char *bs = bignum_decimal(b);
+                char *cs = bignum_decimal(c);
+                char *ps = bignum_decimal(p);
+                
+                printf("%d: fail: %s * %s gave %s expected %s\n",
+                       line, as, bs, ps, cs);
+                fails++;
+
+                sfree(as);
+                sfree(bs);
+                sfree(cs);
+                sfree(ps);
+            }
+            freebn(a);
+            freebn(b);
+            freebn(c);
+            freebn(p);
+        }
+        sfree(buf);
+        sfree(data);
+    }
+
+    printf("passed %d failed %d total %d\n", passes, fails, passes+fails);
+    return fails != 0;
+}
+
+#endif
diff --git a/testdata/bignum.py b/testdata/bignum.py
new file mode 100644 (file)
index 0000000..2c7fb4e
--- /dev/null
@@ -0,0 +1,82 @@
+# Generate test cases for a bignum implementation.
+
+import sys
+import mathlib
+
+def findprod(target, dir = +1, ratio=(1,1)):
+    # Return two numbers whose product is as close as we can get to
+    # 'target', with any deviation having the sign of 'dir', and in
+    # the same approximate ratio as 'ratio'.
+
+    r = mathlib.sqrt(target * ratio[0] * ratio[1])
+    a = r / ratio[1]
+    b = r / ratio[0]
+    if a*b * dir < target * dir:
+        a = a + 1
+        b = b + 1
+    assert a*b * dir >= target * dir
+
+    best = (a,b,a*b)
+
+    while 1:
+        improved = 0
+        a, b = best[:2]
+
+        terms = mathlib.confracr(a, b, output=None)
+        coeffs = [(1,0),(0,1)]
+        for t in terms:
+            coeffs.append((coeffs[-2][0]-t*coeffs[-1][0],
+                           coeffs[-2][1]-t*coeffs[-1][1]))
+        for c in coeffs:
+            # a*c[0]+b*c[1] is as close as we can get it to zero. So
+            # if we replace a and b with a+c[1] and b+c[0], then that
+            # will be added to our product, along with c[0]*c[1].
+            da, db = c[1], c[0]
+
+            # Flip signs as appropriate.
+            if (a+da) * (b+db) * dir < target * dir:
+                da, db = -da, -db
+
+            # Multiply up. We want to get as close as we can to a
+            # solution of the quadratic equation in n
+            #
+            #    (a + n da) (b + n db) = target
+            # => n^2 da db + n (b da + a db) + (a b - target) = 0
+            A,B,C = da*db, b*da+a*db, a*b-target
+            discrim = B^2-4*A*C
+            if discrim > 0 and A != 0:
+                root = mathlib.sqrt(discrim)
+                vals = []
+                vals.append((-B + root) / (2*A))
+                vals.append((-B - root) / (2*A))
+                if root * root != discrim:
+                    root = root + 1
+                    vals.append((-B + root) / (2*A))
+                    vals.append((-B - root) / (2*A))
+
+                for n in vals:
+                    ap = a + da*n
+                    bp = b + db*n
+                    pp = ap*bp
+                    if pp * dir >= target * dir and pp * dir < best[2]*dir:
+                        best = (ap, bp, pp)
+                        improved = 1
+
+        if not improved:
+            break
+
+    return best
+
+def hexstr(n):
+    s = hex(n)
+    if s[:2] == "0x": s = s[2:]
+    if s[-1:] == "L": s = s[:-1]
+    return s
+
+# Tests of multiplication which exercise the propagation of the last
+# carry to the very top of the number.
+for i in range(1,4200):
+    a, b, p = findprod((1<<i)+1, +1, (i, i*i+1))
+    print hexstr(a), hexstr(b), hexstr(p)
+    a, b, p = findprod((1<<i)+1, +1, (i, i+1))
+    print hexstr(a), hexstr(b), hexstr(p)