math/Makefile.am, symm/Makefile.am: Use `--no-install' on oddball tests.
[catacomb] / utils / curve25519.sage
index 32ddb17..6fa3cd3 100644 (file)
@@ -1,10 +1,85 @@
 #! /usr/local/bin/sage
 ### -*- mode: python; coding: utf-8 -*-
 
+import hashlib as H
+
 ###--------------------------------------------------------------------------
-### Define the field.
+### Some general utilities.
+
+def hash(*m):
+  h = H.sha512()
+  for i in m: h.update(i)
+  return h.digest()
+
+def ld(v):
+  return 0 + sum(ord(v[i]) << 8*i for i in xrange(len(v)))
+
+def st(x, n):
+  return ''.join(chr((x >> 8*i)&0xff) for i in xrange(n))
+
+def piece_widths_offsets(wd, n):
+  o = [ceil(wd*i/n) for i in xrange(n + 1)]
+  w = [o[i + 1] - o[i] for i in xrange(n)]
+  return w, o
+
+def pieces(x, wd, n, bias = 0):
+
+  ## Figure out widths and offsets.
+  w, o = piece_widths_offsets(wd, n)
+
+  ## First, normalize |n| < bias/2.
+  if bias and n >= bias/2: n -= bias
+
+  ## First, collect the bits.
+  nn = []
+  for i in xrange(n - 1):
+    m = (1 << w[i]) - 1
+    nn.append(x&m)
+    x >>= w[i]
+  nn.append(x)
+
+  ## Now normalize them to the appropriate interval.
+  c = 0
+  for i in xrange(n - 1):
+    b = 1 << (w[i] - 1)
+    if nn[i] >= b:
+      nn[i] -= 2*b
+      nn[i + 1] += 1
+
+  ## And we're done.
+  return nn
+
+def combine(v, wd, n):
+  w, o = piece_widths_offsets(wd, n)
+  return sum(v[i] << o[i] for i in xrange(n))
+
+###--------------------------------------------------------------------------
+### Define the curve.
 
 p = 2^255 - 19; k = GF(p)
+A = k(486662); A0 = (A - 2)/4
+E = EllipticCurve(k, [0, A, 0, 1, 0]); P = E.lift_x(9)
+l = 2^252 + 27742317777372353535851937790883648493
+
+assert is_prime(l)
+assert (l*P).is_zero()
+assert (p + 1 - 8*l)^2 <= 4*p
+
+###--------------------------------------------------------------------------
+### Example points from `Cryptography in NaCl'.
+
+x = ld(map(chr, [0x70,0x07,0x6d,0x0a,0x73,0x18,0xa5,0x7d
+,0x3c,0x16,0xc1,0x72,0x51,0xb2,0x66,0x45
+,0xdf,0x4c,0x2f,0x87,0xeb,0xc0,0x99,0x2a
+,0xb1,0x77,0xfb,0xa5,0x1d,0xb9,0x2c,0x6a]))
+y = ld(map(chr, [0x58,0xab,0x08,0x7e,0x62,0x4a,0x8a,0x4b
+,0x79,0xe1,0x7f,0x8b,0x83,0x80,0x0e,0xe6
+,0x6f,0x3b,0xb1,0x29,0x26,0x18,0xb6,0xfd
+,0x1c,0x2f,0x8b,0x27,0xff,0x88,0xe0,0x6b]))
+X = x*P
+Y = y*P
+Z = x*Y
+assert Z == y*X
 
 ###--------------------------------------------------------------------------
 ### Arithmetic implementation.
@@ -13,6 +88,8 @@ def sqrn(x, n):
   for i in xrange(n): x = x*x
   return x
 
+sqrtm1 = sqrt(k(-1))
+
 def inv(x):
   t2 = sqrn(x, 1)        #   1 | 2
   u = sqrn(t2, 2)        #   3 | 8
@@ -38,6 +115,352 @@ def inv(x):
   t = u*t11              # 265 | 2^255 - 21
   return t
 
+def quosqrt_djb(x, y):
+
+  ## First, some preliminary values.
+  y2 = sqrn(y, 1)        #   1 | 0, 2
+  y3 = y2*y              #   2 | 0, 3
+  xy3 = x*y3             #   3 | 1, 3
+  y4 = sqrn(y2, 1)       #   4 | 0, 4
+  w = xy3*y4             #   5 | 1, 7
+
+  ## Now calculate w^(p - 5)/8.  Notice that (p - 5)/8 =
+  ## (2^255 - 24)/8 = 2^252 - 3.
+  u = sqrn(w, 1)         #   6 | 2
+  t = u*w                #   7 | 3
+  u = sqrn(t, 1)         #   8 | 6
+  t = u*w                #   9 | 7
+  u = sqrn(t, 3)         #  12 | 56
+  t = u*t                #  13 | 63 = 2^6 - 1
+  u = sqrn(t, 6)         #  19 | 2^12 - 2^6
+  t = u*t                #  20 | 2^12 - 1
+  u = sqrn(t, 12)        #  32 | 2^24 - 2^12
+  t = u*t                #  33 | 2^24 - 1
+  u = sqrn(t, 1)         #  34 | 2^25 - 2
+  t = u*w                #  35 | 2^25 - 1
+  u = sqrn(t, 25)        #  60 | 2^50 - 2^25
+  t2p50m1 = u*t          #  61 | 2^50 - 1
+  u = sqrn(t2p50m1, 50)  # 111 | 2^100 - 2^50
+  t = u*t2p50m1          # 112 | 2^100 - 1
+  u = sqrn(t, 100)       # 212 | 2^200 - 2^100
+  t = u*t                # 213 | 2^200 - 1
+  u = sqrn(t, 50)        # 263 | 2^250 - 2^50
+  t = u*t2p50m1          # 264 | 2^250 - 1
+  u = sqrn(t, 2)         # 266 | 2^252 - 4
+  t = u*w                # 267 | 2^252 - 3
+  beta = t*xy3           # 268 |
+
+  ## Now we have beta = (x y^3) (x y^7)^((p - 5)/8) =
+  ## x^((p + 3)/8) y^((7 p - 11)/8) = (x/y)^((p + 3)/8).
+  ## Suppose alpha^2 = x/y.  Then beta^4 = (x/y)^((p + 3)/2) =
+  ## alpha^(p + 3) = alpha^4 = (x/y)^2, so beta^2 = ±x/y.  If
+  ## y beta^2 = x then alpha = beta and we're done; if
+  ## y beta^2 = -x, then alpha = beta sqrt(-1); otherwise x/y
+  ## wasn't actually a square after all.
+  t = y*beta^2
+  if t == x: return beta
+  elif t == -x: return beta*sqrtm1
+  else: raise ValueError, 'not a square'
+
+def quosqrt_mdw(x, y):
+  v = x*y
+
+  ## Now we calculate w = v^{3*2^252 - 8}.  This will be explained later.
+  u = sqrn(v, 1)         #   1 | 2
+  t = u*v                #   2 | 3
+  u = sqrn(t, 2)         #   4 | 12
+  t15 = u*t              #   5 | 15
+  u = sqrn(t15, 1)       #   6 | 30
+  t = u*v                #   7 | 31 = 2^5 - 1
+  u = sqrn(t, 5)         #  12 | 2^10 - 2^5
+  t = u*t                #  13 | 2^10 - 1
+  u = sqrn(t, 10)        #  23 | 2^20 - 2^10
+  u = u*t                #  24 | 2^20 - 1
+  u = sqrn(u, 10)        #  34 | 2^30 - 2^10
+  t = u*t                #  35 | 2^30 - 1
+  u = sqrn(t, 1)         #  36 | 2^31 - 2
+  t = u*v                #  37 | 2^31 - 1
+  u = sqrn(t, 31)        #  68 | 2^62 - 2^31
+  t = u*t                #  69 | 2^62 - 1
+  u = sqrn(t, 62)        # 131 | 2^124 - 2^62
+  t = u*t                # 132 | 2^124 - 1
+  u = sqrn(t, 124)       # 256 | 2^248 - 2^124
+  t = u*t                # 257 | 2^248 - 1
+  u = sqrn(t, 1)         # 258 | 2^249 - 2
+  t = u*v                # 259 | 2^249 - 1
+  t = sqrn(t, 3)         # 262 | 2^252 - 8
+  u = sqrn(t, 1)         # 263 | 2^253 - 16
+  t = u*t                # 264 | 3*2^252 - 24
+  t = t*t15              # 265 | 3*2^252 - 9
+  w = t*v                # 266 | 3*2^252 - 8
+
+  ## Awesome.  Now let me explain.  Let v be a square in GF(p), and let w =
+  ## v^(3*2^252 - 8).  In particular, let's consider
+  ##
+  ##    v^2 w^4 = v^2 v^{3*2^254 - 32} = (v^{2^254 - 10})^3
+  ##
+  ## But 2^254 - 10 = ((2^255 - 19) - 1)/2 = (p - 1)/2.  Since v is a square,
+  ## it has order dividing (p - 1)/2, and therefore v^2 w^4 = 1 and
+  ##
+  ##    w^4 = 1/v^2
+  ##
+  ## That in turn implies that w^2 = ±1/v.  Now, recall that v = x y, and let
+  ## w' = w x.  Then w'^2 = ±x^2/v = ±x/y.  If y w'^2 = x then we set
+  ## z = w', since we have z^2 = x/y; otherwise let z = i w', where i^2 = -1,
+  ## so z^2 = -w^2 = x/y, and we're done.
+  t = w*x
+  u = y*t^2
+  if u == x: return t
+  elif u == -x: return t*sqrtm1
+  else: raise ValueError, 'not a square'
+
+quosqrt = quosqrt_mdw
+
 assert inv(k(9))*9 == 1
+assert 5*quosqrt(k(4), k(5))^2 == 4
+
+###--------------------------------------------------------------------------
+### The Montgomery ladder.
+
+def x25519(n, x1):
+
+  ## Let Q = (x_1 : y_1 : 1) be an input point.  We calculate
+  ## n Q = (x_n : y_n : z_n), returning x_n/z_n (unless z_n = 0,
+  ## in which case we return zero).
+  ##
+  ## We're given that n = 2^254 + n'_254, where 0 <= n'_254 < 2^254.
+  bb = n.bits()
+  x, z = 1, 0
+  u, w = x1, 1
+
+  ## Initially, let i = 255.
+  for i in xrange(len(bb) - 1, -1, -1):
+
+    ## Split n = n_i 2^i + n'_i, where 0 <= n'_i < 2^i, so n_0 = n.
+    ## We have x, z = x_{n_{i+1}}, z_{n_{i+1}}, and
+    ## u, w = x_{n_{i+1}+1}, z_{n_{i+1}+1}.
+    ## Now either n_i = 2 n_{i+1} or n_i = 2 n_{i+1} + 1, depending
+    ## on bit i of n.
+
+    ## Swap (x : z) and (u : w) if bit i of n is set.
+    if bb[i]: x, z, u, w = u, w, x, z
+
+    ## Do the ladder step.
+    xmz, xpz = x - z, x + z
+    umw, upw = u - w, u + w
+    xmz2, xpz2 = xmz^2, xpz^2
+    xpz2mxmz2 = xpz2 - xmz2
+    xmzupw, xpzumw = xmz*upw, xpz*umw
+    x, z = xmz2*xpz2, xpz2mxmz2*(xpz2 + A0*xpz2mxmz2)
+    u, w = (xmzupw + xpzumw)^2, x1*(xmzupw - xpzumw)^2
+
+    ## Finally, unswap.
+    if bb[i]: x, z, u, w = u, w, x, z
+
+  ## Almost done.
+  return x*inv(z)
+
+assert x25519(y, k(9)) == Y[0]
+assert x25519(x, Y[0]) == x25519(y, X[0]) == Z[0]
+
+###--------------------------------------------------------------------------
+### Edwards curve parameters and conversion.
+
+a = k(-1)
+d = -A0/(A0 + 1)
+
+def mont_to_ed(u, v):
+  return sqrt(-A - 2)*u/v, (u - 1)/(u + 1)
+
+def ed_to_mont(x, y):
+  u = (1 + y)/(1 - y)
+  v = sqrt(-A - 2)*u/x
+  return u, v
+
+Bx, By = mont_to_ed(P[0], P[1])
+if Bx.lift()%2: Bx = -Bx
+B = (Bx, By, 1)
+u, v = ed_to_mont(Bx, By)
+
+assert By == k(4/5)
+assert -Bx^2 + By^2 == 1 + d*Bx^2*By^2
+assert u == k(9)
+assert v == P[1] or v == -P[1]
+
+###--------------------------------------------------------------------------
+### Edwards point addition and doubling.
+
+def ed_add((X1, Y1, Z1), (X2, Y2, Z2)):
+  A = Z1*Z2
+  B = A^2
+  C = X1*X2
+  D = Y1*Y2
+  E = d*C*D
+  F = B - E
+  G = B + E
+  X3 = A*F*((X1 + Y1)*(X2 + Y2) - C - D)
+  Y3 = A*G*(D - a*C)
+  Z3 = F*G
+  return X3, Y3, Z3
+
+def ed_dbl((X1, Y1, Z1)):
+  B = (X1 + Y1)^2
+  C = X1^2
+  D = Y1^2
+  E = a*C
+  F = E + D
+  H = Z1^2
+  J = F - 2*H
+  X3 = (B - C - D)*J
+  Y3 = F*(E - D)
+  Z3 = F*J
+  return X3, Y3, Z3
+
+Q = E.random_point()
+R = E.random_point()
+n = ZZ(randint(0, 2^255 - 1))
+m = ZZ(randint(0, 2^255 - 1))
+Qx, Qy = mont_to_ed(Q[0], Q[1])
+Rx, Ry = mont_to_ed(R[0], R[1])
+
+S = Q + R; T = 2*Q
+Sx, Sy, Sz = ed_add((Qx, Qy, 1), (Rx, Ry, 1))
+Tx, Ty, Tz = ed_dbl((Qx, Qy, 1))
+assert (Sx/Sz, Sy/Sz) == mont_to_ed(S[0], S[1])
+assert (Tx/Tz, Ty/Tz) == mont_to_ed(T[0], T[1])
+
+###--------------------------------------------------------------------------
+### Scalar multiplication.
+
+def ed_mul(n, Q):
+  winwd = 4
+  winlim = 1 << winwd
+  winmask = winlim - 1
+  tabsz = winlim/2 + 1
+
+  ## Recode the scalar to roughly-balanced form.
+  nn = [(n >> i)&winmask for i in xrange(0, n.nbits() + winwd, winwd)]
+  for i in xrange(len(nn) - 2, -1, -1):
+    if nn[i] >= winlim/2:
+      nn[i] -= winlim
+      nn[i + 1] += 1
+
+  ## Build the table of small multiples.
+  V = tabsz*[None]
+  V[0] = (0, 1, 1)
+  V[1] = Q
+  V[2] = ed_dbl(V[1])
+  for i in xrange(3, tabsz, 2):
+    V[i] = ed_add(V[i - 1], Q)
+    V[i + 1] = ed_dbl(V[(i + 1)/2])
+
+  ## Do the multiplication.
+  T = V[0]
+  for i in xrange(len(nn) - 1, -1, -1):
+    w = nn[i]
+    for j in xrange(winwd): T = ed_dbl(T)
+    if w >= 0: T = ed_add(T, V[w])
+    else: x, y, z = V[-w]; T = ed_add(T, (-x, y, z))
+
+  ## Done.
+  return T
+
+def ed_simmul(n0, Q0, n1, Q1):
+  winwd = 2
+  winlim = 1 << winwd
+  winmask = winlim - 1
+  tabsz = 1 << 2*winwd
+
+  ## Extract the scalar pieces.
+  nn = [(n0 >> i)&winmask | (((n1 >> i)&winmask) << winwd)
+        for i in xrange(0, max(n0.nbits(), n1.nbits()), winwd)]
+
+  ## Build the table of small linear combinations.
+  V = tabsz*[None]
+  V[0] = (0, 1, 1)
+  V[1] = Q0; V[winlim] = Q1
+  i = 2
+  while i < winlim:
+    V[i] = ed_dbl(V[i/2])
+    V[i*winlim] = ed_dbl(V[i*winlim/2])
+    i <<= 1
+  i = 2
+  while i < tabsz:
+    for j in xrange(1, i):
+      V[i + j] = ed_add(V[i], V[j])
+    i <<= 1
+
+  ## Do the multiplication.
+  T = V[0]
+  for i in xrange(len(nn) - 1, -1, -1):
+    w = nn[i]
+    for j in xrange(winwd): T = ed_dbl(T)
+    T = ed_add(T, V[w])
+
+  ## Done.
+  return T
+
+U = n*Q; V = n*Q + m*R
+Ux, Uy, Uz = ed_mul(n, (Qx, Qy, 1))
+Vx, Vy, Vz = ed_simmul(n, (Qx, Qy, 1), m, (Rx, Ry, 1))
+assert (Ux/Uz, Uy/Uz) == mont_to_ed(U[0], U[1])
+assert (Vx/Vz, Vy/Vz) == mont_to_ed(V[0], V[1])
+
+###--------------------------------------------------------------------------
+### Point encoding.
+
+def ed_encode((X, Y, Z)):
+  x, y = X/Z, Y/Z
+  xx, yy = x.lift(), y.lift()
+  if xx%2: yy += 1 << 255
+  return st(yy, 32)
+
+def ed_decode(s):
+  n = ld(s)
+  bit = (n >> 255)&1
+  y = n&((1 << 255) - 1)
+  y2 = y^2
+  x = quosqrt(y2 - 1, d*y2 + 1)
+  if x.lift()%2 != bit: x = -x
+  return (x, y, 1)
+
+###--------------------------------------------------------------------------
+### EdDSA implementation.
+
+def eddsa_splitkey(k):
+  h = hash(k)
+  a = 2^254 + (ld(h[0:32])&((1 << 254) - 8))
+  h1 = h[32:64]
+  return a, h1
+
+def eddsa_pubkey(k):
+  a, h1 = eddsa_splitkey(k)
+  A = ed_mul(a, B)
+  return ed_encode(A)
+
+def eddsa_sign(k, m):
+  K = eddsa_pubkey(k)
+  a, h1 = eddsa_splitkey(k)
+  r = ld(hash(h1, m))%l
+  A = ed_decode(K)
+  R = ed_mul(r, B)
+  RR = ed_encode(R)
+  S = (r + a*ld(hash(RR, K, m)))%l
+  return RR + st(S, 32)
+
+def eddsa_verify(K, m, sig):
+  A = ed_decode(K)
+  R, S = sig[0:32], ld(sig[32:64])
+  h = ld(hash(R, K, m))%l
+  V = ed_simmul(S, B, h, (-A[0], A[1], A[2]))
+  return ed_encode(V) == R
+
+priv = '1acdbb793b0384934627470d795c3d1dd4d79cea59ef983f295b9b59179cbb28'.decode('hex')
+msg = '7cf34f75c3dac9a804d0fcd09eba9b29c9484e8a018fa9e073042df88e3c56'.decode('hex')
+pub = '3f60c7541afa76c019cf5aa82dcdb088ed9e4ed9780514aefb379dabc844f31a'.decode('hex')
+sig = 'be71ef4806cb041d885effd9e6b0fbb73d65d7cdec47a89c8a994892f4e55a568c4cc78d61f901e80dbb628b86a23ccd594e712b57fa94c2d67ec26634878507'.decode('hex')
+assert pub == eddsa_pubkey(priv)
+assert sig == eddsa_sign(priv, msg)
+assert eddsa_verify(pub, msg, sig)
 
 ###----- That's all, folks --------------------------------------------------