Add tests of modpow.
authorsimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Sun, 20 Feb 2011 15:27:48 +0000 (15:27 +0000)
committersimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Sun, 20 Feb 2011 15:27:48 +0000 (15:27 +0000)
git-svn-id: svn://svn.tartarus.org/sgt/putty@9100 cda61777-01e9-0310-a592-d414129be87e

sshbn.c
testdata/bignum.py

diff --git a/sshbn.c b/sshbn.c
index 3534361..bfc34fb 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -1737,6 +1737,9 @@ char *bignum_decimal(Bignum x)
 
 /*
  * gcc -g -O0 -DTESTBN -o testbn sshbn.c misc.c -I unix -I charset
+ *
+ * Then feed to this program's standard input the output of
+ * testdata/bignum.py .
  */
 
 void modalfatalbox(char *p, ...)
@@ -1761,7 +1764,7 @@ int main(int argc, char **argv)
     while ((buf = fgetline(stdin)) != NULL) {
         int maxlen = strlen(buf);
         unsigned char *data = snewn(maxlen, unsigned char);
-        unsigned char *ptrs[4], *q;
+        unsigned char *ptrs[5], *q;
         int ptrnum;
         char *bufp = buf;
 
@@ -1770,6 +1773,11 @@ int main(int argc, char **argv)
         q = data;
         ptrnum = 0;
 
+        while (*bufp && !isspace((unsigned char)*bufp))
+            bufp++;
+        if (bufp)
+            *bufp++ = '\0';
+
         while (*bufp) {
             char *start, *end;
             int i;
@@ -1798,11 +1806,17 @@ int main(int argc, char **argv)
             ptrs[ptrnum] = q;
         }
 
-        if (ptrnum == 3) {
-            Bignum a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
-            Bignum b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
-            Bignum c = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
-            Bignum p = bigmul(a, b);
+        if (!strcmp(buf, "mul")) {
+            Bignum a, b, c, p;
+
+            if (ptrnum != 3) {
+                printf("%d: mul with %d parameters, expected 3\n", line);
+                exit(1);
+            }
+            a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
+            b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
+            c = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
+            p = bigmul(a, b);
 
             if (bignum_cmp(c, p) == 0) {
                 passes++;
@@ -1825,7 +1839,49 @@ int main(int argc, char **argv)
             freebn(b);
             freebn(c);
             freebn(p);
+        } else if (!strcmp(buf, "pow")) {
+            Bignum base, expt, modulus, expected, answer;
+
+            if (ptrnum != 4) {
+                printf("%d: mul with %d parameters, expected 3\n", line);
+                exit(1);
+            }
+
+            base = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
+            expt = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
+            modulus = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
+            expected = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
+            answer = modpow(base, expt, modulus);
+
+            if (bignum_cmp(expected, answer) == 0) {
+                passes++;
+            } else {
+                char *as = bignum_decimal(base);
+                char *bs = bignum_decimal(expt);
+                char *cs = bignum_decimal(modulus);
+                char *ds = bignum_decimal(answer);
+                char *ps = bignum_decimal(expected);
+                
+                printf("%d: fail: %s ^ %s mod %s gave %s expected %s\n",
+                       line, as, bs, cs, ds, ps);
+                fails++;
+
+                sfree(as);
+                sfree(bs);
+                sfree(cs);
+                sfree(ds);
+                sfree(ps);
+            }
+            freebn(base);
+            freebn(expt);
+            freebn(modulus);
+            freebn(expected);
+            freebn(answer);
+        } else {
+            printf("%d: unrecognised test keyword: '%s'\n", line, buf);
+            exit(1);
         }
+
         sfree(buf);
         sfree(data);
     }
index 2c7fb4e..f781bea 100644 (file)
@@ -77,6 +77,13 @@ def hexstr(n):
 # carry to the very top of the number.
 for i in range(1,4200):
     a, b, p = findprod((1<<i)+1, +1, (i, i*i+1))
-    print hexstr(a), hexstr(b), hexstr(p)
+    print "mul", hexstr(a), hexstr(b), hexstr(p)
     a, b, p = findprod((1<<i)+1, +1, (i, i+1))
-    print hexstr(a), hexstr(b), hexstr(p)
+    print "mul", hexstr(a), hexstr(b), hexstr(p)
+
+# Simple tests of modpow.
+for i in range(64, 4097, 63):
+    modulus = mathlib.sqrt(1<<(2*i-1)) | 1
+    base = mathlib.sqrt(3*modulus*modulus) % modulus
+    expt = mathlib.sqrt(modulus*modulus*2/5)
+    print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus))