#! /usr/local/bin/sage
### -*- mode: python; coding: utf-8 -*-
+import hashlib as H
+
###--------------------------------------------------------------------------
### Some general utilities.
+def hash(*m):
+ h = H.sha512()
+ for i in m: h.update(i)
+ return h.digest()
+
def ld(v):
return 0 + sum(ord(v[i]) << 8*i for i in xrange(len(v)))
t = u*t11 # 265 | 2^255 - 21
return t
-def quosqrt(x, y):
+def quosqrt_djb(x, y):
## First, some preliminary values.
y2 = sqrn(y, 1) # 1 | 0, 2
elif t == -x: return beta*sqrtm1
else: raise ValueError, 'not a square'
+def quosqrt_mdw(x, y):
+ v = x*y
+
+ ## Now we calculate w = v^{3*2^252 - 8}. This will be explained later.
+ u = sqrn(v, 1) # 1 | 2
+ t = u*v # 2 | 3
+ u = sqrn(t, 2) # 4 | 12
+ t15 = u*t # 5 | 15
+ u = sqrn(t15, 1) # 6 | 30
+ t = u*v # 7 | 31 = 2^5 - 1
+ u = sqrn(t, 5) # 12 | 2^10 - 2^5
+ t = u*t # 13 | 2^10 - 1
+ u = sqrn(t, 10) # 23 | 2^20 - 2^10
+ u = u*t # 24 | 2^20 - 1
+ u = sqrn(u, 10) # 34 | 2^30 - 2^10
+ t = u*t # 35 | 2^30 - 1
+ u = sqrn(t, 1) # 36 | 2^31 - 2
+ t = u*v # 37 | 2^31 - 1
+ u = sqrn(t, 31) # 68 | 2^62 - 2^31
+ t = u*t # 69 | 2^62 - 1
+ u = sqrn(t, 62) # 131 | 2^124 - 2^62
+ t = u*t # 132 | 2^124 - 1
+ u = sqrn(t, 124) # 256 | 2^248 - 2^124
+ t = u*t # 257 | 2^248 - 1
+ u = sqrn(t, 1) # 258 | 2^249 - 2
+ t = u*v # 259 | 2^249 - 1
+ t = sqrn(t, 3) # 262 | 2^252 - 8
+ u = sqrn(t, 1) # 263 | 2^253 - 16
+ t = u*t # 264 | 3*2^252 - 24
+ t = t*t15 # 265 | 3*2^252 - 9
+ w = t*v # 266 | 3*2^252 - 8
+
+ ## Awesome. Now let me explain. Let v be a square in GF(p), and let w =
+ ## v^(3*2^252 - 8). In particular, let's consider
+ ##
+ ## v^2 w^4 = v^2 v^{3*2^254 - 32} = (v^{2^254 - 10})^3
+ ##
+ ## But 2^254 - 10 = ((2^255 - 19) - 1)/2 = (p - 1)/2. Since v is a square,
+ ## it has order dividing (p - 1)/2, and therefore v^2 w^4 = 1 and
+ ##
+ ## w^4 = 1/v^2
+ ##
+ ## That in turn implies that w^2 = ±1/v. Now, recall that v = x y, and let
+ ## w' = w x. Then w'^2 = ±x^2/v = ±x/y. If y w'^2 = x then we set
+ ## z = w', since we have z^2 = x/y; otherwise let z = i w', where i^2 = -1,
+ ## so z^2 = -w^2 = x/y, and we're done.
+ t = w*x
+ u = y*t^2
+ if u == x: return t
+ elif u == -x: return t*sqrtm1
+ else: raise ValueError, 'not a square'
+
+quosqrt = quosqrt_mdw
+
assert inv(k(9))*9 == 1
assert 5*quosqrt(k(4), k(5))^2 == 4
###--------------------------------------------------------------------------
### The Montgomery ladder.
-A0 = (A - 2)/4
-
def x25519(n, x1):
## Let Q = (x_1 : y_1 : 1) be an input point. We calculate
assert x25519(y, k(9)) == Y[0]
assert x25519(x, Y[0]) == x25519(y, X[0]) == Z[0]
+###--------------------------------------------------------------------------
+### Edwards curve parameters and conversion.
+
+a = k(-1)
+d = -A0/(A0 + 1)
+
+def mont_to_ed(u, v):
+ return sqrt(-A - 2)*u/v, (u - 1)/(u + 1)
+
+def ed_to_mont(x, y):
+ u = (1 + y)/(1 - y)
+ v = sqrt(-A - 2)*u/x
+ return u, v
+
+Bx, By = mont_to_ed(P[0], P[1])
+if Bx.lift()%2: Bx = -Bx
+B = (Bx, By, 1)
+u, v = ed_to_mont(Bx, By)
+
+assert By == k(4/5)
+assert -Bx^2 + By^2 == 1 + d*Bx^2*By^2
+assert u == k(9)
+assert v == P[1] or v == -P[1]
+
+###--------------------------------------------------------------------------
+### Edwards point addition and doubling.
+
+def ed_add((X1, Y1, Z1), (X2, Y2, Z2)):
+ A = Z1*Z2
+ B = A^2
+ C = X1*X2
+ D = Y1*Y2
+ E = d*C*D
+ F = B - E
+ G = B + E
+ X3 = A*F*((X1 + Y1)*(X2 + Y2) - C - D)
+ Y3 = A*G*(D - a*C)
+ Z3 = F*G
+ return X3, Y3, Z3
+
+def ed_dbl((X1, Y1, Z1)):
+ B = (X1 + Y1)^2
+ C = X1^2
+ D = Y1^2
+ E = a*C
+ F = E + D
+ H = Z1^2
+ J = F - 2*H
+ X3 = (B - C - D)*J
+ Y3 = F*(E - D)
+ Z3 = F*J
+ return X3, Y3, Z3
+
+Q = E.random_point()
+R = E.random_point()
+n = ZZ(randint(0, 2^255 - 1))
+m = ZZ(randint(0, 2^255 - 1))
+Qx, Qy = mont_to_ed(Q[0], Q[1])
+Rx, Ry = mont_to_ed(R[0], R[1])
+
+S = Q + R; T = 2*Q
+Sx, Sy, Sz = ed_add((Qx, Qy, 1), (Rx, Ry, 1))
+Tx, Ty, Tz = ed_dbl((Qx, Qy, 1))
+assert (Sx/Sz, Sy/Sz) == mont_to_ed(S[0], S[1])
+assert (Tx/Tz, Ty/Tz) == mont_to_ed(T[0], T[1])
+
+###--------------------------------------------------------------------------
+### Scalar multiplication.
+
+def ed_mul(n, Q):
+ winwd = 4
+ winlim = 1 << winwd
+ winmask = winlim - 1
+ tabsz = winlim/2 + 1
+
+ ## Recode the scalar to roughly-balanced form.
+ nn = [(n >> i)&winmask for i in xrange(0, n.nbits() + winwd, winwd)]
+ for i in xrange(len(nn) - 2, -1, -1):
+ if nn[i] >= winlim/2:
+ nn[i] -= winlim
+ nn[i + 1] += 1
+
+ ## Build the table of small multiples.
+ V = tabsz*[None]
+ V[0] = (0, 1, 1)
+ V[1] = Q
+ V[2] = ed_dbl(V[1])
+ for i in xrange(3, tabsz, 2):
+ V[i] = ed_add(V[i - 1], Q)
+ V[i + 1] = ed_dbl(V[(i + 1)/2])
+
+ ## Do the multiplication.
+ T = V[0]
+ for i in xrange(len(nn) - 1, -1, -1):
+ w = nn[i]
+ for j in xrange(winwd): T = ed_dbl(T)
+ if w >= 0: T = ed_add(T, V[w])
+ else: x, y, z = V[-w]; T = ed_add(T, (-x, y, z))
+
+ ## Done.
+ return T
+
+def ed_simmul(n0, Q0, n1, Q1):
+ winwd = 2
+ winlim = 1 << winwd
+ winmask = winlim - 1
+ tabsz = 1 << 2*winwd
+
+ ## Extract the scalar pieces.
+ nn = [(n0 >> i)&winmask | (((n1 >> i)&winmask) << winwd)
+ for i in xrange(0, max(n0.nbits(), n1.nbits()), winwd)]
+
+ ## Build the table of small linear combinations.
+ V = tabsz*[None]
+ V[0] = (0, 1, 1)
+ V[1] = Q0; V[winlim] = Q1
+ i = 2
+ while i < winlim:
+ V[i] = ed_dbl(V[i/2])
+ V[i*winlim] = ed_dbl(V[i*winlim/2])
+ i <<= 1
+ i = 2
+ while i < tabsz:
+ for j in xrange(1, i):
+ V[i + j] = ed_add(V[i], V[j])
+ i <<= 1
+
+ ## Do the multiplication.
+ T = V[0]
+ for i in xrange(len(nn) - 1, -1, -1):
+ w = nn[i]
+ for j in xrange(winwd): T = ed_dbl(T)
+ T = ed_add(T, V[w])
+
+ ## Done.
+ return T
+
+U = n*Q; V = n*Q + m*R
+Ux, Uy, Uz = ed_mul(n, (Qx, Qy, 1))
+Vx, Vy, Vz = ed_simmul(n, (Qx, Qy, 1), m, (Rx, Ry, 1))
+assert (Ux/Uz, Uy/Uz) == mont_to_ed(U[0], U[1])
+assert (Vx/Vz, Vy/Vz) == mont_to_ed(V[0], V[1])
+
+###--------------------------------------------------------------------------
+### Point encoding.
+
+def ed_encode((X, Y, Z)):
+ x, y = X/Z, Y/Z
+ xx, yy = x.lift(), y.lift()
+ if xx%2: yy += 1 << 255
+ return st(yy, 32)
+
+def ed_decode(s):
+ n = ld(s)
+ bit = (n >> 255)&1
+ y = n&((1 << 255) - 1)
+ y2 = y^2
+ x = quosqrt(y2 - 1, d*y2 + 1)
+ if x.lift()%2 != bit: x = -x
+ return (x, y, 1)
+
+###--------------------------------------------------------------------------
+### EdDSA implementation.
+
+def eddsa_splitkey(k):
+ h = hash(k)
+ a = 2^254 + (ld(h[0:32])&((1 << 254) - 8))
+ h1 = h[32:64]
+ return a, h1
+
+def eddsa_pubkey(k):
+ a, h1 = eddsa_splitkey(k)
+ A = ed_mul(a, B)
+ return ed_encode(A)
+
+def eddsa_sign(k, m):
+ K = eddsa_pubkey(k)
+ a, h1 = eddsa_splitkey(k)
+ r = ld(hash(h1, m))%l
+ A = ed_decode(K)
+ R = ed_mul(r, B)
+ RR = ed_encode(R)
+ S = (r + a*ld(hash(RR, K, m)))%l
+ return RR + st(S, 32)
+
+def eddsa_verify(K, m, sig):
+ A = ed_decode(K)
+ R, S = sig[0:32], ld(sig[32:64])
+ h = ld(hash(R, K, m))%l
+ V = ed_simmul(S, B, h, (-A[0], A[1], A[2]))
+ return ed_encode(V) == R
+
+priv = '1acdbb793b0384934627470d795c3d1dd4d79cea59ef983f295b9b59179cbb28'.decode('hex')
+msg = '7cf34f75c3dac9a804d0fcd09eba9b29c9484e8a018fa9e073042df88e3c56'.decode('hex')
+pub = '3f60c7541afa76c019cf5aa82dcdb088ed9e4ed9780514aefb379dabc844f31a'.decode('hex')
+sig = 'be71ef4806cb041d885effd9e6b0fbb73d65d7cdec47a89c8a994892f4e55a568c4cc78d61f901e80dbb628b86a23ccd594e712b57fa94c2d67ec26634878507'.decode('hex')
+assert pub == eddsa_pubkey(priv)
+assert sig == eddsa_sign(priv, msg)
+assert eddsa_verify(pub, msg, sig)
+
###----- That's all, folks --------------------------------------------------