utils/gcm-ref (present_gf_pmull): Round width up to a multiple of 64 bits.
[catacomb] / utils / gcm-ref
1 #! /usr/bin/python
2 ### -*- coding: utf-8 -*-
3
4 from sys import argv, exit
5
6 import catacomb as C
7
8 ###--------------------------------------------------------------------------
9 ### Random utilities.
10
11 def words(s):
12 """Split S into 32-bit pieces and report their values as hex."""
13 return ' '.join('%08x' % C.MP.loadb(s[i:i + 4])
14 for i in xrange(0, len(s), 4))
15
16 def words_64(s):
17 """Split S into 64-bit pieces and report their values as hex."""
18 return ' '.join('%016x' % C.MP.loadb(s[i:i + 8])
19 for i in xrange(0, len(s), 8))
20
21 def repmask(val, wd, n):
22 """Return a mask consisting of N copies of the WD-bit value VAL."""
23 v = C.GF(val)
24 a = C.GF(0)
25 for i in xrange(n): a = (a << wd) | v
26 return a
27
28 def combs(things, k):
29 """Iterate over all possible combinations of K of the THINGS."""
30 ii = range(k)
31 n = len(things)
32 while True:
33 yield [things[i] for i in ii]
34 for j in xrange(k):
35 if j == k - 1: lim = n
36 else: lim = ii[j + 1]
37 i = ii[j] + 1
38 if i < lim:
39 ii[j] = i
40 break
41 ii[j] = j
42 else:
43 return
44
45 POLYMAP = {}
46
47 def poly(nbits):
48 """
49 Return the lexically first irreducible polynomial of degree NBITS of lowest
50 weight.
51 """
52 try: return POLYMAP[nbits]
53 except KeyError: pass
54 base = C.GF(0).setbit(nbits).setbit(0)
55 for k in xrange(1, nbits, 2):
56 for cc in combs(range(1, nbits), k):
57 p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
58 if p.irreduciblep(): POLYMAP[nbits] = p; return p
59 raise ValueError, nbits
60
61 def gcm_mangle(x):
62 """Flip the bits within each byte according to GCM's insane convention."""
63 y = C.WriteBuffer()
64 for b in x:
65 b = ord(b)
66 bb = 0
67 for i in xrange(8):
68 bb <<= 1
69 if b&1: bb |= 1
70 b >>= 1
71 y.putu8(bb)
72 return y.contents
73
74 def endswap_words_32(x):
75 """End-swap each 32-bit word of X."""
76 x = C.ReadBuffer(x)
77 y = C.WriteBuffer()
78 while x.left: y.putu32l(x.getu32b())
79 return y.contents
80
81 def endswap_words_64(x):
82 """End-swap each 64-bit word of X."""
83 x = C.ReadBuffer(x)
84 y = C.WriteBuffer()
85 while x.left: y.putu64l(x.getu64b())
86 return y.contents
87
88 def endswap_bytes(x):
89 """End-swap X by bytes."""
90 y = C.WriteBuffer()
91 for ch in reversed(x): y.put(ch)
92 return y.contents
93
94 def gfmask(n):
95 return C.GF(C.MP(0).setbit(n) - 1)
96
97 def gcm_mul(x, y):
98 """Multiply X and Y according to the GCM rules."""
99 w = len(x)
100 p = poly(8*w)
101 u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
102 z = (u*v)%p
103 return gcm_mangle(z.storel(w))
104
105 DEMOMAP = {}
106 def demo(func):
107 name = func.func_name
108 assert(name.startswith('demo_'))
109 DEMOMAP[name[5:].replace('_', '-')] = func
110 return func
111
112 def iota(i = 0):
113 vi = [i]
114 def next(): vi[0] += 1; return vi[0] - 1
115 return next
116
117 ###--------------------------------------------------------------------------
118 ### Portable table-driven implementation.
119
120 def shift_left(x):
121 """Given a field element X (in external format), return X t."""
122 w = len(x)
123 p = poly(8*w)
124 return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p))
125
126 def table_common(u, v, flip, getword, ixmask):
127 """
128 Multiply U by V using table lookup; common for `table-b' and `table-l'.
129
130 This matches the `simple_mulk_...' implementation in `gcm.c'. One-entry
131 per bit is the best we can manage if we want a constant-time
132 implementation: processing n bits at a time means we need to scan
133 (2^n - 1)/n times as much memory.
134
135 * FLIP is a function (assumed to be an involution) on one argument X to
136 convert X from external format to table-entry format or back again.
137
138 * GETWORD is a function on one argument B to retrieve the next 32-bit
139 chunk of a field element held in a `ReadBuffer'. Bits within a word
140 are processed most-significant first.
141
142 * IXMASK is a mask XORed into table indices to permute the table so that
143 it's order matches that induced by GETWORD.
144
145 The table is built such that tab[i XOR IXMASK] = U t^i.
146 """
147 w = len(u); assert(w == len(v))
148 a = C.ByteString.zero(w)
149 tab = [None]*(8*w)
150 for i in xrange(8*w):
151 print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u))
152 tab[i ^ ixmask] = flip(u)
153 u = shift_left(u)
154 v = C.ReadBuffer(v)
155 i = 0
156 while v.left:
157 t = getword(v)
158 for j in xrange(32):
159 bit = (t >> 31)&1
160 if bit: a ^= tab[i]
161 print ';; %6s = %d: a <- %s [%9s = %s]' % \
162 ('v[%d]' % (i ^ ixmask), bit, words(a),
163 'utab[%d]' % (i ^ ixmask), words(tab[i]))
164 i += 1; t <<= 1
165 return flip(a)
166
167 @demo
168 def demo_table_b(u, v):
169 """Big-endian table lookup."""
170 return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0)
171
172 @demo
173 def demo_table_l(u, v):
174 """Little-endian table lookup."""
175 return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18)
176
177 ###--------------------------------------------------------------------------
178 ### Implementation using 64×64->128-bit binary polynomial multiplication.
179
180 _i = iota()
181 TAG_INPUT_U = _i()
182 TAG_INPUT_V = _i()
183 TAG_KPIECE_U = _i()
184 TAG_KPIECE_V = _i()
185 TAG_PRODPIECE = _i()
186 TAG_PRODSUM = _i()
187 TAG_PRODUCT = _i()
188 TAG_SHIFTED = _i()
189 TAG_REDCBITS = _i()
190 TAG_REDCFULL = _i()
191 TAG_REDCMIX = _i()
192 TAG_OUTPUT = _i()
193
194 def split_gf(x, n):
195 n /= 8
196 return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)]
197
198 def join_gf(xx, n):
199 x = C.GF(0)
200 for i in xrange(len(xx)): x = (x << n) | xx[i]
201 return x
202
203 def present_gf(x, w, n, what):
204 firstp = True
205 m = gfmask(n)
206 for i in xrange(0, w, 128):
207 print ';; %12s%c =%s' % \
208 (firstp and what or '',
209 firstp and ':' or ' ',
210 ''.join([j < w
211 and ' 0x%s' % hex(((x >> j)&m).storeb(n/8))
212 or ''
213 for j in xrange(i, i + 128, n)]))
214 firstp = False
215
216 def present_gf_pclmul(tag, wd, x, w, n, what):
217 if tag != TAG_PRODPIECE: present_gf(x, w, n, what)
218
219 def reverse(x, w):
220 return C.GF.loadl(x.storeb(w/8))
221
222 def rev32(x):
223 w = x.noctets
224 m_ffff = repmask(0xffff, 32, w/4)
225 m_ff = repmask(0xff, 16, w/2)
226 x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff)
227 x = ((x&m_ff) << 8) | ((x >> 8)&m_ff)
228 return x
229
230 def rev8(x):
231 w = x.noctets
232 m_0f = repmask(0x0f, 8, w)
233 m_33 = repmask(0x33, 8, w)
234 m_55 = repmask(0x55, 8, w)
235 x = ((x&m_0f) << 4) | ((x >> 4)&m_0f)
236 x = ((x&m_33) << 2) | ((x >> 2)&m_33)
237 x = ((x&m_55) << 1) | ((x >> 1)&m_55)
238 return x
239
240 def present_gf_mullp64(tag, wd, x, w, n, what):
241 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
242 return
243 elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT:
244 y = x
245 elif (wd == 96 or wd == 192 or wd == 256) and \
246 TAG_PRODSUM <= tag < TAG_OUTPUT:
247 y = x
248 else:
249 xx = x.storeb(w/8)
250 extra = len(xx)%8
251 if extra: xx += C.ByteString.zero(8 - extra)
252 yb = C.WriteBuffer()
253 for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i])
254 y = C.GF.loadb(yb.contents)
255 present_gf(y, (w + 63)&~63, n, what)
256
257 def present_gf_pmull(tag, wd, x, w, n, what):
258 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL or tag == TAG_SHIFTED:
259 return
260 elif tag == TAG_INPUT_V or tag == TAG_KPIECE_V:
261 w = (w + 63)&~63
262 bx = C.ReadBuffer(x.storeb(w/8))
263 by = C.WriteBuffer()
264 while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk)
265 x = C.GF.loadb(by.contents)
266 w *= 2
267 elif TAG_PRODSUM <= tag <= TAG_PRODUCT:
268 x <<= 1
269 y = reverse(rev8(x), w)
270 present_gf(y, w, n, what)
271
272 def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat):
273 """
274 Multiply U by V, returning the product.
275
276 This is the fallback long multiplication.
277 """
278
279 uw, vw = 8*len(u), 8*len(v)
280
281 ## We start by carving the operands into 64-bit pieces. This is
282 ## straightforward except for the 96-bit case, where we end up with two
283 ## short pieces which we pad at the beginning.
284 if uw%mulwd: pad = (-uw)%mulwd; u += C.ByteString.zero(pad); uw += pad
285 if vw%mulwd: pad = (-vw)%mulwd; v += C.ByteString.zero(pad); vw += pad
286 uu = split_gf(u, mulwd)
287 vv = split_gf(v, mulwd)
288
289 ## Report and accumulate the individual product pieces.
290 x = C.GF(0)
291 ulim, vlim = uw/mulwd, vw/mulwd
292 for i in xrange(ulim + vlim - 2, -1, -1):
293 t = C.GF(0)
294 for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)):
295 s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j]
296 presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd,
297 '%s_%d %s_%d' % (uwhat, i - j, vwhat, j))
298 t += s
299 presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd,
300 '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i))
301 x += t << (mulwd*i)
302 presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat))
303
304 return x
305
306 def poly64_mul_karatsuba(u, v, klimit, presfn, wd,
307 dispwd, mulwd, uwhat, vwhat):
308 """
309 Multiply U by V, returning the product.
310
311 If the length of U and V is at least KLIMIT, and the operands are otherwise
312 suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to
313 `poly64_mul_simple'.
314 """
315 w = 8*len(u)
316
317 if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0:
318 return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat)
319
320 hw = w/2
321 u0, u1 = u[:hw/8], u[hw/8:]
322 v0, v1 = v[:hw/8], v[hw/8:]
323 uu, vv = u0 ^ u1, v0 ^ v1
324
325 presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat)
326 presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat)
327 uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd,
328 '%s*' % uwhat, '%s*' % vwhat)
329
330 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat)
331 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat)
332 u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd,
333 '%s0' % uwhat, '%s0' % vwhat)
334
335 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat)
336 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat)
337 u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd,
338 '%s1' % uwhat, '%s1' % vwhat)
339
340 uvuv = uuvv + u0v0 + u1v1
341 presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat))
342
343 x = u1v1 + (uvuv << hw) + (u0v0 << w)
344 presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat))
345 return x
346
347 def poly64_common(u, v, presfn, dispwd = 32, mulwd = 64, redcwd = 32,
348 klimit = 256):
349 """
350 Multiply U by V using a primitive 64-bit binary polynomial mutliplier.
351
352 Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on
353 x86, and as `vmull.p64'/`pmull' on ARM.
354
355 Operands arrive in a `register format', which is a byte-swapped variant of
356 the external format. Implementations differ on the precise details,
357 though.
358 """
359
360 ## We work in two main phases: first, calculate the full double-width
361 ## product; and, second, reduce it modulo the field polynomial.
362
363 w = 8*len(u); assert(w == 8*len(v))
364 p = poly(w)
365 presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
366 presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
367
368 ## So, on to the first part: the multiplication.
369 x = poly64_mul_karatsuba(u, v, klimit, presfn, w, dispwd, mulwd, 'u', 'v')
370
371 ## Now we have to shift everything up one bit to account for GCM's crazy
372 ## bit ordering.
373 y = x << 1
374 if w == 96: y >>= 64
375 presfn(TAG_SHIFTED, w, y, 2*w, dispwd, 'y')
376
377 ## Now for the reduction.
378 ##
379 ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i<d} r_i t^i,
380 ## with each r_i either 0 or 1. Because we choose the lexically earliest
381 ## irreducible polynomial with the necessary degree, r_i = 1 happens only
382 ## for a small number of tiny i. In our field, we have t^d = r.
383 ##
384 ## We carve the product into convenient n-bit pieces, for some n dividing d
385 ## -- typically n = 32 or 64. Let d = m n, and write y = SUM_{0<=i<2m} y_i
386 ## t^{ni}. The upper portion, the y_i with i >= m, needs reduction; but
387 ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and
388 ## add it to the bottom half. This all depends on r_i = 0 for all i >=
389 ## n/2. We process each nonzero coefficient of r separately, in two
390 ## passes.
391 ##
392 ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j
393 ## bits (or would be if GCM weren't backwards, but let's not worry about
394 ## that right now). The high j bits will spill over into the next chunk,
395 ## while the low n - j bits will stay where they are. It's these high bits
396 ## which cause trouble -- particularly the high bits of the top chunk,
397 ## since we'll add them on to y_m, which will need further reduction. But
398 ## only the topmost j bits will do this.
399 ##
400 ## The trick is that we do all of the bits which spill over first -- all of
401 ## the top j bits in each chunk, for each j -- in one pass, and then a
402 ## second pass of all the bits which don't. Because j, j' < n/2 for any
403 ## two nonzero coefficient degrees j and j', we have j + j' < n whence j <
404 ## n - j' -- so all of the bits contributed to y_m will be handled in the
405 ## second pass when we handle the bits that don't spill over.
406 rr = [i for i in xrange(1, w) if p.testbit(i)]
407 m = gfmask(redcwd)
408
409 ## Handle the spilling bits.
410 yy = split_gf(y.storeb(w/4), redcwd)
411 b = C.GF(0)
412 for rj in rr:
413 br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]]
414 presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj)
415 b += join_gf(br, redcwd) << (w - redcwd)
416 presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b')
417 s = y + b
418 presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's')
419
420 ## Handle the nonspilling bits.
421 ss = split_gf(s.storeb(w/4), redcwd)
422 a = C.GF(0)
423 for rj in rr:
424 ar = [si >> rj for si in ss[w/redcwd:]]
425 presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj)
426 a += join_gf(ar, redcwd)
427 presfn(TAG_REDCFULL, w, a, w, dispwd, 'a')
428
429 ## Mix everything together.
430 m = gfmask(w)
431 z = (s&m) + (s >> w) + a
432 presfn(TAG_OUTPUT, w, z, w, dispwd, 'z')
433
434 ## And we're done.
435 return z.storeb(w/8)
436
437 @demo
438 def demo_pclmul(u, v):
439 return poly64_common(u, v, presfn = present_gf_pclmul)
440
441 @demo
442 def demo_vmullp64(u, v):
443 w = 8*len(u)
444 return poly64_common(u, v, presfn = present_gf_mullp64,
445 redcwd = w%64 == 32 and 32 or 64)
446
447 @demo
448 def demo_pmull(u, v):
449 w = 8*len(u)
450 return poly64_common(u, v, presfn = present_gf_pmull,
451 redcwd = w%64 == 32 and 32 or 64)
452
453 ###--------------------------------------------------------------------------
454 ### @@@ Random debris to be deleted. @@@
455
456 def cutting_room_floor():
457
458 x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df')
459 y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332')
460
461 u, v = C.GF.loadb(x), C.GF.loadb(y)
462
463 g = u*v << 1
464 print 'y = %s' % words(g.storeb(48))
465 b1 = (g&repmask(0x01, 32, 6)) << 191
466 b2 = (g&repmask(0x03, 32, 6)) << 190
467 b7 = (g&repmask(0x7f, 32, 6)) << 185
468 b = b1 + b2 + b7
469 print 'b = %s' % words(b.storeb(48)[0:28])
470 h = g + b
471 print 'w = %s' % words(h.storeb(48))
472
473 a0 = (h&repmask(0xffffffff, 32, 6)) << 192
474 a1 = (h&repmask(0xfffffffe, 32, 6)) << 191
475 a2 = (h&repmask(0xfffffffc, 32, 6)) << 190
476 a7 = (h&repmask(0xffffff80, 32, 6)) << 185
477 a = a0 + a1 + a2 + a7
478
479 print ' a_1 = %s' % words(a1.storeb(48)[0:24])
480 print ' a_2 = %s' % words(a2.storeb(48)[0:24])
481 print ' a_7 = %s' % words(a7.storeb(48)[0:24])
482
483 print 'low+unit = %s' % words((h + a0).storeb(48)[0:24])
484 print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24])
485 print ' 1,7 = %s' % words((a1 + a7).storeb(48)[0:24])
486
487 print 'a = %s' % words(a.storeb(48)[0:24])
488 z = h + a
489 print 'z = %s' % words(z.storeb(48))
490
491 z = gcm_mul(x, y)
492 print 'u v mod p = %s' % words(z)
493
494 ###--------------------------------------------------------------------------
495 ### Main program.
496
497 style = argv[1]
498 u = C.bytes(argv[2])
499 v = C.bytes(argv[3])
500 zz = DEMOMAP[style](u, v)
501 assert zz == gcm_mul(u, v)
502
503 ###----- That's all, folks --------------------------------------------------