utils/gcm-ref (present_gf_vmullp64): Add `v' prefix to match front end.
[catacomb] / utils / gcm-ref
CommitLineData
9e6a4409
MW
1#! /usr/bin/python
2### -*- coding: utf-8 -*-
3
4from sys import argv, exit
5
6import catacomb as C
7
8###--------------------------------------------------------------------------
9### Random utilities.
10
11def words(s):
12 """Split S into 32-bit pieces and report their values as hex."""
13 return ' '.join('%08x' % C.MP.loadb(s[i:i + 4])
14 for i in xrange(0, len(s), 4))
15
16def words_64(s):
17 """Split S into 64-bit pieces and report their values as hex."""
18 return ' '.join('%016x' % C.MP.loadb(s[i:i + 8])
19 for i in xrange(0, len(s), 8))
20
21def repmask(val, wd, n):
22 """Return a mask consisting of N copies of the WD-bit value VAL."""
23 v = C.GF(val)
24 a = C.GF(0)
25 for i in xrange(n): a = (a << wd) | v
26 return a
27
28def combs(things, k):
29 """Iterate over all possible combinations of K of the THINGS."""
30 ii = range(k)
31 n = len(things)
32 while True:
33 yield [things[i] for i in ii]
34 for j in xrange(k):
35 if j == k - 1: lim = n
36 else: lim = ii[j + 1]
37 i = ii[j] + 1
38 if i < lim:
39 ii[j] = i
40 break
41 ii[j] = j
42 else:
43 return
44
45POLYMAP = {}
46
47def poly(nbits):
48 """
49 Return the lexically first irreducible polynomial of degree NBITS of lowest
50 weight.
51 """
52 try: return POLYMAP[nbits]
53 except KeyError: pass
54 base = C.GF(0).setbit(nbits).setbit(0)
55 for k in xrange(1, nbits, 2):
56 for cc in combs(range(1, nbits), k):
601ec68e 57 p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
9e6a4409
MW
58 if p.irreduciblep(): POLYMAP[nbits] = p; return p
59 raise ValueError, nbits
60
61def gcm_mangle(x):
62 """Flip the bits within each byte according to GCM's insane convention."""
63 y = C.WriteBuffer()
64 for b in x:
65 b = ord(b)
66 bb = 0
67 for i in xrange(8):
68 bb <<= 1
69 if b&1: bb |= 1
70 b >>= 1
71 y.putu8(bb)
72 return y.contents
73
74def endswap_words_32(x):
75 """End-swap each 32-bit word of X."""
76 x = C.ReadBuffer(x)
77 y = C.WriteBuffer()
78 while x.left: y.putu32l(x.getu32b())
79 return y.contents
80
81def endswap_words_64(x):
82 """End-swap each 64-bit word of X."""
83 x = C.ReadBuffer(x)
84 y = C.WriteBuffer()
85 while x.left: y.putu64l(x.getu64b())
86 return y.contents
87
88def endswap_bytes(x):
89 """End-swap X by bytes."""
90 y = C.WriteBuffer()
91 for ch in reversed(x): y.put(ch)
92 return y.contents
93
94def gfmask(n):
95 return C.GF(C.MP(0).setbit(n) - 1)
96
97def gcm_mul(x, y):
98 """Multiply X and Y according to the GCM rules."""
99 w = len(x)
100 p = poly(8*w)
101 u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
102 z = (u*v)%p
103 return gcm_mangle(z.storel(w))
104
105DEMOMAP = {}
106def demo(func):
107 name = func.func_name
108 assert(name.startswith('demo_'))
109 DEMOMAP[name[5:].replace('_', '-')] = func
110 return func
111
112def iota(i = 0):
113 vi = [i]
114 def next(): vi[0] += 1; return vi[0] - 1
115 return next
116
117###--------------------------------------------------------------------------
118### Portable table-driven implementation.
119
120def shift_left(x):
121 """Given a field element X (in external format), return X t."""
122 w = len(x)
123 p = poly(8*w)
124 return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p))
125
126def table_common(u, v, flip, getword, ixmask):
127 """
128 Multiply U by V using table lookup; common for `table-b' and `table-l'.
129
4e7475c2 130 This matches the `simple_mulk_...' implementation in `gcm.c'. One entry
9e6a4409
MW
131 per bit is the best we can manage if we want a constant-time
132 implementation: processing n bits at a time means we need to scan
133 (2^n - 1)/n times as much memory.
134
135 * FLIP is a function (assumed to be an involution) on one argument X to
136 convert X from external format to table-entry format or back again.
137
138 * GETWORD is a function on one argument B to retrieve the next 32-bit
139 chunk of a field element held in a `ReadBuffer'. Bits within a word
140 are processed most-significant first.
141
142 * IXMASK is a mask XORed into table indices to permute the table so that
4e7475c2 143 its order matches that induced by GETWORD.
9e6a4409
MW
144
145 The table is built such that tab[i XOR IXMASK] = U t^i.
146 """
147 w = len(u); assert(w == len(v))
148 a = C.ByteString.zero(w)
149 tab = [None]*(8*w)
150 for i in xrange(8*w):
151 print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u))
152 tab[i ^ ixmask] = flip(u)
153 u = shift_left(u)
154 v = C.ReadBuffer(v)
155 i = 0
156 while v.left:
157 t = getword(v)
158 for j in xrange(32):
159 bit = (t >> 31)&1
160 if bit: a ^= tab[i]
161 print ';; %6s = %d: a <- %s [%9s = %s]' % \
162 ('v[%d]' % (i ^ ixmask), bit, words(a),
163 'utab[%d]' % (i ^ ixmask), words(tab[i]))
164 i += 1; t <<= 1
165 return flip(a)
166
167@demo
168def demo_table_b(u, v):
169 """Big-endian table lookup."""
170 return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0)
171
172@demo
173def demo_table_l(u, v):
174 """Little-endian table lookup."""
58094286 175 return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18)
9e6a4409
MW
176
177###--------------------------------------------------------------------------
178### Implementation using 64×64->128-bit binary polynomial multiplication.
179
180_i = iota()
181TAG_INPUT_U = _i()
182TAG_INPUT_V = _i()
183TAG_KPIECE_U = _i()
184TAG_KPIECE_V = _i()
185TAG_PRODPIECE = _i()
186TAG_PRODSUM = _i()
187TAG_PRODUCT = _i()
188TAG_SHIFTED = _i()
189TAG_REDCBITS = _i()
190TAG_REDCFULL = _i()
191TAG_REDCMIX = _i()
192TAG_OUTPUT = _i()
193
194def split_gf(x, n):
195 n /= 8
196 return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)]
197
198def join_gf(xx, n):
199 x = C.GF(0)
200 for i in xrange(len(xx)): x = (x << n) | xx[i]
201 return x
202
203def present_gf(x, w, n, what):
204 firstp = True
205 m = gfmask(n)
206 for i in xrange(0, w, 128):
207 print ';; %12s%c =%s' % \
208 (firstp and what or '',
209 firstp and ':' or ' ',
210 ''.join([j < w
211 and ' 0x%s' % hex(((x >> j)&m).storeb(n/8))
212 or ''
213 for j in xrange(i, i + 128, n)]))
214 firstp = False
215
216def present_gf_pclmul(tag, wd, x, w, n, what):
217 if tag != TAG_PRODPIECE: present_gf(x, w, n, what)
218
219def reverse(x, w):
220 return C.GF.loadl(x.storeb(w/8))
221
222def rev32(x):
223 w = x.noctets
224 m_ffff = repmask(0xffff, 32, w/4)
225 m_ff = repmask(0xff, 16, w/2)
226 x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff)
227 x = ((x&m_ff) << 8) | ((x >> 8)&m_ff)
228 return x
229
230def rev8(x):
231 w = x.noctets
232 m_0f = repmask(0x0f, 8, w)
233 m_33 = repmask(0x33, 8, w)
234 m_55 = repmask(0x55, 8, w)
235 x = ((x&m_0f) << 4) | ((x >> 4)&m_0f)
236 x = ((x&m_33) << 2) | ((x >> 2)&m_33)
237 x = ((x&m_55) << 1) | ((x >> 1)&m_55)
238 return x
239
e29fe90c 240def present_gf_vmullp64(tag, wd, x, w, n, what):
9e6a4409
MW
241 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
242 return
243 elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT:
244 y = x
245 elif (wd == 96 or wd == 192 or wd == 256) and \
246 TAG_PRODSUM <= tag < TAG_OUTPUT:
247 y = x
248 else:
249 xx = x.storeb(w/8)
250 extra = len(xx)%8
251 if extra: xx += C.ByteString.zero(8 - extra)
252 yb = C.WriteBuffer()
253 for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i])
254 y = C.GF.loadb(yb.contents)
255 present_gf(y, (w + 63)&~63, n, what)
256
257def present_gf_pmull(tag, wd, x, w, n, what):
258 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL or tag == TAG_SHIFTED:
259 return
260 elif tag == TAG_INPUT_V or tag == TAG_KPIECE_V:
188ffeae 261 w = (w + 63)&~63
9e6a4409
MW
262 bx = C.ReadBuffer(x.storeb(w/8))
263 by = C.WriteBuffer()
264 while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk)
265 x = C.GF.loadb(by.contents)
266 w *= 2
267 elif TAG_PRODSUM <= tag <= TAG_PRODUCT:
268 x <<= 1
269 y = reverse(rev8(x), w)
270 present_gf(y, w, n, what)
271
272def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat):
273 """
274 Multiply U by V, returning the product.
275
276 This is the fallback long multiplication.
277 """
278
279 uw, vw = 8*len(u), 8*len(v)
280
281 ## We start by carving the operands into 64-bit pieces. This is
282 ## straightforward except for the 96-bit case, where we end up with two
283 ## short pieces which we pad at the beginning.
91a8f888
MW
284 upad = (-uw)%mulwd; u += C.ByteString.zero(upad); uw += upad
285 vpad = (-vw)%mulwd; v += C.ByteString.zero(vpad); vw += vpad
286 uu = split_gf(u, mulwd); vv = split_gf(v, mulwd)
9e6a4409
MW
287
288 ## Report and accumulate the individual product pieces.
289 x = C.GF(0)
290 ulim, vlim = uw/mulwd, vw/mulwd
291 for i in xrange(ulim + vlim - 2, -1, -1):
292 t = C.GF(0)
293 for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)):
294 s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j]
295 presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd,
296 '%s_%d %s_%d' % (uwhat, i - j, vwhat, j))
297 t += s
298 presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd,
299 '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i))
300 x += t << (mulwd*i)
301 presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat))
302
91a8f888 303 return x >> (upad + vpad)
9e6a4409
MW
304
305def poly64_mul_karatsuba(u, v, klimit, presfn, wd,
306 dispwd, mulwd, uwhat, vwhat):
307 """
308 Multiply U by V, returning the product.
309
310 If the length of U and V is at least KLIMIT, and the operands are otherwise
311 suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to
312 `poly64_mul_simple'.
313 """
314 w = 8*len(u)
315
316 if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0:
317 return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat)
318
319 hw = w/2
320 u0, u1 = u[:hw/8], u[hw/8:]
321 v0, v1 = v[:hw/8], v[hw/8:]
322 uu, vv = u0 ^ u1, v0 ^ v1
323
324 presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat)
325 presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat)
326 uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd,
327 '%s*' % uwhat, '%s*' % vwhat)
328
329 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat)
330 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat)
331 u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd,
332 '%s0' % uwhat, '%s0' % vwhat)
333
334 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat)
335 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat)
336 u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd,
337 '%s1' % uwhat, '%s1' % vwhat)
338
339 uvuv = uuvv + u0v0 + u1v1
340 presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat))
341
342 x = u1v1 + (uvuv << hw) + (u0v0 << w)
343 presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat))
344 return x
345
346def poly64_common(u, v, presfn, dispwd = 32, mulwd = 64, redcwd = 32,
347 klimit = 256):
348 """
349 Multiply U by V using a primitive 64-bit binary polynomial mutliplier.
350
351 Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on
352 x86, and as `vmull.p64'/`pmull' on ARM.
353
354 Operands arrive in a `register format', which is a byte-swapped variant of
355 the external format. Implementations differ on the precise details,
356 though.
357 """
358
359 ## We work in two main phases: first, calculate the full double-width
360 ## product; and, second, reduce it modulo the field polynomial.
361
362 w = 8*len(u); assert(w == 8*len(v))
363 p = poly(w)
364 presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
365 presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
366
367 ## So, on to the first part: the multiplication.
368 x = poly64_mul_karatsuba(u, v, klimit, presfn, w, dispwd, mulwd, 'u', 'v')
369
370 ## Now we have to shift everything up one bit to account for GCM's crazy
371 ## bit ordering.
372 y = x << 1
9e6a4409
MW
373 presfn(TAG_SHIFTED, w, y, 2*w, dispwd, 'y')
374
375 ## Now for the reduction.
376 ##
377 ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i<d} r_i t^i,
378 ## with each r_i either 0 or 1. Because we choose the lexically earliest
379 ## irreducible polynomial with the necessary degree, r_i = 1 happens only
380 ## for a small number of tiny i. In our field, we have t^d = r.
381 ##
382 ## We carve the product into convenient n-bit pieces, for some n dividing d
383 ## -- typically n = 32 or 64. Let d = m n, and write y = SUM_{0<=i<2m} y_i
384 ## t^{ni}. The upper portion, the y_i with i >= m, needs reduction; but
385 ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and
386 ## add it to the bottom half. This all depends on r_i = 0 for all i >=
387 ## n/2. We process each nonzero coefficient of r separately, in two
388 ## passes.
389 ##
390 ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j
391 ## bits (or would be if GCM weren't backwards, but let's not worry about
392 ## that right now). The high j bits will spill over into the next chunk,
393 ## while the low n - j bits will stay where they are. It's these high bits
394 ## which cause trouble -- particularly the high bits of the top chunk,
395 ## since we'll add them on to y_m, which will need further reduction. But
396 ## only the topmost j bits will do this.
397 ##
398 ## The trick is that we do all of the bits which spill over first -- all of
399 ## the top j bits in each chunk, for each j -- in one pass, and then a
400 ## second pass of all the bits which don't. Because j, j' < n/2 for any
401 ## two nonzero coefficient degrees j and j', we have j + j' < n whence j <
402 ## n - j' -- so all of the bits contributed to y_m will be handled in the
403 ## second pass when we handle the bits that don't spill over.
404 rr = [i for i in xrange(1, w) if p.testbit(i)]
405 m = gfmask(redcwd)
406
407 ## Handle the spilling bits.
408 yy = split_gf(y.storeb(w/4), redcwd)
409 b = C.GF(0)
410 for rj in rr:
411 br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]]
412 presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj)
413 b += join_gf(br, redcwd) << (w - redcwd)
414 presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b')
415 s = y + b
416 presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's')
417
418 ## Handle the nonspilling bits.
419 ss = split_gf(s.storeb(w/4), redcwd)
420 a = C.GF(0)
421 for rj in rr:
422 ar = [si >> rj for si in ss[w/redcwd:]]
423 presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj)
424 a += join_gf(ar, redcwd)
425 presfn(TAG_REDCFULL, w, a, w, dispwd, 'a')
426
427 ## Mix everything together.
428 m = gfmask(w)
429 z = (s&m) + (s >> w) + a
430 presfn(TAG_OUTPUT, w, z, w, dispwd, 'z')
431
432 ## And we're done.
433 return z.storeb(w/8)
434
435@demo
436def demo_pclmul(u, v):
437 return poly64_common(u, v, presfn = present_gf_pclmul)
438
439@demo
440def demo_vmullp64(u, v):
441 w = 8*len(u)
e29fe90c 442 return poly64_common(u, v, presfn = present_gf_vmullp64,
9e6a4409
MW
443 redcwd = w%64 == 32 and 32 or 64)
444
445@demo
446def demo_pmull(u, v):
447 w = 8*len(u)
448 return poly64_common(u, v, presfn = present_gf_pmull,
449 redcwd = w%64 == 32 and 32 or 64)
450
451###--------------------------------------------------------------------------
452### @@@ Random debris to be deleted. @@@
453
454def cutting_room_floor():
455
456 x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df')
457 y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332')
458
459 u, v = C.GF.loadb(x), C.GF.loadb(y)
460
461 g = u*v << 1
462 print 'y = %s' % words(g.storeb(48))
463 b1 = (g&repmask(0x01, 32, 6)) << 191
464 b2 = (g&repmask(0x03, 32, 6)) << 190
465 b7 = (g&repmask(0x7f, 32, 6)) << 185
466 b = b1 + b2 + b7
467 print 'b = %s' % words(b.storeb(48)[0:28])
468 h = g + b
469 print 'w = %s' % words(h.storeb(48))
470
471 a0 = (h&repmask(0xffffffff, 32, 6)) << 192
472 a1 = (h&repmask(0xfffffffe, 32, 6)) << 191
473 a2 = (h&repmask(0xfffffffc, 32, 6)) << 190
474 a7 = (h&repmask(0xffffff80, 32, 6)) << 185
475 a = a0 + a1 + a2 + a7
476
477 print ' a_1 = %s' % words(a1.storeb(48)[0:24])
478 print ' a_2 = %s' % words(a2.storeb(48)[0:24])
479 print ' a_7 = %s' % words(a7.storeb(48)[0:24])
480
481 print 'low+unit = %s' % words((h + a0).storeb(48)[0:24])
482 print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24])
483 print ' 1,7 = %s' % words((a1 + a7).storeb(48)[0:24])
484
485 print 'a = %s' % words(a.storeb(48)[0:24])
486 z = h + a
487 print 'z = %s' % words(z.storeb(48))
488
489 z = gcm_mul(x, y)
490 print 'u v mod p = %s' % words(z)
491
492###--------------------------------------------------------------------------
493### Main program.
494
495style = argv[1]
496u = C.bytes(argv[2])
497v = C.bytes(argv[3])
498zz = DEMOMAP[style](u, v)
499assert zz == gcm_mul(u, v)
500
501###----- That's all, folks --------------------------------------------------