Merge branch '2.3.x'
[catacomb] / utils / curve25519.sage
1 #! /usr/local/bin/sage
2 ### -*- mode: python; coding: utf-8 -*-
3
4 import hashlib as H
5
6 ###--------------------------------------------------------------------------
7 ### Some general utilities.
8
9 def hash(*m):
10 h = H.sha512()
11 for i in m: h.update(i)
12 return h.digest()
13
14 def ld(v):
15 return 0 + sum(ord(v[i]) << 8*i for i in xrange(len(v)))
16
17 def st(x, n):
18 return ''.join(chr((x >> 8*i)&0xff) for i in xrange(n))
19
20 def piece_widths_offsets(wd, n):
21 o = [ceil(wd*i/n) for i in xrange(n + 1)]
22 w = [o[i + 1] - o[i] for i in xrange(n)]
23 return w, o
24
25 def pieces(x, wd, n, bias = 0):
26
27 ## Figure out widths and offsets.
28 w, o = piece_widths_offsets(wd, n)
29
30 ## First, normalize |n| < bias/2.
31 if bias and n >= bias/2: n -= bias
32
33 ## First, collect the bits.
34 nn = []
35 for i in xrange(n - 1):
36 m = (1 << w[i]) - 1
37 nn.append(x&m)
38 x >>= w[i]
39 nn.append(x)
40
41 ## Now normalize them to the appropriate interval.
42 c = 0
43 for i in xrange(n - 1):
44 b = 1 << (w[i] - 1)
45 if nn[i] >= b:
46 nn[i] -= 2*b
47 nn[i + 1] += 1
48
49 ## And we're done.
50 return nn
51
52 def combine(v, wd, n):
53 w, o = piece_widths_offsets(wd, n)
54 return sum(v[i] << o[i] for i in xrange(n))
55
56 ###--------------------------------------------------------------------------
57 ### Define the curve.
58
59 p = 2^255 - 19; k = GF(p)
60 A = k(486662); A0 = (A - 2)/4
61 E = EllipticCurve(k, [0, A, 0, 1, 0]); P = E.lift_x(9)
62 l = 2^252 + 27742317777372353535851937790883648493
63
64 assert is_prime(l)
65 assert (l*P).is_zero()
66 assert (p + 1 - 8*l)^2 <= 4*p
67
68 ###--------------------------------------------------------------------------
69 ### Example points from `Cryptography in NaCl'.
70
71 x = ld(map(chr, [0x70,0x07,0x6d,0x0a,0x73,0x18,0xa5,0x7d
72 ,0x3c,0x16,0xc1,0x72,0x51,0xb2,0x66,0x45
73 ,0xdf,0x4c,0x2f,0x87,0xeb,0xc0,0x99,0x2a
74 ,0xb1,0x77,0xfb,0xa5,0x1d,0xb9,0x2c,0x6a]))
75 y = ld(map(chr, [0x58,0xab,0x08,0x7e,0x62,0x4a,0x8a,0x4b
76 ,0x79,0xe1,0x7f,0x8b,0x83,0x80,0x0e,0xe6
77 ,0x6f,0x3b,0xb1,0x29,0x26,0x18,0xb6,0xfd
78 ,0x1c,0x2f,0x8b,0x27,0xff,0x88,0xe0,0x6b]))
79 X = x*P
80 Y = y*P
81 Z = x*Y
82 assert Z == y*X
83
84 ###--------------------------------------------------------------------------
85 ### Arithmetic implementation.
86
87 def sqrn(x, n):
88 for i in xrange(n): x = x*x
89 return x
90
91 sqrtm1 = sqrt(k(-1))
92
93 def inv(x):
94 t2 = sqrn(x, 1) # 1 | 2
95 u = sqrn(t2, 2) # 3 | 8
96 t = u*x # 4 | 9
97 t11 = t*t2 # 5 | 11
98 u = sqrn(t11, 1) # 6 | 22
99 t = u*t # 7 | 2^5 - 1 = 31
100 u = sqrn(t, 5) # 12 | 2^10 - 2^5
101 t2p10m1 = u*t # 13 | 2^10 - 1
102 u = sqrn(t2p10m1, 10) # 23 | 2^20 - 2^10
103 t = u*t2p10m1 # 24 | 2^20 - 1
104 u = sqrn(t, 20) # 44 | 2^40 - 2^20
105 t = u*t # 45 | 2^40 - 1
106 u = sqrn(t, 10) # 55 | 2^50 - 2^10
107 t2p50m1 = u*t2p10m1 # 56 | 2^50 - 1
108 u = sqrn(t2p50m1, 50) # 106 | 2^100 - 2^50
109 t = u*t2p50m1 # 107 | 2^100 - 1
110 u = sqrn(t, 100) # 207 | 2^200 - 2^100
111 t = u*t # 208 | 2^200 - 1
112 u = sqrn(t, 50) # 258 | 2^250 - 2^50
113 t = u*t2p50m1 # 259 | 2^250 - 1
114 u = sqrn(t, 5) # 264 | 2^255 - 2^5
115 t = u*t11 # 265 | 2^255 - 21
116 return t
117
118 def quosqrt(x, y):
119
120 ## First, some preliminary values.
121 y2 = sqrn(y, 1) # 1 | 0, 2
122 y3 = y2*y # 2 | 0, 3
123 xy3 = x*y3 # 3 | 1, 3
124 y4 = sqrn(y2, 1) # 4 | 0, 4
125 w = xy3*y4 # 5 | 1, 7
126
127 ## Now calculate w^(p - 5)/8. Notice that (p - 5)/8 =
128 ## (2^255 - 24)/8 = 2^252 - 3.
129 u = sqrn(w, 1) # 6 | 2
130 t = u*w # 7 | 3
131 u = sqrn(t, 1) # 8 | 6
132 t = u*w # 9 | 7
133 u = sqrn(t, 3) # 12 | 56
134 t = u*t # 13 | 63 = 2^6 - 1
135 u = sqrn(t, 6) # 19 | 2^12 - 2^6
136 t = u*t # 20 | 2^12 - 1
137 u = sqrn(t, 12) # 32 | 2^24 - 2^12
138 t = u*t # 33 | 2^24 - 1
139 u = sqrn(t, 1) # 34 | 2^25 - 2
140 t = u*w # 35 | 2^25 - 1
141 u = sqrn(t, 25) # 60 | 2^50 - 2^25
142 t2p50m1 = u*t # 61 | 2^50 - 1
143 u = sqrn(t2p50m1, 50) # 111 | 2^100 - 2^50
144 t = u*t2p50m1 # 112 | 2^100 - 1
145 u = sqrn(t, 100) # 212 | 2^200 - 2^100
146 t = u*t # 213 | 2^200 - 1
147 u = sqrn(t, 50) # 263 | 2^250 - 2^50
148 t = u*t2p50m1 # 264 | 2^250 - 1
149 u = sqrn(t, 2) # 266 | 2^252 - 4
150 t = u*w # 267 | 2^252 - 3
151 beta = t*xy3 # 268 |
152
153 ## Now we have beta = (x y^3) (x y^7)^((p - 5)/8) =
154 ## x^((p + 3)/8) y^((7 p - 11)/8) = (x/y)^((p + 3)/8).
155 ## Suppose alpha^2 = x/y. Then beta^4 = (x/y)^((p + 3)/2) =
156 ## alpha^(p + 3) = alpha^4 = (x/y)^2, so beta^2 = ±x/y. If
157 ## y beta^2 = x then alpha = beta and we're done; if
158 ## y beta^2 = -x, then alpha = beta sqrt(-1); otherwise x/y
159 ## wasn't actually a square after all.
160 t = y*beta^2
161 if t == x: return beta
162 elif t == -x: return beta*sqrtm1
163 else: raise ValueError, 'not a square'
164
165 assert inv(k(9))*9 == 1
166 assert 5*quosqrt(k(4), k(5))^2 == 4
167
168 ###--------------------------------------------------------------------------
169 ### The Montgomery ladder.
170
171 A0 = (A - 2)/4
172
173 def x25519(n, x1):
174
175 ## Let Q = (x_1 : y_1 : 1) be an input point. We calculate
176 ## n Q = (x_n : y_n : z_n), returning x_n/z_n (unless z_n = 0,
177 ## in which case we return zero).
178 ##
179 ## We're given that n = 2^254 + n'_254, where 0 <= n'_254 < 2^254.
180 bb = n.bits()
181 x, z = 1, 0
182 u, w = x1, 1
183
184 ## Initially, let i = 255.
185 for i in xrange(len(bb) - 1, -1, -1):
186
187 ## Split n = n_i 2^i + n'_i, where 0 <= n'_i < 2^i, so n_0 = n.
188 ## We have x, z = x_{n_{i+1}}, z_{n_{i+1}}, and
189 ## u, w = x_{n_{i+1}+1}, z_{n_{i+1}+1}.
190 ## Now either n_i = 2 n_{i+1} or n_i = 2 n_{i+1} + 1, depending
191 ## on bit i of n.
192
193 ## Swap (x : z) and (u : w) if bit i of n is set.
194 if bb[i]: x, z, u, w = u, w, x, z
195
196 ## Do the ladder step.
197 xmz, xpz = x - z, x + z
198 umw, upw = u - w, u + w
199 xmz2, xpz2 = xmz^2, xpz^2
200 xpz2mxmz2 = xpz2 - xmz2
201 xmzupw, xpzumw = xmz*upw, xpz*umw
202 x, z = xmz2*xpz2, xpz2mxmz2*(xpz2 + A0*xpz2mxmz2)
203 u, w = (xmzupw + xpzumw)^2, x1*(xmzupw - xpzumw)^2
204
205 ## Finally, unswap.
206 if bb[i]: x, z, u, w = u, w, x, z
207
208 ## Almost done.
209 return x*inv(z)
210
211 assert x25519(y, k(9)) == Y[0]
212 assert x25519(x, Y[0]) == x25519(y, X[0]) == Z[0]
213
214 ###--------------------------------------------------------------------------
215 ### Edwards curve parameters and conversion.
216
217 a = k(-1)
218 d = k(-A0/(A0 + 1))
219
220 def mont_to_ed(u, v):
221 return sqrt(-A - 2)*u/v, (u - 1)/(u + 1)
222
223 def ed_to_mont(x, y):
224 u = (1 + y)/(1 - y)
225 v = sqrt(-A - 2)*u/x
226 return u, v
227
228 Bx, By = mont_to_ed(P[0], P[1])
229 if Bx.lift()%2: Bx = -Bx
230 B = (Bx, By, 1)
231 u, v = ed_to_mont(Bx, By)
232
233 assert By == k(4/5)
234 assert -Bx^2 + By^2 == 1 + d*Bx^2*By^2
235 assert u == k(9)
236 assert v == P[1] or v == -P[1]
237
238 ###--------------------------------------------------------------------------
239 ### Edwards point addition and doubling.
240
241 def ed_add((X1, Y1, Z1), (X2, Y2, Z2)):
242 A = Z1*Z2
243 B = A^2
244 C = X1*X2
245 D = Y1*Y2
246 E = d*C*D
247 F = B - E
248 G = B + E
249 X3 = A*F*((X1 + Y1)*(X2 + Y2) - C - D)
250 Y3 = A*G*(D - a*C)
251 Z3 = F*G
252 return X3, Y3, Z3
253
254 def ed_dbl((X1, Y1, Z1)):
255 B = (X1 + Y1)^2
256 C = X1^2
257 D = Y1^2
258 E = a*C
259 F = E + D
260 H = Z1^2
261 J = F - 2*H
262 X3 = (B - C - D)*J
263 Y3 = F*(E - D)
264 Z3 = F*J
265 return X3, Y3, Z3
266
267 Q = E.random_point()
268 R = E.random_point()
269 n = ZZ(randint(0, 2^255 - 1))
270 m = ZZ(randint(0, 2^255 - 1))
271 Qx, Qy = mont_to_ed(Q[0], Q[1])
272 Rx, Ry = mont_to_ed(R[0], R[1])
273
274 S = Q + R; T = 2*Q
275 Sx, Sy, Sz = ed_add((Qx, Qy, 1), (Rx, Ry, 1))
276 Tx, Ty, Tz = ed_dbl((Qx, Qy, 1))
277 assert (Sx/Sz, Sy/Sz) == mont_to_ed(S[0], S[1])
278 assert (Tx/Tz, Ty/Tz) == mont_to_ed(T[0], T[1])
279
280 ###--------------------------------------------------------------------------
281 ### Scalar multiplication.
282
283 def ed_mul(n, Q):
284 winwd = 4
285 winlim = 1 << winwd
286 winmask = winlim - 1
287 tabsz = winlim/2 + 1
288
289 ## Recode the scalar to roughly-balanced form.
290 nn = [(n >> i)&winmask for i in xrange(0, n.nbits() + winwd, winwd)]
291 for i in xrange(len(nn) - 2, -1, -1):
292 if nn[i] >= winlim/2:
293 nn[i] -= winlim
294 nn[i + 1] += 1
295
296 ## Build the table of small multiples.
297 V = tabsz*[None]
298 V[0] = (0, 1, 1)
299 V[1] = Q
300 V[2] = ed_dbl(V[1])
301 for i in xrange(3, tabsz, 2):
302 V[i] = ed_add(V[i - 1], Q)
303 V[i + 1] = ed_dbl(V[(i + 1)/2])
304
305 ## Do the multiplication.
306 T = V[0]
307 for i in xrange(len(nn) - 1, -1, -1):
308 w = nn[i]
309 for j in xrange(winwd): T = ed_dbl(T)
310 if w >= 0: T = ed_add(T, V[w])
311 else: x, y, z = V[-w]; T = ed_add(T, (-x, y, z))
312
313 ## Done.
314 return T
315
316 def ed_simmul(n0, Q0, n1, Q1):
317 winwd = 2
318 winlim = 1 << winwd
319 winmask = winlim - 1
320 tabsz = 1 << 2*winwd
321
322 ## Extract the scalar pieces.
323 nn = [(n0 >> i)&winmask | (((n1 >> i)&winmask) << winwd)
324 for i in xrange(0, max(n0.nbits(), n1.nbits()), winwd)]
325
326 ## Build the table of small linear combinations.
327 V = tabsz*[None]
328 V[0] = (0, 1, 1)
329 V[1] = Q0; V[winlim] = Q1
330 i = 2
331 while i < winlim:
332 V[i] = ed_dbl(V[i/2])
333 V[i*winlim] = ed_dbl(V[i*winlim/2])
334 i <<= 1
335 i = 2
336 while i < tabsz:
337 for j in xrange(1, i):
338 V[i + j] = ed_add(V[i], V[j])
339 i <<= 1
340
341 ## Do the multiplication.
342 T = V[0]
343 for i in xrange(len(nn) - 1, -1, -1):
344 w = nn[i]
345 for j in xrange(winwd): T = ed_dbl(T)
346 T = ed_add(T, V[w])
347
348 ## Done.
349 return T
350
351 U = n*Q; V = n*Q + m*R
352 Ux, Uy, Uz = ed_mul(n, (Qx, Qy, 1))
353 Vx, Vy, Vz = ed_simmul(n, (Qx, Qy, 1), m, (Rx, Ry, 1))
354 assert (Ux/Uz, Uy/Uz) == mont_to_ed(U[0], U[1])
355 assert (Vx/Vz, Vy/Vz) == mont_to_ed(V[0], V[1])
356
357 ###--------------------------------------------------------------------------
358 ### Point encoding.
359
360 def ed_encode((X, Y, Z)):
361 x, y = X/Z, Y/Z
362 xx, yy = x.lift(), y.lift()
363 if xx%2: yy += 1 << 255
364 return st(yy, 32)
365
366 def ed_decode(s):
367 n = ld(s)
368 bit = (n >> 255)&1
369 y = n&((1 << 255) - 1)
370 y2 = y^2
371 x = quosqrt(y2 - 1, d*y2 + 1)
372 if x.lift()%2 != bit: x = -x
373 return (x, y, 1)
374
375 ###--------------------------------------------------------------------------
376 ### EdDSA implementation.
377
378 def eddsa_splitkey(k):
379 h = hash(k)
380 a = 2^254 + (ld(h[0:32])&((1 << 254) - 8))
381 h1 = h[32:64]
382 return a, h1
383
384 def eddsa_pubkey(k):
385 a, h1 = eddsa_splitkey(k)
386 A = ed_mul(a, B)
387 return ed_encode(A)
388
389 def eddsa_sign(k, m):
390 K = eddsa_pubkey(k)
391 a, h1 = eddsa_splitkey(k)
392 r = ld(hash(h1, m))%l
393 A = ed_decode(K)
394 R = ed_mul(r, B)
395 RR = ed_encode(R)
396 S = (r + a*ld(hash(RR, K, m)))%l
397 return RR + st(S, 32)
398
399 def eddsa_verify(K, m, sig):
400 A = ed_decode(K)
401 R, S = sig[0:32], ld(sig[32:64])
402 h = ld(hash(R, K, m))%l
403 V = ed_simmul(S, B, h, (-A[0], A[1], A[2]))
404 return ed_encode(V) == R
405
406 priv = '1acdbb793b0384934627470d795c3d1dd4d79cea59ef983f295b9b59179cbb28'.decode('hex')
407 msg = '7cf34f75c3dac9a804d0fcd09eba9b29c9484e8a018fa9e073042df88e3c56'.decode('hex')
408 pub = '3f60c7541afa76c019cf5aa82dcdb088ed9e4ed9780514aefb379dabc844f31a'.decode('hex')
409 sig = 'be71ef4806cb041d885effd9e6b0fbb73d65d7cdec47a89c8a994892f4e55a568c4cc78d61f901e80dbb628b86a23ccd594e712b57fa94c2d67ec26634878507'.decode('hex')
410 assert pub == eddsa_pubkey(priv)
411 assert sig == eddsa_sign(priv, msg)
412 assert eddsa_verify(pub, msg, sig)
413
414 ###----- That's all, folks --------------------------------------------------