#! /usr/bin/python ### -*- coding: utf-8 -*- from sys import argv, exit import catacomb as C ###-------------------------------------------------------------------------- ### Random utilities. def words(s): """Split S into 32-bit pieces and report their values as hex.""" return ' '.join('%08x' % C.MP.loadb(s[i:i + 4]) for i in xrange(0, len(s), 4)) def words_64(s): """Split S into 64-bit pieces and report their values as hex.""" return ' '.join('%016x' % C.MP.loadb(s[i:i + 8]) for i in xrange(0, len(s), 8)) def repmask(val, wd, n): """Return a mask consisting of N copies of the WD-bit value VAL.""" v = C.GF(val) a = C.GF(0) for i in xrange(n): a = (a << wd) | v return a def combs(things, k): """Iterate over all possible combinations of K of the THINGS.""" ii = range(k) n = len(things) while True: yield [things[i] for i in ii] for j in xrange(k): if j == k - 1: lim = n else: lim = ii[j + 1] i = ii[j] + 1 if i < lim: ii[j] = i break ii[j] = j else: return POLYMAP = {} def poly(nbits): """ Return the lexically first irreducible polynomial of degree NBITS of lowest weight. """ try: return POLYMAP[nbits] except KeyError: pass base = C.GF(0).setbit(nbits).setbit(0) for k in xrange(1, nbits, 2): for cc in combs(range(1, nbits), k): p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0)) if p.irreduciblep(): POLYMAP[nbits] = p; return p raise ValueError, nbits def gcm_mangle(x): """Flip the bits within each byte according to GCM's insane convention.""" y = C.WriteBuffer() for b in x: b = ord(b) bb = 0 for i in xrange(8): bb <<= 1 if b&1: bb |= 1 b >>= 1 y.putu8(bb) return y.contents def endswap_words_32(x): """End-swap each 32-bit word of X.""" x = C.ReadBuffer(x) y = C.WriteBuffer() while x.left: y.putu32l(x.getu32b()) return y.contents def endswap_words_64(x): """End-swap each 64-bit word of X.""" x = C.ReadBuffer(x) y = C.WriteBuffer() while x.left: y.putu64l(x.getu64b()) return y.contents def endswap_bytes(x): """End-swap X by bytes.""" y = C.WriteBuffer() for ch in reversed(x): y.put(ch) return y.contents def gfmask(n): return C.GF(C.MP(0).setbit(n) - 1) def gcm_mul(x, y): """Multiply X and Y according to the GCM rules.""" w = len(x) p = poly(8*w) u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y)) z = (u*v)%p return gcm_mangle(z.storel(w)) DEMOMAP = {} def demo(func): name = func.func_name assert(name.startswith('demo_')) DEMOMAP[name[5:].replace('_', '-')] = func return func def iota(i = 0): vi = [i] def next(): vi[0] += 1; return vi[0] - 1 return next ###-------------------------------------------------------------------------- ### Portable table-driven implementation. def shift_left(x): """Given a field element X (in external format), return X t.""" w = len(x) p = poly(8*w) return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p)) def shift_right(x): """Given a field element X (in external format), return X/t.""" w = len(x) p = poly(8*w) return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x))*p.modinv(2))%p)) def table_common(u, v, flip, getword, ixmask): """ Multiply U by V using table lookup; common for `table-b' and `table-l'. This matches the `simple_mulk_...' implementation in `gcm.c'. One entry per bit is the best we can manage if we want a constant-time implementation: processing n bits at a time means we need to scan (2^n - 1)/n times as much memory. * FLIP is a function (assumed to be an involution) on one argument X to convert X from external format to table-entry format or back again. * GETWORD is a function on one argument B to retrieve the next 32-bit chunk of a field element held in a `ReadBuffer'. Bits within a word are processed most-significant first. * IXMASK is a mask XORed into table indices to permute the table so that its order matches that induced by GETWORD. The table is built such that tab[i XOR IXMASK] = U t^i. """ w = len(u); assert(w == len(v)) a = C.ByteString.zero(w) tab = [None]*(8*w) for i in xrange(8*w): print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u)) tab[i ^ ixmask] = flip(u) u = shift_left(u) v = C.ReadBuffer(v) i = 0 while v.left: t = getword(v) for j in xrange(32): bit = (t >> 31)&1 if bit: a ^= tab[i] print ';; %6s = %d: a <- %s [%9s = %s]' % \ ('v[%d]' % (i ^ ixmask), bit, words(a), 'utab[%d]' % (i ^ ixmask), words(tab[i])) i += 1; t <<= 1 return flip(a) @demo def demo_table_b(u, v): """Big-endian table lookup.""" return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0) @demo def demo_table_l(u, v): """Little-endian table lookup.""" return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18) ###-------------------------------------------------------------------------- ### Implementation using 64×64->128-bit binary polynomial multiplication. _i = iota() TAG_INPUT_U = _i() TAG_INPUT_V = _i() TAG_SHIFTED_V = _i() TAG_KPIECE_U = _i() TAG_KPIECE_V = _i() TAG_PRODPIECE = _i() TAG_PRODSUM = _i() TAG_PRODUCT = _i() TAG_REDCBITS = _i() TAG_REDCFULL = _i() TAG_REDCMIX = _i() TAG_OUTPUT = _i() def split_gf(x, n): n /= 8 return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)] def join_gf(xx, n): x = C.GF(0) for i in xrange(len(xx)): x = (x << n) | xx[i] return x def present_gf(x, w, n, what): firstp = True m = gfmask(n) for i in xrange(0, w, 128): print ';; %12s%c =%s' % \ (firstp and what or '', firstp and ':' or ' ', ''.join([j < w and ' 0x%s' % hex(((x >> j)&m).storeb(n/8)) or '' for j in xrange(i, i + 128, n)])) firstp = False def present_gf_pclmul(tag, wd, x, w, n, what): if tag != TAG_PRODPIECE: present_gf(x, w, n, what) def reverse(x, w): return C.GF.loadl(x.storeb(w/8)) def rev32(x): w = x.noctets m_ffff = repmask(0xffff, 32, w/4) m_ff = repmask(0xff, 16, w/2) x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff) x = ((x&m_ff) << 8) | ((x >> 8)&m_ff) return x def rev8(x): w = x.noctets m_0f = repmask(0x0f, 8, w) m_33 = repmask(0x33, 8, w) m_55 = repmask(0x55, 8, w) x = ((x&m_0f) << 4) | ((x >> 4)&m_0f) x = ((x&m_33) << 2) | ((x >> 2)&m_33) x = ((x&m_55) << 1) | ((x >> 1)&m_55) return x def present_gf_vmullp64(tag, wd, x, w, n, what): if tag == TAG_PRODPIECE or tag == TAG_REDCFULL: return elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT: y = x elif (wd == 96 or wd == 192 or wd == 256) and \ TAG_PRODSUM <= tag < TAG_OUTPUT: y = x else: xx = x.storeb(w/8) extra = len(xx)%8 if extra: xx += C.ByteString.zero(8 - extra) yb = C.WriteBuffer() for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i]) y = C.GF.loadb(yb.contents) present_gf(y, (w + 63)&~63, n, what) def present_gf_pmull(tag, wd, x, w, n, what): if tag == TAG_PRODPIECE or tag == TAG_REDCFULL: return elif tag == TAG_INPUT_V or tag == TAG_SHIFTED_V or tag == TAG_KPIECE_V: w = (w + 63)&~63 bx = C.ReadBuffer(x.storeb(w/8)) by = C.WriteBuffer() while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk) x = C.GF.loadb(by.contents) w *= 2 elif TAG_PRODSUM <= tag <= TAG_PRODUCT: x <<= 1 y = reverse(rev8(x), w) present_gf(y, w, n, what) def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat): """ Multiply U by V, returning the product. This is the fallback long multiplication. """ uw, vw = 8*len(u), 8*len(v) ## We start by carving the operands into 64-bit pieces. This is ## straightforward except for the 96-bit case, where we end up with two ## short pieces which we pad at the beginning. upad = (-uw)%mulwd; u += C.ByteString.zero(upad); uw += upad vpad = (-vw)%mulwd; v += C.ByteString.zero(vpad); vw += vpad uu = split_gf(u, mulwd); vv = split_gf(v, mulwd) ## Report and accumulate the individual product pieces. x = C.GF(0) ulim, vlim = uw/mulwd, vw/mulwd for i in xrange(ulim + vlim - 2, -1, -1): t = C.GF(0) for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)): s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j] presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd, '%s_%d %s_%d' % (uwhat, i - j, vwhat, j)) t += s presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd, '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i)) x += t << (mulwd*i) presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat)) return x >> (upad + vpad) def poly64_mul_karatsuba(u, v, klimit, presfn, wd, dispwd, mulwd, uwhat, vwhat): """ Multiply U by V, returning the product. If the length of U and V is at least KLIMIT, and the operands are otherwise suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to `poly64_mul_simple'. """ w = 8*len(u) if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0: return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat) hw = w/2 u0, u1 = u[:hw/8], u[hw/8:] v0, v1 = v[:hw/8], v[hw/8:] uu, vv = u0 ^ u1, v0 ^ v1 presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat) presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat) uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd, '%s*' % uwhat, '%s*' % vwhat) presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat) presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat) u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd, '%s0' % uwhat, '%s0' % vwhat) presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat) presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat) u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd, '%s1' % uwhat, '%s1' % vwhat) uvuv = uuvv + u0v0 + u1v1 presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat)) x = u1v1 + (uvuv << hw) + (u0v0 << w) presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat)) return x def poly64_mul(u, v, presfn, dispwd, mulwd, klimit, uwhat, vwhat): """ Multiply U by V using a primitive 64-bit binary polynomial mutliplier. Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on x86, and as `vmull.p64'/`pmull' on ARM. Operands arrive in a `register format', which is a byte-swapped variant of the external format. Implementations differ on the precise details, though. Returns the double-precision product. """ w = 8*len(u); assert(w == 8*len(v)) x = poly64_mul_karatsuba(u, v, klimit, presfn, w, dispwd, mulwd, uwhat, vwhat) return x.storeb(w/4) def poly64_redc(y, presfn, dispwd, redcwd): """ Reduce a double-precision product X modulo the appropriate polynomial. The operand arrives in a `register format', which is a byte-swapped variant of the external format. Implementations differ on the precise details, though. Returns the single-precision reduced value. """ w = 4*len(y) p = poly(w) ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i= m, needs reduction; but ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and ## add it to the bottom half. This all depends on r_i = 0 for all i >= ## n/2. We process each nonzero coefficient of r separately, in two ## passes. ## ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j ## bits (or would be if GCM weren't backwards, but let's not worry about ## that right now). The high j bits will spill over into the next chunk, ## while the low n - j bits will stay where they are. It's these high bits ## which cause trouble -- particularly the high bits of the top chunk, ## since we'll add them on to y_m, which will need further reduction. But ## only the topmost j bits will do this. ## ## The trick is that we do all of the bits which spill over first -- all of ## the top j bits in each chunk, for each j -- in one pass, and then a ## second pass of all the bits which don't. Because j, j' < n/2 for any ## two nonzero coefficient degrees j and j', we have j + j' < n whence j < ## n - j' -- so all of the bits contributed to y_m will be handled in the ## second pass when we handle the bits that don't spill over. rr = [i for i in xrange(1, w) if p.testbit(i)] m = gfmask(redcwd) ## Handle the spilling bits. yy = split_gf(y, redcwd) b = C.GF(0) for rj in rr: br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]] presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj) b += join_gf(br, redcwd) << (w - redcwd) presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b') s = C.GF.loadb(y) + b presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's') ## Handle the nonspilling bits. ss = split_gf(s.storeb(w/4), redcwd) a = C.GF(0) for rj in rr: ar = [si >> rj for si in ss[w/redcwd:]] presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj) a += join_gf(ar, redcwd) presfn(TAG_REDCFULL, w, a, w, dispwd, 'a') ## Mix everything together. m = gfmask(w) z = (s&m) + (s >> w) + a presfn(TAG_OUTPUT, w, z, w, dispwd, 'z') ## And we're done. return z.storeb(w/8) def poly64_shiftcommon(u, v, presfn, dispwd = 32, mulwd = 64, redcwd = 32, klimit = 256): w = 8*len(u) presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u') presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v') vv = shift_right(v) presfn(TAG_SHIFTED_V, w, C.GF.loadb(vv), w, dispwd, "v'") y = poly64_mul(u, vv, presfn, dispwd, mulwd, klimit, "u", "v'") z = poly64_redc(y, presfn, dispwd, redcwd) return z def poly64_directcommon(u, v, presfn, dispwd = 32, mulwd = 64, redcwd = 32, klimit = 256): w = 8*len(u) presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u') presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v') y = poly64_mul(u, v, presfn, dispwd, mulwd, klimit, "u", "v") y = (C.GF.loadb(y) << 1).storeb(w/4) z = poly64_redc(y, presfn, dispwd, redcwd) return z @demo def demo_pclmul(u, v): return poly64_shiftcommon(u, v, presfn = present_gf_pclmul) @demo def demo_vmullp64(u, v): w = 8*len(u) return poly64_shiftcommon(u, v, presfn = present_gf_vmullp64, redcwd = w%64 == 32 and 32 or 64) @demo def demo_pmull(u, v): w = 8*len(u) return poly64_directcommon(u, v, presfn = present_gf_pmull, redcwd = w%64 == 32 and 32 or 64) ###-------------------------------------------------------------------------- ### @@@ Random debris to be deleted. @@@ def cutting_room_floor(): x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df') y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332') u, v = C.GF.loadb(x), C.GF.loadb(y) g = u*v << 1 print 'y = %s' % words(g.storeb(48)) b1 = (g&repmask(0x01, 32, 6)) << 191 b2 = (g&repmask(0x03, 32, 6)) << 190 b7 = (g&repmask(0x7f, 32, 6)) << 185 b = b1 + b2 + b7 print 'b = %s' % words(b.storeb(48)[0:28]) h = g + b print 'w = %s' % words(h.storeb(48)) a0 = (h&repmask(0xffffffff, 32, 6)) << 192 a1 = (h&repmask(0xfffffffe, 32, 6)) << 191 a2 = (h&repmask(0xfffffffc, 32, 6)) << 190 a7 = (h&repmask(0xffffff80, 32, 6)) << 185 a = a0 + a1 + a2 + a7 print ' a_1 = %s' % words(a1.storeb(48)[0:24]) print ' a_2 = %s' % words(a2.storeb(48)[0:24]) print ' a_7 = %s' % words(a7.storeb(48)[0:24]) print 'low+unit = %s' % words((h + a0).storeb(48)[0:24]) print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24]) print ' 1,7 = %s' % words((a1 + a7).storeb(48)[0:24]) print 'a = %s' % words(a.storeb(48)[0:24]) z = h + a print 'z = %s' % words(z.storeb(48)) z = gcm_mul(x, y) print 'u v mod p = %s' % words(z) ###-------------------------------------------------------------------------- ### Main program. style = argv[1] u = C.bytes(argv[2]) v = C.bytes(argv[3]) zz = DEMOMAP[style](u, v) assert zz == gcm_mul(u, v) ###----- That's all, folks --------------------------------------------------