progs/perftest.c: Use from Glibc syscall numbers.
[catacomb] / utils / gcm-ref
CommitLineData
9e6a4409
MW
1#! /usr/bin/python
2### -*- coding: utf-8 -*-
3
4from sys import argv, exit
5
6import catacomb as C
7
8###--------------------------------------------------------------------------
9### Random utilities.
10
11def words(s):
12 """Split S into 32-bit pieces and report their values as hex."""
13 return ' '.join('%08x' % C.MP.loadb(s[i:i + 4])
14 for i in xrange(0, len(s), 4))
15
16def words_64(s):
17 """Split S into 64-bit pieces and report their values as hex."""
18 return ' '.join('%016x' % C.MP.loadb(s[i:i + 8])
19 for i in xrange(0, len(s), 8))
20
21def repmask(val, wd, n):
22 """Return a mask consisting of N copies of the WD-bit value VAL."""
23 v = C.GF(val)
24 a = C.GF(0)
25 for i in xrange(n): a = (a << wd) | v
26 return a
27
28def combs(things, k):
29 """Iterate over all possible combinations of K of the THINGS."""
30 ii = range(k)
31 n = len(things)
32 while True:
33 yield [things[i] for i in ii]
34 for j in xrange(k):
35 if j == k - 1: lim = n
36 else: lim = ii[j + 1]
37 i = ii[j] + 1
38 if i < lim:
39 ii[j] = i
40 break
41 ii[j] = j
42 else:
43 return
44
45POLYMAP = {}
46
47def poly(nbits):
48 """
49 Return the lexically first irreducible polynomial of degree NBITS of lowest
50 weight.
51 """
52 try: return POLYMAP[nbits]
53 except KeyError: pass
54 base = C.GF(0).setbit(nbits).setbit(0)
55 for k in xrange(1, nbits, 2):
56 for cc in combs(range(1, nbits), k):
601ec68e 57 p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
9e6a4409
MW
58 if p.irreduciblep(): POLYMAP[nbits] = p; return p
59 raise ValueError, nbits
60
61def gcm_mangle(x):
62 """Flip the bits within each byte according to GCM's insane convention."""
63 y = C.WriteBuffer()
64 for b in x:
65 b = ord(b)
66 bb = 0
67 for i in xrange(8):
68 bb <<= 1
69 if b&1: bb |= 1
70 b >>= 1
71 y.putu8(bb)
72 return y.contents
73
74def endswap_words_32(x):
75 """End-swap each 32-bit word of X."""
76 x = C.ReadBuffer(x)
77 y = C.WriteBuffer()
78 while x.left: y.putu32l(x.getu32b())
79 return y.contents
80
81def endswap_words_64(x):
82 """End-swap each 64-bit word of X."""
83 x = C.ReadBuffer(x)
84 y = C.WriteBuffer()
85 while x.left: y.putu64l(x.getu64b())
86 return y.contents
87
88def endswap_bytes(x):
89 """End-swap X by bytes."""
90 y = C.WriteBuffer()
91 for ch in reversed(x): y.put(ch)
92 return y.contents
93
94def gfmask(n):
95 return C.GF(C.MP(0).setbit(n) - 1)
96
97def gcm_mul(x, y):
98 """Multiply X and Y according to the GCM rules."""
99 w = len(x)
100 p = poly(8*w)
101 u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
102 z = (u*v)%p
103 return gcm_mangle(z.storel(w))
104
105DEMOMAP = {}
106def demo(func):
107 name = func.func_name
108 assert(name.startswith('demo_'))
109 DEMOMAP[name[5:].replace('_', '-')] = func
110 return func
111
112def iota(i = 0):
113 vi = [i]
114 def next(): vi[0] += 1; return vi[0] - 1
115 return next
116
117###--------------------------------------------------------------------------
118### Portable table-driven implementation.
119
120def shift_left(x):
121 """Given a field element X (in external format), return X t."""
122 w = len(x)
123 p = poly(8*w)
124 return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p))
125
ac4a43c1
MW
126def shift_right(x):
127 """Given a field element X (in external format), return X/t."""
128 w = len(x)
129 p = poly(8*w)
130 return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x))*p.modinv(2))%p))
131
9e6a4409
MW
132def table_common(u, v, flip, getword, ixmask):
133 """
134 Multiply U by V using table lookup; common for `table-b' and `table-l'.
135
4e7475c2 136 This matches the `simple_mulk_...' implementation in `gcm.c'. One entry
9e6a4409
MW
137 per bit is the best we can manage if we want a constant-time
138 implementation: processing n bits at a time means we need to scan
139 (2^n - 1)/n times as much memory.
140
141 * FLIP is a function (assumed to be an involution) on one argument X to
142 convert X from external format to table-entry format or back again.
143
144 * GETWORD is a function on one argument B to retrieve the next 32-bit
145 chunk of a field element held in a `ReadBuffer'. Bits within a word
146 are processed most-significant first.
147
148 * IXMASK is a mask XORed into table indices to permute the table so that
4e7475c2 149 its order matches that induced by GETWORD.
9e6a4409
MW
150
151 The table is built such that tab[i XOR IXMASK] = U t^i.
152 """
153 w = len(u); assert(w == len(v))
154 a = C.ByteString.zero(w)
155 tab = [None]*(8*w)
156 for i in xrange(8*w):
157 print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u))
158 tab[i ^ ixmask] = flip(u)
159 u = shift_left(u)
160 v = C.ReadBuffer(v)
161 i = 0
162 while v.left:
163 t = getword(v)
164 for j in xrange(32):
165 bit = (t >> 31)&1
166 if bit: a ^= tab[i]
167 print ';; %6s = %d: a <- %s [%9s = %s]' % \
168 ('v[%d]' % (i ^ ixmask), bit, words(a),
169 'utab[%d]' % (i ^ ixmask), words(tab[i]))
170 i += 1; t <<= 1
171 return flip(a)
172
173@demo
174def demo_table_b(u, v):
175 """Big-endian table lookup."""
176 return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0)
177
178@demo
179def demo_table_l(u, v):
180 """Little-endian table lookup."""
58094286 181 return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18)
9e6a4409
MW
182
183###--------------------------------------------------------------------------
184### Implementation using 64×64->128-bit binary polynomial multiplication.
185
186_i = iota()
187TAG_INPUT_U = _i()
188TAG_INPUT_V = _i()
ac4a43c1 189TAG_SHIFTED_V = _i()
9e6a4409
MW
190TAG_KPIECE_U = _i()
191TAG_KPIECE_V = _i()
192TAG_PRODPIECE = _i()
193TAG_PRODSUM = _i()
194TAG_PRODUCT = _i()
9e6a4409
MW
195TAG_REDCBITS = _i()
196TAG_REDCFULL = _i()
197TAG_REDCMIX = _i()
198TAG_OUTPUT = _i()
199
200def split_gf(x, n):
201 n /= 8
202 return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)]
203
204def join_gf(xx, n):
205 x = C.GF(0)
206 for i in xrange(len(xx)): x = (x << n) | xx[i]
207 return x
208
209def present_gf(x, w, n, what):
210 firstp = True
211 m = gfmask(n)
212 for i in xrange(0, w, 128):
213 print ';; %12s%c =%s' % \
214 (firstp and what or '',
215 firstp and ':' or ' ',
216 ''.join([j < w
217 and ' 0x%s' % hex(((x >> j)&m).storeb(n/8))
218 or ''
219 for j in xrange(i, i + 128, n)]))
220 firstp = False
221
222def present_gf_pclmul(tag, wd, x, w, n, what):
223 if tag != TAG_PRODPIECE: present_gf(x, w, n, what)
224
225def reverse(x, w):
226 return C.GF.loadl(x.storeb(w/8))
227
228def rev32(x):
229 w = x.noctets
230 m_ffff = repmask(0xffff, 32, w/4)
231 m_ff = repmask(0xff, 16, w/2)
232 x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff)
233 x = ((x&m_ff) << 8) | ((x >> 8)&m_ff)
234 return x
235
236def rev8(x):
237 w = x.noctets
238 m_0f = repmask(0x0f, 8, w)
239 m_33 = repmask(0x33, 8, w)
240 m_55 = repmask(0x55, 8, w)
241 x = ((x&m_0f) << 4) | ((x >> 4)&m_0f)
242 x = ((x&m_33) << 2) | ((x >> 2)&m_33)
243 x = ((x&m_55) << 1) | ((x >> 1)&m_55)
244 return x
245
e29fe90c 246def present_gf_vmullp64(tag, wd, x, w, n, what):
9e6a4409
MW
247 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
248 return
249 elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT:
250 y = x
251 elif (wd == 96 or wd == 192 or wd == 256) and \
252 TAG_PRODSUM <= tag < TAG_OUTPUT:
253 y = x
254 else:
255 xx = x.storeb(w/8)
256 extra = len(xx)%8
257 if extra: xx += C.ByteString.zero(8 - extra)
258 yb = C.WriteBuffer()
259 for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i])
260 y = C.GF.loadb(yb.contents)
261 present_gf(y, (w + 63)&~63, n, what)
262
263def present_gf_pmull(tag, wd, x, w, n, what):
ac4a43c1 264 if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
9e6a4409 265 return
ac4a43c1 266 elif tag == TAG_INPUT_V or tag == TAG_SHIFTED_V or tag == TAG_KPIECE_V:
188ffeae 267 w = (w + 63)&~63
9e6a4409
MW
268 bx = C.ReadBuffer(x.storeb(w/8))
269 by = C.WriteBuffer()
270 while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk)
271 x = C.GF.loadb(by.contents)
272 w *= 2
273 elif TAG_PRODSUM <= tag <= TAG_PRODUCT:
274 x <<= 1
275 y = reverse(rev8(x), w)
276 present_gf(y, w, n, what)
277
278def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat):
279 """
280 Multiply U by V, returning the product.
281
282 This is the fallback long multiplication.
283 """
284
285 uw, vw = 8*len(u), 8*len(v)
286
287 ## We start by carving the operands into 64-bit pieces. This is
288 ## straightforward except for the 96-bit case, where we end up with two
289 ## short pieces which we pad at the beginning.
91a8f888
MW
290 upad = (-uw)%mulwd; u += C.ByteString.zero(upad); uw += upad
291 vpad = (-vw)%mulwd; v += C.ByteString.zero(vpad); vw += vpad
292 uu = split_gf(u, mulwd); vv = split_gf(v, mulwd)
9e6a4409
MW
293
294 ## Report and accumulate the individual product pieces.
295 x = C.GF(0)
296 ulim, vlim = uw/mulwd, vw/mulwd
297 for i in xrange(ulim + vlim - 2, -1, -1):
298 t = C.GF(0)
299 for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)):
300 s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j]
301 presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd,
302 '%s_%d %s_%d' % (uwhat, i - j, vwhat, j))
303 t += s
304 presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd,
305 '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i))
306 x += t << (mulwd*i)
307 presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat))
308
91a8f888 309 return x >> (upad + vpad)
9e6a4409
MW
310
311def poly64_mul_karatsuba(u, v, klimit, presfn, wd,
312 dispwd, mulwd, uwhat, vwhat):
313 """
314 Multiply U by V, returning the product.
315
316 If the length of U and V is at least KLIMIT, and the operands are otherwise
317 suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to
318 `poly64_mul_simple'.
319 """
320 w = 8*len(u)
321
322 if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0:
323 return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat)
324
325 hw = w/2
326 u0, u1 = u[:hw/8], u[hw/8:]
327 v0, v1 = v[:hw/8], v[hw/8:]
328 uu, vv = u0 ^ u1, v0 ^ v1
329
330 presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat)
331 presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat)
332 uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd,
333 '%s*' % uwhat, '%s*' % vwhat)
334
335 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat)
336 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat)
337 u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd,
338 '%s0' % uwhat, '%s0' % vwhat)
339
340 presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat)
341 presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat)
342 u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd,
343 '%s1' % uwhat, '%s1' % vwhat)
344
345 uvuv = uuvv + u0v0 + u1v1
346 presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat))
347
348 x = u1v1 + (uvuv << hw) + (u0v0 << w)
349 presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat))
350 return x
351
e95b355c 352def poly64_mul(u, v, presfn, dispwd, mulwd, klimit, uwhat, vwhat):
9e6a4409
MW
353 """
354 Multiply U by V using a primitive 64-bit binary polynomial mutliplier.
355
356 Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on
357 x86, and as `vmull.p64'/`pmull' on ARM.
358
359 Operands arrive in a `register format', which is a byte-swapped variant of
360 the external format. Implementations differ on the precise details,
e95b355c 361 though. Returns the double-precision product.
9e6a4409
MW
362 """
363
9e6a4409 364 w = 8*len(u); assert(w == 8*len(v))
e95b355c
MW
365 x = poly64_mul_karatsuba(u, v, klimit, presfn,
366 w, dispwd, mulwd, uwhat, vwhat)
9e6a4409 367
e95b355c 368 return x.storeb(w/4)
9e6a4409 369
e95b355c
MW
370def poly64_redc(y, presfn, dispwd, redcwd):
371 """
372 Reduce a double-precision product X modulo the appropriate polynomial.
373
374 The operand arrives in a `register format', which is a byte-swapped variant
375 of the external format. Implementations differ on the precise details,
376 though. Returns the single-precision reduced value.
377 """
378
379 w = 4*len(y)
380 p = poly(w)
9e6a4409 381
9e6a4409
MW
382 ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i<d} r_i t^i,
383 ## with each r_i either 0 or 1. Because we choose the lexically earliest
384 ## irreducible polynomial with the necessary degree, r_i = 1 happens only
385 ## for a small number of tiny i. In our field, we have t^d = r.
386 ##
387 ## We carve the product into convenient n-bit pieces, for some n dividing d
388 ## -- typically n = 32 or 64. Let d = m n, and write y = SUM_{0<=i<2m} y_i
389 ## t^{ni}. The upper portion, the y_i with i >= m, needs reduction; but
390 ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and
391 ## add it to the bottom half. This all depends on r_i = 0 for all i >=
392 ## n/2. We process each nonzero coefficient of r separately, in two
393 ## passes.
394 ##
395 ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j
396 ## bits (or would be if GCM weren't backwards, but let's not worry about
397 ## that right now). The high j bits will spill over into the next chunk,
398 ## while the low n - j bits will stay where they are. It's these high bits
399 ## which cause trouble -- particularly the high bits of the top chunk,
400 ## since we'll add them on to y_m, which will need further reduction. But
401 ## only the topmost j bits will do this.
402 ##
403 ## The trick is that we do all of the bits which spill over first -- all of
404 ## the top j bits in each chunk, for each j -- in one pass, and then a
405 ## second pass of all the bits which don't. Because j, j' < n/2 for any
406 ## two nonzero coefficient degrees j and j', we have j + j' < n whence j <
407 ## n - j' -- so all of the bits contributed to y_m will be handled in the
408 ## second pass when we handle the bits that don't spill over.
409 rr = [i for i in xrange(1, w) if p.testbit(i)]
410 m = gfmask(redcwd)
411
412 ## Handle the spilling bits.
e95b355c 413 yy = split_gf(y, redcwd)
9e6a4409
MW
414 b = C.GF(0)
415 for rj in rr:
416 br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]]
417 presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj)
418 b += join_gf(br, redcwd) << (w - redcwd)
419 presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b')
e95b355c 420 s = C.GF.loadb(y) + b
9e6a4409
MW
421 presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's')
422
423 ## Handle the nonspilling bits.
424 ss = split_gf(s.storeb(w/4), redcwd)
425 a = C.GF(0)
426 for rj in rr:
427 ar = [si >> rj for si in ss[w/redcwd:]]
428 presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj)
429 a += join_gf(ar, redcwd)
430 presfn(TAG_REDCFULL, w, a, w, dispwd, 'a')
431
432 ## Mix everything together.
433 m = gfmask(w)
434 z = (s&m) + (s >> w) + a
435 presfn(TAG_OUTPUT, w, z, w, dispwd, 'z')
436
437 ## And we're done.
438 return z.storeb(w/8)
439
ac4a43c1
MW
440def poly64_shiftcommon(u, v, presfn, dispwd = 32, mulwd = 64,
441 redcwd = 32, klimit = 256):
442 w = 8*len(u)
443 presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
444 presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
445 vv = shift_right(v)
446 presfn(TAG_SHIFTED_V, w, C.GF.loadb(vv), w, dispwd, "v'")
447 y = poly64_mul(u, vv, presfn, dispwd, mulwd, klimit, "u", "v'")
448 z = poly64_redc(y, presfn, dispwd, redcwd)
449 return z
450
451def poly64_directcommon(u, v, presfn, dispwd = 32, mulwd = 64,
452 redcwd = 32, klimit = 256):
e95b355c
MW
453 w = 8*len(u)
454 presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
455 presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
456 y = poly64_mul(u, v, presfn, dispwd, mulwd, klimit, "u", "v")
457 y = (C.GF.loadb(y) << 1).storeb(w/4)
458 z = poly64_redc(y, presfn, dispwd, redcwd)
459 return z
460
9e6a4409
MW
461@demo
462def demo_pclmul(u, v):
ac4a43c1 463 return poly64_shiftcommon(u, v, presfn = present_gf_pclmul)
9e6a4409
MW
464
465@demo
466def demo_vmullp64(u, v):
467 w = 8*len(u)
ac4a43c1
MW
468 return poly64_shiftcommon(u, v, presfn = present_gf_vmullp64,
469 redcwd = w%64 == 32 and 32 or 64)
9e6a4409
MW
470
471@demo
472def demo_pmull(u, v):
473 w = 8*len(u)
ac4a43c1
MW
474 return poly64_directcommon(u, v, presfn = present_gf_pmull,
475 redcwd = w%64 == 32 and 32 or 64)
9e6a4409
MW
476
477###--------------------------------------------------------------------------
478### @@@ Random debris to be deleted. @@@
479
480def cutting_room_floor():
481
482 x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df')
483 y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332')
484
485 u, v = C.GF.loadb(x), C.GF.loadb(y)
486
487 g = u*v << 1
488 print 'y = %s' % words(g.storeb(48))
489 b1 = (g&repmask(0x01, 32, 6)) << 191
490 b2 = (g&repmask(0x03, 32, 6)) << 190
491 b7 = (g&repmask(0x7f, 32, 6)) << 185
492 b = b1 + b2 + b7
493 print 'b = %s' % words(b.storeb(48)[0:28])
494 h = g + b
495 print 'w = %s' % words(h.storeb(48))
496
497 a0 = (h&repmask(0xffffffff, 32, 6)) << 192
498 a1 = (h&repmask(0xfffffffe, 32, 6)) << 191
499 a2 = (h&repmask(0xfffffffc, 32, 6)) << 190
500 a7 = (h&repmask(0xffffff80, 32, 6)) << 185
501 a = a0 + a1 + a2 + a7
502
503 print ' a_1 = %s' % words(a1.storeb(48)[0:24])
504 print ' a_2 = %s' % words(a2.storeb(48)[0:24])
505 print ' a_7 = %s' % words(a7.storeb(48)[0:24])
506
507 print 'low+unit = %s' % words((h + a0).storeb(48)[0:24])
508 print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24])
509 print ' 1,7 = %s' % words((a1 + a7).storeb(48)[0:24])
510
511 print 'a = %s' % words(a.storeb(48)[0:24])
512 z = h + a
513 print 'z = %s' % words(z.storeb(48))
514
515 z = gcm_mul(x, y)
516 print 'u v mod p = %s' % words(z)
517
518###--------------------------------------------------------------------------
519### Main program.
520
521style = argv[1]
522u = C.bytes(argv[2])
523v = C.bytes(argv[3])
524zz = DEMOMAP[style](u, v)
525assert zz == gcm_mul(u, v)
526
527###----- That's all, folks --------------------------------------------------