utils/gcm-ref: Pull `poly64_mul' and `poly64_redc' out of `poly64_common'.
[catacomb] / utils / gcm-ref
CommitLineData
9e6a4409
MW
1#! /usr/bin/python
2### -*- coding: utf-8 -*-
3
4from sys import argv, exit
5
6import catacomb as C
7
8###--------------------------------------------------------------------------
9### Random utilities.
10
11def words(s):
12 """Split S into 32-bit pieces and report their values as hex."""
13 return ' '.join('%08x' % C.MP.loadb(s[i:i + 4])
14 for i in xrange(0, len(s), 4))
15
16def words_64(s):
17 """Split S into 64-bit pieces and report their values as hex."""
18 return ' '.join('%016x' % C.MP.loadb(s[i:i + 8])
19 for i in xrange(0, len(s), 8))
20
21def repmask(val, wd, n):
22 """Return a mask consisting of N copies of the WD-bit value VAL."""
23 v = C.GF(val)
24 a = C.GF(0)
25 for i in xrange(n): a = (a << wd) | v
26 return a
27
28def combs(things, k):
29 """Iterate over all possible combinations of K of the THINGS."""
30 ii = range(k)
31 n = len(things)
32 while True:
33 yield [things[i] for i in ii]
34 for j in xrange(k):
35 if j == k - 1: lim = n
36 else: lim = ii[j + 1]
37 i = ii[j] + 1
38 if i < lim:
39 ii[j] = i
40 break
41 ii[j] = j
42 else:
43 return
44
45POLYMAP = {}
46
47def poly(nbits):
48 """
49 Return the lexically first irreducible polynomial of degree NBITS of lowest
50 weight.
51 """
52 try: return POLYMAP[nbits]
53 except KeyError: pass
54 base = C.GF(0).setbit(nbits).setbit(0)
55 for k in xrange(1, nbits, 2):
56 for cc in combs(range(1, nbits), k):
601ec68e 57 p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
9e6a4409
MW
58 if p.irreduciblep(): POLYMAP[nbits] = p; return p
59 raise ValueError, nbits
60
61def gcm_mangle(x):
62 """Flip the bits within each byte according to GCM's insane convention."""
63 y = C.WriteBuffer()
64 for b in x:
65 b = ord(b)
66 bb = 0
67 for i in xrange(8):
68 bb <<= 1
69 if b&1: bb |= 1
70 b >>= 1
71 y.putu8(bb)
72 return y.contents
73
74def endswap_words_32(x):
75 """End-swap each 32-bit word of X."""
76 x = C.ReadBuffer(x)
77 y = C.WriteBuffer()
78 while x.left: y.putu32l(x.getu32b())
79 return y.contents
80
81def endswap_words_64(x):
82 """End-swap each 64-bit word of X."""
83 x = C.ReadBuffer(x)
84 y = C.WriteBuffer()
85 while x.left: y.putu64l(x.getu64b())
86 return y.contents
87
88def endswap_bytes(x):
89 """End-swap X by bytes."""
90 y = C.WriteBuffer()
91 for ch in reversed(x): y.put(ch)
92 return y.contents
93
94def gfmask(n):
95 return C.GF(C.MP(0).setbit(n) - 1)
96
97def gcm_mul(x, y):
98 """Multiply X and Y according to the GCM rules."""
99 w = len(x)
100 p = poly(8*w)
101 u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
102 z = (u*v)%p
103 return gcm_mangle(z.storel(w))
104
105DEMOMAP = {}
106def demo(func):
107 name = func.func_name
108 assert(name.startswith('demo_'))
109 DEMOMAP[name[5:].replace('_', '-')] = func
110 return func
111
112def iota(i = 0):
113 vi = [i]
114 def next(): vi[0] += 1; return vi[0] - 1
115 return next
116
117###--------------------------------------------------------------------------
118### Portable table-driven implementation.
119
120def shift_left(x):
121 """Given a field element X (in external format), return X t."""
122 w = len(x)
123 p = poly(8*w)
124 return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p))
125
126def table_common(u, v, flip, getword, ixmask):
127 """
128 Multiply U by V using table lookup; common for `table-b' and `table-l'.
129
4e7475c2 130 This matches the `simple_mulk_...' implementation in `gcm.c'. One entry
9e6a4409
MW
131 per bit is the best we can manage if we want a constant-time
132 implementation: processing n bits at a time means we need to scan
133 (2^n - 1)/n times as much memory.
134
135 * FLIP is a function (assumed to be an involution) on one argument X to
136 convert X from external format to table-entry format or back again.
137
138 * GETWORD is a function on one argument B to retrieve the next 32-bit
139 chunk of a field element held in a `ReadBuffer'. Bits within a word
140 are processed most-significant first.
141
142 * IXMASK is a mask XORed into table indices to permute the table so that
4e7475c2 143 its order matches that induced by GETWORD.
9e6a4409
MW
144
145 The table is built such that tab[i XOR IXMASK] = U t^i.
146 """
147 w = len(u); assert(w == len(v))
148 a = C.ByteString.zero(w)
149 tab = [None]*(8*w)
150 for i in xrange(8*w):
151 print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u))
152 tab[i ^ ixmask] = flip(u)
153 u = shift_left(u)
154 v = C.ReadBuffer(v)
155 i = 0
156 while v.left:
157 t = getword(v)
158 for j in xrange(32):
159 bit = (t >> 31)&1
160 if bit: a ^= tab[i]
161 print ';; %6s = %d: a <- %s [%9s = %s]' % \
162 ('v[%d]' % (i ^ ixmask), bit, words(a),
163 'utab[%d]' % (i ^ ixmask), words(tab[i]))
164 i += 1; t <<= 1
165 return flip(a)
166
167@demo
168def demo_table_b(u, v):
169 """Big-endian table lookup."""
170 return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0)
171
172@demo
173def demo_table_l(u, v):
174 """Little-endian table lookup."""
58094286 175 return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18)
9e6a4409
MW
176
177###--------------------------------------------------------------------------
178### Implementation using 64×64->128-bit binary polynomial multiplication.
179
180_i = iota()
181TAG_INPUT_U = _i()
182TAG_INPUT_V = _i()
183TAG_KPIECE_U = _i()
184TAG_KPIECE_V = _i()
185TAG_PRODPIECE = _i()
186TAG_PRODSUM = _i()
187TAG_PRODUCT = _i()
188TAG_SHIFTED = _i()
189TAG_REDCBITS = _i()
190TAG_REDCFULL = _i()
191TAG_REDCMIX = _i()
192TAG_OUTPUT = _i()
193
194def split_gf(x, n):
195 n /= 8
196 return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)]
197
198def join_gf(xx, n):
199 x = C.GF(0)
200 for i in xrange(len(xx)): x = (x << n) | xx[i]
201 return x
202
203def present_gf(x, w, n, what):
204 firstp = True
205 m = gfmask(n)
206 for i in xrange(0, w, 128):
207 print ';; %12s%c =%s' % \
208 (firstp and what or '',
209 firstp and ':' or ' ',
210 ''.join([j < w
211 and ' 0x%s' % hex(((x >> j)&m).storeb(n/8))
212 or ''
213 for j in xrange(i, i + 128, n)]))
214 firstp = False
215
216def present_gf_pclmul(tag, wd, x, w, n, what):
217 if tag != TAG_PRODPIECE: present_gf(x, w, n, what)
218
219def reverse(x, w):
220 return C.GF.loadl(x.storeb(w/8))
221
222def rev32(x):
223 w = x.noctets
224 m_ffff = repmask(0xffff, 32, w/4)
225 m_ff = repmask(0xff, 16, w/2)
226 x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff)
227 x = ((x&m_ff) << 8) | ((x >> 8)&m_ff)
228 return x
229
230def rev8(x):
231 w = x.noctets
232 m_0f = repmask(0x0f, 8, w)
233 m_33 = repmask(0x33, 8, w)
234 m_55 = repmask(0x55, 8, w)
235 x = ((x&m_0f) << 4) | ((x >> 4)&m_0f)
236 x = ((x&m_33) << 2) | ((x >> 2)&m_33)
237 x = ((x&m_55) << 1) | ((x >> 1)&m_55)
238 return x
239
e29fe90c 240def present_gf_vmullp64(tag, wd, x, w, n, what):
9e6a4409
MW
241 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
242 return
243 elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT:
244 y = x
245 elif (wd == 96 or wd == 192 or wd == 256) and \
246 TAG_PRODSUM <= tag < TAG_OUTPUT:
247 y = x
248 else:
249 xx = x.storeb(w/8)
250 extra = len(xx)%8
251 if extra: xx += C.ByteString.zero(8 - extra)
252 yb = C.WriteBuffer()
253 for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i])
254 y = C.GF.loadb(yb.contents)
255 present_gf(y, (w + 63)&~63, n, what)
256
257def present_gf_pmull(tag, wd, x, w, n, what):
258 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL or tag == TAG_SHIFTED:
259 return
260 elif tag == TAG_INPUT_V or tag == TAG_KPIECE_V:
188ffeae 261 w = (w + 63)&~63
9e6a4409
MW
262 bx = C.ReadBuffer(x.storeb(w/8))
263 by = C.WriteBuffer()
264 while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk)
265 x = C.GF.loadb(by.contents)
266 w *= 2
267 elif TAG_PRODSUM <= tag <= TAG_PRODUCT:
268 x <<= 1
269 y = reverse(rev8(x), w)
270 present_gf(y, w, n, what)
271
272def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat):
273 """
274 Multiply U by V, returning the product.
275
276 This is the fallback long multiplication.
277 """
278
279 uw, vw = 8*len(u), 8*len(v)
280
281 ## We start by carving the operands into 64-bit pieces. This is
282 ## straightforward except for the 96-bit case, where we end up with two
283 ## short pieces which we pad at the beginning.
91a8f888
MW
284 upad = (-uw)%mulwd; u += C.ByteString.zero(upad); uw += upad
285 vpad = (-vw)%mulwd; v += C.ByteString.zero(vpad); vw += vpad
286 uu = split_gf(u, mulwd); vv = split_gf(v, mulwd)
9e6a4409
MW
287
288 ## Report and accumulate the individual product pieces.
289 x = C.GF(0)
290 ulim, vlim = uw/mulwd, vw/mulwd
291 for i in xrange(ulim + vlim - 2, -1, -1):
292 t = C.GF(0)
293 for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)):
294 s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j]
295 presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd,
296 '%s_%d %s_%d' % (uwhat, i - j, vwhat, j))
297 t += s
298 presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd,
299 '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i))
300 x += t << (mulwd*i)
301 presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat))
302
91a8f888 303 return x >> (upad + vpad)
9e6a4409
MW
304
305def poly64_mul_karatsuba(u, v, klimit, presfn, wd,
306 dispwd, mulwd, uwhat, vwhat):
307 """
308 Multiply U by V, returning the product.
309
310 If the length of U and V is at least KLIMIT, and the operands are otherwise
311 suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to
312 `poly64_mul_simple'.
313 """
314 w = 8*len(u)
315
316 if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0:
317 return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat)
318
319 hw = w/2
320 u0, u1 = u[:hw/8], u[hw/8:]
321 v0, v1 = v[:hw/8], v[hw/8:]
322 uu, vv = u0 ^ u1, v0 ^ v1
323
324 presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat)
325 presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat)
326 uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd,
327 '%s*' % uwhat, '%s*' % vwhat)
328
329 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat)
330 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat)
331 u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd,
332 '%s0' % uwhat, '%s0' % vwhat)
333
334 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat)
335 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat)
336 u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd,
337 '%s1' % uwhat, '%s1' % vwhat)
338
339 uvuv = uuvv + u0v0 + u1v1
340 presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat))
341
342 x = u1v1 + (uvuv << hw) + (u0v0 << w)
343 presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat))
344 return x
345
e95b355c 346def poly64_mul(u, v, presfn, dispwd, mulwd, klimit, uwhat, vwhat):
9e6a4409
MW
347 """
348 Multiply U by V using a primitive 64-bit binary polynomial mutliplier.
349
350 Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on
351 x86, and as `vmull.p64'/`pmull' on ARM.
352
353 Operands arrive in a `register format', which is a byte-swapped variant of
354 the external format. Implementations differ on the precise details,
e95b355c 355 though. Returns the double-precision product.
9e6a4409
MW
356 """
357
9e6a4409 358 w = 8*len(u); assert(w == 8*len(v))
e95b355c
MW
359 x = poly64_mul_karatsuba(u, v, klimit, presfn,
360 w, dispwd, mulwd, uwhat, vwhat)
9e6a4409 361
e95b355c 362 return x.storeb(w/4)
9e6a4409 363
e95b355c
MW
364def poly64_redc(y, presfn, dispwd, redcwd):
365 """
366 Reduce a double-precision product X modulo the appropriate polynomial.
367
368 The operand arrives in a `register format', which is a byte-swapped variant
369 of the external format. Implementations differ on the precise details,
370 though. Returns the single-precision reduced value.
371 """
372
373 w = 4*len(y)
374 p = poly(w)
9e6a4409 375
9e6a4409
MW
376 ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i<d} r_i t^i,
377 ## with each r_i either 0 or 1. Because we choose the lexically earliest
378 ## irreducible polynomial with the necessary degree, r_i = 1 happens only
379 ## for a small number of tiny i. In our field, we have t^d = r.
380 ##
381 ## We carve the product into convenient n-bit pieces, for some n dividing d
382 ## -- typically n = 32 or 64. Let d = m n, and write y = SUM_{0<=i<2m} y_i
383 ## t^{ni}. The upper portion, the y_i with i >= m, needs reduction; but
384 ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and
385 ## add it to the bottom half. This all depends on r_i = 0 for all i >=
386 ## n/2. We process each nonzero coefficient of r separately, in two
387 ## passes.
388 ##
389 ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j
390 ## bits (or would be if GCM weren't backwards, but let's not worry about
391 ## that right now). The high j bits will spill over into the next chunk,
392 ## while the low n - j bits will stay where they are. It's these high bits
393 ## which cause trouble -- particularly the high bits of the top chunk,
394 ## since we'll add them on to y_m, which will need further reduction. But
395 ## only the topmost j bits will do this.
396 ##
397 ## The trick is that we do all of the bits which spill over first -- all of
398 ## the top j bits in each chunk, for each j -- in one pass, and then a
399 ## second pass of all the bits which don't. Because j, j' < n/2 for any
400 ## two nonzero coefficient degrees j and j', we have j + j' < n whence j <
401 ## n - j' -- so all of the bits contributed to y_m will be handled in the
402 ## second pass when we handle the bits that don't spill over.
403 rr = [i for i in xrange(1, w) if p.testbit(i)]
404 m = gfmask(redcwd)
405
406 ## Handle the spilling bits.
e95b355c 407 yy = split_gf(y, redcwd)
9e6a4409
MW
408 b = C.GF(0)
409 for rj in rr:
410 br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]]
411 presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj)
412 b += join_gf(br, redcwd) << (w - redcwd)
413 presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b')
e95b355c 414 s = C.GF.loadb(y) + b
9e6a4409
MW
415 presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's')
416
417 ## Handle the nonspilling bits.
418 ss = split_gf(s.storeb(w/4), redcwd)
419 a = C.GF(0)
420 for rj in rr:
421 ar = [si >> rj for si in ss[w/redcwd:]]
422 presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj)
423 a += join_gf(ar, redcwd)
424 presfn(TAG_REDCFULL, w, a, w, dispwd, 'a')
425
426 ## Mix everything together.
427 m = gfmask(w)
428 z = (s&m) + (s >> w) + a
429 presfn(TAG_OUTPUT, w, z, w, dispwd, 'z')
430
431 ## And we're done.
432 return z.storeb(w/8)
433
e95b355c
MW
434def poly64_common(u, v, presfn, dispwd = 32, mulwd = 64,
435 redcwd = 32, klimit = 256):
436 w = 8*len(u)
437 presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
438 presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
439 y = poly64_mul(u, v, presfn, dispwd, mulwd, klimit, "u", "v")
440 y = (C.GF.loadb(y) << 1).storeb(w/4)
441 z = poly64_redc(y, presfn, dispwd, redcwd)
442 return z
443
9e6a4409
MW
444@demo
445def demo_pclmul(u, v):
446 return poly64_common(u, v, presfn = present_gf_pclmul)
447
448@demo
449def demo_vmullp64(u, v):
450 w = 8*len(u)
e29fe90c 451 return poly64_common(u, v, presfn = present_gf_vmullp64,
9e6a4409
MW
452 redcwd = w%64 == 32 and 32 or 64)
453
454@demo
455def demo_pmull(u, v):
456 w = 8*len(u)
457 return poly64_common(u, v, presfn = present_gf_pmull,
458 redcwd = w%64 == 32 and 32 or 64)
459
460###--------------------------------------------------------------------------
461### @@@ Random debris to be deleted. @@@
462
463def cutting_room_floor():
464
465 x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df')
466 y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332')
467
468 u, v = C.GF.loadb(x), C.GF.loadb(y)
469
470 g = u*v << 1
471 print 'y = %s' % words(g.storeb(48))
472 b1 = (g&repmask(0x01, 32, 6)) << 191
473 b2 = (g&repmask(0x03, 32, 6)) << 190
474 b7 = (g&repmask(0x7f, 32, 6)) << 185
475 b = b1 + b2 + b7
476 print 'b = %s' % words(b.storeb(48)[0:28])
477 h = g + b
478 print 'w = %s' % words(h.storeb(48))
479
480 a0 = (h&repmask(0xffffffff, 32, 6)) << 192
481 a1 = (h&repmask(0xfffffffe, 32, 6)) << 191
482 a2 = (h&repmask(0xfffffffc, 32, 6)) << 190
483 a7 = (h&repmask(0xffffff80, 32, 6)) << 185
484 a = a0 + a1 + a2 + a7
485
486 print ' a_1 = %s' % words(a1.storeb(48)[0:24])
487 print ' a_2 = %s' % words(a2.storeb(48)[0:24])
488 print ' a_7 = %s' % words(a7.storeb(48)[0:24])
489
490 print 'low+unit = %s' % words((h + a0).storeb(48)[0:24])
491 print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24])
492 print ' 1,7 = %s' % words((a1 + a7).storeb(48)[0:24])
493
494 print 'a = %s' % words(a.storeb(48)[0:24])
495 z = h + a
496 print 'z = %s' % words(z.storeb(48))
497
498 z = gcm_mul(x, y)
499 print 'u v mod p = %s' % words(z)
500
501###--------------------------------------------------------------------------
502### Main program.
503
504style = argv[1]
505u = C.bytes(argv[2])
506v = C.bytes(argv[3])
507zz = DEMOMAP[style](u, v)
508assert zz == gcm_mul(u, v)
509
510###----- That's all, folks --------------------------------------------------