| 1 | # Generate test cases for a bignum implementation. |
| 2 | |
| 3 | import sys |
| 4 | |
| 5 | # integer square roots |
| 6 | def sqrt(n): |
| 7 | d = long(n) |
| 8 | a = 0L |
| 9 | # b must start off as a power of 4 at least as large as n |
| 10 | ndigits = len(hex(long(n))) |
| 11 | b = 1L << (ndigits*4) |
| 12 | while 1: |
| 13 | a = a >> 1 |
| 14 | di = 2*a + b |
| 15 | if di <= d: |
| 16 | d = d - di |
| 17 | a = a + b |
| 18 | b = b >> 2 |
| 19 | if b == 0: break |
| 20 | return a |
| 21 | |
| 22 | # continued fraction convergents of a rational |
| 23 | def confrac(n, d): |
| 24 | coeffs = [(1,0),(0,1)] |
| 25 | while d != 0: |
| 26 | i = n / d |
| 27 | n, d = d, n % d |
| 28 | coeffs.append((coeffs[-2][0]-i*coeffs[-1][0], |
| 29 | coeffs[-2][1]-i*coeffs[-1][1])) |
| 30 | return coeffs |
| 31 | |
| 32 | def findprod(target, dir = +1, ratio=(1,1)): |
| 33 | # Return two numbers whose product is as close as we can get to |
| 34 | # 'target', with any deviation having the sign of 'dir', and in |
| 35 | # the same approximate ratio as 'ratio'. |
| 36 | |
| 37 | r = sqrt(target * ratio[0] * ratio[1]) |
| 38 | a = r / ratio[1] |
| 39 | b = r / ratio[0] |
| 40 | if a*b * dir < target * dir: |
| 41 | a = a + 1 |
| 42 | b = b + 1 |
| 43 | assert a*b * dir >= target * dir |
| 44 | |
| 45 | best = (a,b,a*b) |
| 46 | |
| 47 | while 1: |
| 48 | improved = 0 |
| 49 | a, b = best[:2] |
| 50 | |
| 51 | coeffs = confrac(a, b) |
| 52 | for c in coeffs: |
| 53 | # a*c[0]+b*c[1] is as close as we can get it to zero. So |
| 54 | # if we replace a and b with a+c[1] and b+c[0], then that |
| 55 | # will be added to our product, along with c[0]*c[1]. |
| 56 | da, db = c[1], c[0] |
| 57 | |
| 58 | # Flip signs as appropriate. |
| 59 | if (a+da) * (b+db) * dir < target * dir: |
| 60 | da, db = -da, -db |
| 61 | |
| 62 | # Multiply up. We want to get as close as we can to a |
| 63 | # solution of the quadratic equation in n |
| 64 | # |
| 65 | # (a + n da) (b + n db) = target |
| 66 | # => n^2 da db + n (b da + a db) + (a b - target) = 0 |
| 67 | A,B,C = da*db, b*da+a*db, a*b-target |
| 68 | discrim = B^2-4*A*C |
| 69 | if discrim > 0 and A != 0: |
| 70 | root = sqrt(discrim) |
| 71 | vals = [] |
| 72 | vals.append((-B + root) / (2*A)) |
| 73 | vals.append((-B - root) / (2*A)) |
| 74 | if root * root != discrim: |
| 75 | root = root + 1 |
| 76 | vals.append((-B + root) / (2*A)) |
| 77 | vals.append((-B - root) / (2*A)) |
| 78 | |
| 79 | for n in vals: |
| 80 | ap = a + da*n |
| 81 | bp = b + db*n |
| 82 | pp = ap*bp |
| 83 | if pp * dir >= target * dir and pp * dir < best[2]*dir: |
| 84 | best = (ap, bp, pp) |
| 85 | improved = 1 |
| 86 | |
| 87 | if not improved: |
| 88 | break |
| 89 | |
| 90 | return best |
| 91 | |
| 92 | def hexstr(n): |
| 93 | s = hex(n) |
| 94 | if s[:2] == "0x": s = s[2:] |
| 95 | if s[-1:] == "L": s = s[:-1] |
| 96 | return s |
| 97 | |
| 98 | # Tests of multiplication which exercise the propagation of the last |
| 99 | # carry to the very top of the number. |
| 100 | for i in range(1,4200): |
| 101 | a, b, p = findprod((1<<i)+1, +1, (i, i*i+1)) |
| 102 | print "mul", hexstr(a), hexstr(b), hexstr(p) |
| 103 | a, b, p = findprod((1<<i)+1, +1, (i, i+1)) |
| 104 | print "mul", hexstr(a), hexstr(b), hexstr(p) |
| 105 | |
| 106 | # Simple tests of modpow. |
| 107 | for i in range(64, 4097, 63): |
| 108 | modulus = sqrt(1<<(2*i-1)) | 1 |
| 109 | base = sqrt(3*modulus*modulus) % modulus |
| 110 | expt = sqrt(modulus*modulus*2/5) |
| 111 | print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus)) |
| 112 | if i <= 1024: |
| 113 | # Test even moduli, which can't be done by Montgomery. |
| 114 | modulus = modulus - 1 |
| 115 | print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus)) |