utils/gcm-ref: Fix embarrassing mistakes in comments.
[catacomb] / utils / gcm-ref
CommitLineData
9e6a4409
MW
1#! /usr/bin/python
2### -*- coding: utf-8 -*-
3
4from sys import argv, exit
5
6import catacomb as C
7
8###--------------------------------------------------------------------------
9### Random utilities.
10
11def words(s):
12 """Split S into 32-bit pieces and report their values as hex."""
13 return ' '.join('%08x' % C.MP.loadb(s[i:i + 4])
14 for i in xrange(0, len(s), 4))
15
16def words_64(s):
17 """Split S into 64-bit pieces and report their values as hex."""
18 return ' '.join('%016x' % C.MP.loadb(s[i:i + 8])
19 for i in xrange(0, len(s), 8))
20
21def repmask(val, wd, n):
22 """Return a mask consisting of N copies of the WD-bit value VAL."""
23 v = C.GF(val)
24 a = C.GF(0)
25 for i in xrange(n): a = (a << wd) | v
26 return a
27
28def combs(things, k):
29 """Iterate over all possible combinations of K of the THINGS."""
30 ii = range(k)
31 n = len(things)
32 while True:
33 yield [things[i] for i in ii]
34 for j in xrange(k):
35 if j == k - 1: lim = n
36 else: lim = ii[j + 1]
37 i = ii[j] + 1
38 if i < lim:
39 ii[j] = i
40 break
41 ii[j] = j
42 else:
43 return
44
45POLYMAP = {}
46
47def poly(nbits):
48 """
49 Return the lexically first irreducible polynomial of degree NBITS of lowest
50 weight.
51 """
52 try: return POLYMAP[nbits]
53 except KeyError: pass
54 base = C.GF(0).setbit(nbits).setbit(0)
55 for k in xrange(1, nbits, 2):
56 for cc in combs(range(1, nbits), k):
601ec68e 57 p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
9e6a4409
MW
58 if p.irreduciblep(): POLYMAP[nbits] = p; return p
59 raise ValueError, nbits
60
61def gcm_mangle(x):
62 """Flip the bits within each byte according to GCM's insane convention."""
63 y = C.WriteBuffer()
64 for b in x:
65 b = ord(b)
66 bb = 0
67 for i in xrange(8):
68 bb <<= 1
69 if b&1: bb |= 1
70 b >>= 1
71 y.putu8(bb)
72 return y.contents
73
74def endswap_words_32(x):
75 """End-swap each 32-bit word of X."""
76 x = C.ReadBuffer(x)
77 y = C.WriteBuffer()
78 while x.left: y.putu32l(x.getu32b())
79 return y.contents
80
81def endswap_words_64(x):
82 """End-swap each 64-bit word of X."""
83 x = C.ReadBuffer(x)
84 y = C.WriteBuffer()
85 while x.left: y.putu64l(x.getu64b())
86 return y.contents
87
88def endswap_bytes(x):
89 """End-swap X by bytes."""
90 y = C.WriteBuffer()
91 for ch in reversed(x): y.put(ch)
92 return y.contents
93
94def gfmask(n):
95 return C.GF(C.MP(0).setbit(n) - 1)
96
97def gcm_mul(x, y):
98 """Multiply X and Y according to the GCM rules."""
99 w = len(x)
100 p = poly(8*w)
101 u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
102 z = (u*v)%p
103 return gcm_mangle(z.storel(w))
104
105DEMOMAP = {}
106def demo(func):
107 name = func.func_name
108 assert(name.startswith('demo_'))
109 DEMOMAP[name[5:].replace('_', '-')] = func
110 return func
111
112def iota(i = 0):
113 vi = [i]
114 def next(): vi[0] += 1; return vi[0] - 1
115 return next
116
117###--------------------------------------------------------------------------
118### Portable table-driven implementation.
119
120def shift_left(x):
121 """Given a field element X (in external format), return X t."""
122 w = len(x)
123 p = poly(8*w)
124 return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p))
125
126def table_common(u, v, flip, getword, ixmask):
127 """
128 Multiply U by V using table lookup; common for `table-b' and `table-l'.
129
4e7475c2 130 This matches the `simple_mulk_...' implementation in `gcm.c'. One entry
9e6a4409
MW
131 per bit is the best we can manage if we want a constant-time
132 implementation: processing n bits at a time means we need to scan
133 (2^n - 1)/n times as much memory.
134
135 * FLIP is a function (assumed to be an involution) on one argument X to
136 convert X from external format to table-entry format or back again.
137
138 * GETWORD is a function on one argument B to retrieve the next 32-bit
139 chunk of a field element held in a `ReadBuffer'. Bits within a word
140 are processed most-significant first.
141
142 * IXMASK is a mask XORed into table indices to permute the table so that
4e7475c2 143 its order matches that induced by GETWORD.
9e6a4409
MW
144
145 The table is built such that tab[i XOR IXMASK] = U t^i.
146 """
147 w = len(u); assert(w == len(v))
148 a = C.ByteString.zero(w)
149 tab = [None]*(8*w)
150 for i in xrange(8*w):
151 print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u))
152 tab[i ^ ixmask] = flip(u)
153 u = shift_left(u)
154 v = C.ReadBuffer(v)
155 i = 0
156 while v.left:
157 t = getword(v)
158 for j in xrange(32):
159 bit = (t >> 31)&1
160 if bit: a ^= tab[i]
161 print ';; %6s = %d: a <- %s [%9s = %s]' % \
162 ('v[%d]' % (i ^ ixmask), bit, words(a),
163 'utab[%d]' % (i ^ ixmask), words(tab[i]))
164 i += 1; t <<= 1
165 return flip(a)
166
167@demo
168def demo_table_b(u, v):
169 """Big-endian table lookup."""
170 return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0)
171
172@demo
173def demo_table_l(u, v):
174 """Little-endian table lookup."""
58094286 175 return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18)
9e6a4409
MW
176
177###--------------------------------------------------------------------------
178### Implementation using 64×64->128-bit binary polynomial multiplication.
179
180_i = iota()
181TAG_INPUT_U = _i()
182TAG_INPUT_V = _i()
183TAG_KPIECE_U = _i()
184TAG_KPIECE_V = _i()
185TAG_PRODPIECE = _i()
186TAG_PRODSUM = _i()
187TAG_PRODUCT = _i()
188TAG_SHIFTED = _i()
189TAG_REDCBITS = _i()
190TAG_REDCFULL = _i()
191TAG_REDCMIX = _i()
192TAG_OUTPUT = _i()
193
194def split_gf(x, n):
195 n /= 8
196 return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)]
197
198def join_gf(xx, n):
199 x = C.GF(0)
200 for i in xrange(len(xx)): x = (x << n) | xx[i]
201 return x
202
203def present_gf(x, w, n, what):
204 firstp = True
205 m = gfmask(n)
206 for i in xrange(0, w, 128):
207 print ';; %12s%c =%s' % \
208 (firstp and what or '',
209 firstp and ':' or ' ',
210 ''.join([j < w
211 and ' 0x%s' % hex(((x >> j)&m).storeb(n/8))
212 or ''
213 for j in xrange(i, i + 128, n)]))
214 firstp = False
215
216def present_gf_pclmul(tag, wd, x, w, n, what):
217 if tag != TAG_PRODPIECE: present_gf(x, w, n, what)
218
219def reverse(x, w):
220 return C.GF.loadl(x.storeb(w/8))
221
222def rev32(x):
223 w = x.noctets
224 m_ffff = repmask(0xffff, 32, w/4)
225 m_ff = repmask(0xff, 16, w/2)
226 x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff)
227 x = ((x&m_ff) << 8) | ((x >> 8)&m_ff)
228 return x
229
230def rev8(x):
231 w = x.noctets
232 m_0f = repmask(0x0f, 8, w)
233 m_33 = repmask(0x33, 8, w)
234 m_55 = repmask(0x55, 8, w)
235 x = ((x&m_0f) << 4) | ((x >> 4)&m_0f)
236 x = ((x&m_33) << 2) | ((x >> 2)&m_33)
237 x = ((x&m_55) << 1) | ((x >> 1)&m_55)
238 return x
239
240def present_gf_mullp64(tag, wd, x, w, n, what):
241 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
242 return
243 elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT:
244 y = x
245 elif (wd == 96 or wd == 192 or wd == 256) and \
246 TAG_PRODSUM <= tag < TAG_OUTPUT:
247 y = x
248 else:
249 xx = x.storeb(w/8)
250 extra = len(xx)%8
251 if extra: xx += C.ByteString.zero(8 - extra)
252 yb = C.WriteBuffer()
253 for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i])
254 y = C.GF.loadb(yb.contents)
255 present_gf(y, (w + 63)&~63, n, what)
256
257def present_gf_pmull(tag, wd, x, w, n, what):
258 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL or tag == TAG_SHIFTED:
259 return
260 elif tag == TAG_INPUT_V or tag == TAG_KPIECE_V:
188ffeae 261 w = (w + 63)&~63
9e6a4409
MW
262 bx = C.ReadBuffer(x.storeb(w/8))
263 by = C.WriteBuffer()
264 while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk)
265 x = C.GF.loadb(by.contents)
266 w *= 2
267 elif TAG_PRODSUM <= tag <= TAG_PRODUCT:
268 x <<= 1
269 y = reverse(rev8(x), w)
270 present_gf(y, w, n, what)
271
272def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat):
273 """
274 Multiply U by V, returning the product.
275
276 This is the fallback long multiplication.
277 """
278
279 uw, vw = 8*len(u), 8*len(v)
280
281 ## We start by carving the operands into 64-bit pieces. This is
282 ## straightforward except for the 96-bit case, where we end up with two
283 ## short pieces which we pad at the beginning.
284 if uw%mulwd: pad = (-uw)%mulwd; u += C.ByteString.zero(pad); uw += pad
1dfa221e 285 if vw%mulwd: pad = (-vw)%mulwd; v += C.ByteString.zero(pad); vw += pad
9e6a4409
MW
286 uu = split_gf(u, mulwd)
287 vv = split_gf(v, mulwd)
288
289 ## Report and accumulate the individual product pieces.
290 x = C.GF(0)
291 ulim, vlim = uw/mulwd, vw/mulwd
292 for i in xrange(ulim + vlim - 2, -1, -1):
293 t = C.GF(0)
294 for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)):
295 s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j]
296 presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd,
297 '%s_%d %s_%d' % (uwhat, i - j, vwhat, j))
298 t += s
299 presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd,
300 '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i))
301 x += t << (mulwd*i)
302 presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat))
303
304 return x
305
306def poly64_mul_karatsuba(u, v, klimit, presfn, wd,
307 dispwd, mulwd, uwhat, vwhat):
308 """
309 Multiply U by V, returning the product.
310
311 If the length of U and V is at least KLIMIT, and the operands are otherwise
312 suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to
313 `poly64_mul_simple'.
314 """
315 w = 8*len(u)
316
317 if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0:
318 return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat)
319
320 hw = w/2
321 u0, u1 = u[:hw/8], u[hw/8:]
322 v0, v1 = v[:hw/8], v[hw/8:]
323 uu, vv = u0 ^ u1, v0 ^ v1
324
325 presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat)
326 presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat)
327 uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd,
328 '%s*' % uwhat, '%s*' % vwhat)
329
330 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat)
331 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat)
332 u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd,
333 '%s0' % uwhat, '%s0' % vwhat)
334
335 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat)
336 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat)
337 u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd,
338 '%s1' % uwhat, '%s1' % vwhat)
339
340 uvuv = uuvv + u0v0 + u1v1
341 presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat))
342
343 x = u1v1 + (uvuv << hw) + (u0v0 << w)
344 presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat))
345 return x
346
347def poly64_common(u, v, presfn, dispwd = 32, mulwd = 64, redcwd = 32,
348 klimit = 256):
349 """
350 Multiply U by V using a primitive 64-bit binary polynomial mutliplier.
351
352 Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on
353 x86, and as `vmull.p64'/`pmull' on ARM.
354
355 Operands arrive in a `register format', which is a byte-swapped variant of
356 the external format. Implementations differ on the precise details,
357 though.
358 """
359
360 ## We work in two main phases: first, calculate the full double-width
361 ## product; and, second, reduce it modulo the field polynomial.
362
363 w = 8*len(u); assert(w == 8*len(v))
364 p = poly(w)
365 presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
366 presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
367
368 ## So, on to the first part: the multiplication.
369 x = poly64_mul_karatsuba(u, v, klimit, presfn, w, dispwd, mulwd, 'u', 'v')
370
371 ## Now we have to shift everything up one bit to account for GCM's crazy
372 ## bit ordering.
373 y = x << 1
374 if w == 96: y >>= 64
375 presfn(TAG_SHIFTED, w, y, 2*w, dispwd, 'y')
376
377 ## Now for the reduction.
378 ##
379 ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i<d} r_i t^i,
380 ## with each r_i either 0 or 1. Because we choose the lexically earliest
381 ## irreducible polynomial with the necessary degree, r_i = 1 happens only
382 ## for a small number of tiny i. In our field, we have t^d = r.
383 ##
384 ## We carve the product into convenient n-bit pieces, for some n dividing d
385 ## -- typically n = 32 or 64. Let d = m n, and write y = SUM_{0<=i<2m} y_i
386 ## t^{ni}. The upper portion, the y_i with i >= m, needs reduction; but
387 ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and
388 ## add it to the bottom half. This all depends on r_i = 0 for all i >=
389 ## n/2. We process each nonzero coefficient of r separately, in two
390 ## passes.
391 ##
392 ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j
393 ## bits (or would be if GCM weren't backwards, but let's not worry about
394 ## that right now). The high j bits will spill over into the next chunk,
395 ## while the low n - j bits will stay where they are. It's these high bits
396 ## which cause trouble -- particularly the high bits of the top chunk,
397 ## since we'll add them on to y_m, which will need further reduction. But
398 ## only the topmost j bits will do this.
399 ##
400 ## The trick is that we do all of the bits which spill over first -- all of
401 ## the top j bits in each chunk, for each j -- in one pass, and then a
402 ## second pass of all the bits which don't. Because j, j' < n/2 for any
403 ## two nonzero coefficient degrees j and j', we have j + j' < n whence j <
404 ## n - j' -- so all of the bits contributed to y_m will be handled in the
405 ## second pass when we handle the bits that don't spill over.
406 rr = [i for i in xrange(1, w) if p.testbit(i)]
407 m = gfmask(redcwd)
408
409 ## Handle the spilling bits.
410 yy = split_gf(y.storeb(w/4), redcwd)
411 b = C.GF(0)
412 for rj in rr:
413 br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]]
414 presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj)
415 b += join_gf(br, redcwd) << (w - redcwd)
416 presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b')
417 s = y + b
418 presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's')
419
420 ## Handle the nonspilling bits.
421 ss = split_gf(s.storeb(w/4), redcwd)
422 a = C.GF(0)
423 for rj in rr:
424 ar = [si >> rj for si in ss[w/redcwd:]]
425 presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj)
426 a += join_gf(ar, redcwd)
427 presfn(TAG_REDCFULL, w, a, w, dispwd, 'a')
428
429 ## Mix everything together.
430 m = gfmask(w)
431 z = (s&m) + (s >> w) + a
432 presfn(TAG_OUTPUT, w, z, w, dispwd, 'z')
433
434 ## And we're done.
435 return z.storeb(w/8)
436
437@demo
438def demo_pclmul(u, v):
439 return poly64_common(u, v, presfn = present_gf_pclmul)
440
441@demo
442def demo_vmullp64(u, v):
443 w = 8*len(u)
444 return poly64_common(u, v, presfn = present_gf_mullp64,
445 redcwd = w%64 == 32 and 32 or 64)
446
447@demo
448def demo_pmull(u, v):
449 w = 8*len(u)
450 return poly64_common(u, v, presfn = present_gf_pmull,
451 redcwd = w%64 == 32 and 32 or 64)
452
453###--------------------------------------------------------------------------
454### @@@ Random debris to be deleted. @@@
455
456def cutting_room_floor():
457
458 x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df')
459 y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332')
460
461 u, v = C.GF.loadb(x), C.GF.loadb(y)
462
463 g = u*v << 1
464 print 'y = %s' % words(g.storeb(48))
465 b1 = (g&repmask(0x01, 32, 6)) << 191
466 b2 = (g&repmask(0x03, 32, 6)) << 190
467 b7 = (g&repmask(0x7f, 32, 6)) << 185
468 b = b1 + b2 + b7
469 print 'b = %s' % words(b.storeb(48)[0:28])
470 h = g + b
471 print 'w = %s' % words(h.storeb(48))
472
473 a0 = (h&repmask(0xffffffff, 32, 6)) << 192
474 a1 = (h&repmask(0xfffffffe, 32, 6)) << 191
475 a2 = (h&repmask(0xfffffffc, 32, 6)) << 190
476 a7 = (h&repmask(0xffffff80, 32, 6)) << 185
477 a = a0 + a1 + a2 + a7
478
479 print ' a_1 = %s' % words(a1.storeb(48)[0:24])
480 print ' a_2 = %s' % words(a2.storeb(48)[0:24])
481 print ' a_7 = %s' % words(a7.storeb(48)[0:24])
482
483 print 'low+unit = %s' % words((h + a0).storeb(48)[0:24])
484 print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24])
485 print ' 1,7 = %s' % words((a1 + a7).storeb(48)[0:24])
486
487 print 'a = %s' % words(a.storeb(48)[0:24])
488 z = h + a
489 print 'z = %s' % words(z.storeb(48))
490
491 z = gcm_mul(x, y)
492 print 'u v mod p = %s' % words(z)
493
494###--------------------------------------------------------------------------
495### Main program.
496
497style = argv[1]
498u = C.bytes(argv[2])
499v = C.bytes(argv[3])
500zz = DEMOMAP[style](u, v)
501assert zz == gcm_mul(u, v)
502
503###----- That's all, folks --------------------------------------------------