#! /usr/bin/python from sys import argv, exit from struct import unpack, pack from itertools import izip import catacomb as C R = C.FibRand(0) ###-------------------------------------------------------------------------- ### Utilities. def combs(things, k): ii = range(k) n = len(things) while True: yield [things[i] for i in ii] for j in xrange(k): if j == k - 1: lim = n else: lim = ii[j + 1] i = ii[j] + 1 if i < lim: ii[j] = i break ii[j] = j else: return POLYMAP = {} def poly(nbits): try: return POLYMAP[nbits] except KeyError: pass base = C.GF(0).setbit(nbits).setbit(0) for k in xrange(1, nbits, 2): for cc in combs(range(1, nbits), k): p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0)) if p.irreduciblep(): POLYMAP[nbits] = p; return p raise ValueError, nbits def prim(nbits): ## No fancy way to do this: I'd need a much cleverer factoring algorithm ## than I have in my pockets. if nbits == 64: cc = [64, 4, 3, 1, 0] elif nbits == 96: cc = [96, 10, 9, 6, 0] elif nbits == 128: cc = [128, 7, 2, 1, 0] elif nbits == 192: cc = [192, 15, 11, 5, 0] elif nbits == 256: cc = [256, 10, 5, 2, 0] else: raise ValueError, 'no field for %d bits' % nbits p = C.GF(0) for c in cc: p = p.setbit(c) return p def Z(n): return C.ByteString.zero(n) def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8) def with_lastp(it): it = iter(it) try: j = next(it) except StopIteration: raise ValueError, 'empty iter' lastp = False while not lastp: i = j try: j = next(it) except StopIteration: lastp = True yield i, lastp def safehex(x): if len(x): return hex(x) else: return '""' def keylens(ksz): sel = [] if isinstance(ksz, C.KeySZSet): kk = ksz.set elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod) elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0] kk = list(kk); kk = kk[:] n = len(kk) while n and len(sel) < 4: i = R.range(n) n -= 1 kk[i], kk[n] = kk[n], kk[i] sel.append(kk[n]) return sel def pad0star(m, w): n = len(m) if not n: r = w else: r = (-len(m))%w if r: m += Z(r) return C.ByteString(m) def pad10star(m, w): r = w - len(m)%w if r: m += '\x80' + Z(r - 1) return C.ByteString(m) def ntz(i): j = 0 while (i&1) == 0: i >>= 1; j += 1 return j def blocks(x, w): v, i, n = [], 0, len(x) while n - i > w: v.append(C.ByteString(x[i:i + w])) i += w return v, C.ByteString(x[i:]) EMPTY = C.bytes('') def blocks0(x, w): v, tl = blocks(x, w) if len(tl) == w: v.append(tl); tl = EMPTY return v, tl def dummygen(bc): return [] CUSTOM = {} ###-------------------------------------------------------------------------- ### RC6. class RC6Cipher (type): def __new__(cls, w, r): name = 'rc6-%d/%d' % (w, r) me = type(name, (RC6Base,), {}) me.name = name me.r = r me.w = w me.blksz = w/2 me.keysz = C.KeySZRange(me.blksz, 1, 255, 1) return me def rotw(w): return w.bit_length() - 1 def rol(w, x, n): m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1 return ((x&m0) << n) | (x >> (w - n))&m1 def ror(w, x, n): m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1 return ((x&m0) << (w - n)) | (x >> n)&m1 class RC6Base (object): ## Magic constants. P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06) Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188) def __init__(me, k): ## Build the magic numbers. P = me.P400 >> (400 - me.w) if P%2 == 0: P += 1 Q = me.Q400 >> (400 - me.w) if Q%2 == 0: Q += 1 M = C.MP(0).setbit(me.w) - 1 ## Convert the key into words. wb = me.w/8 c = (len(k) + wb - 1)/wb kb, ktl = blocks(k, me.w/8) L = map(C.MP.loadl, kb + [ktl]) assert c == len(L) ## Build the subkey table. me.d = rotw(me.w) n = 2*me.r + 4 S = [(P + i*Q)&M for i in xrange(n)] ##for j in xrange(c): ## print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0')) ##for i in xrange(n): ## print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0')) i = j = 0 A = B = C.MP(0) for s in xrange(3*max(c, n)): A = S[i] = rol(me.w, S[i] + A + B, 3) B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d)) ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0')) ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0')) i = (i + 1)%n j = (j + 1)%c ## Done. me.s = S def encrypt(me, x): M = C.MP(0).setbit(me.w) - 1 a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0]) b = (b + me.s[0])&M d = (d + me.s[1])&M ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0')) ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0')) for i in xrange(2, 2*me.r + 2, 2): t = rol(me.w, 2*b*b + b, me.d) u = rol(me.w, 2*d*d + d, me.d) a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0')) ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0')) a, b, c, d = b, c, d, a a = (a + me.s[2*me.r + 2])&M c = (c + me.s[2*me.r + 3])&M ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0')) ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0')) return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) + c.storel(me.blksz/4) + d.storel(me.blksz/4)) def decrypt(me, x): M = C.MP(0).setbit(me.w) - 1 a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)) c = (c - me.s[2*me.r + 3])&M a = (a - me.s[2*me.r + 2])&M for i in xrange(2*me.r + 1, 1, -2): a, b, c, d = d, a, b, c u = rol(me.w, 2*d*d + d, me.d) t = rol(me.w, 2*b*b + b, me.d) c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t a = (a + s[2*me.r + 2])&M c = (c + s[2*me.r + 3])&M return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) + c.storel(me.blksz/4) + d.storel(me.blksz/4)) for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16), (32, 20), (48, 16), (64, 16), (96, 16), (128, 16), (192, 16), (256, 16), (400, 16)]: CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r) ###-------------------------------------------------------------------------- ### OMAC (or CMAC). def omac_masks(E): blksz = E.__class__.blksz p = poly(8*blksz) z = Z(blksz) L = E.encrypt(z) m0 = mul_blk_gf(L, C.GF(2), p) m1 = mul_blk_gf(m0, C.GF(2), p) return m0, m1 def dump_omac(E): blksz = E.__class__.blksz m0, m1 = omac_masks(E) print 'L = %s' % hex(E.encrypt(Z(blksz))) print 'm0 = %s' % hex(m0) print 'm1 = %s' % hex(m1) for t in xrange(3): print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz)))) print 'z%d = %s' % (t, hex(omac(E, t, ''))) def omac(E, t, m): blksz = E.__class__.blksz m0, m1 = omac_masks(E) a = Z(blksz) if t is not None: m = C.MP(t).storeb(blksz) + m v, tl = blocks(m, blksz) for x in v: a = E.encrypt(a ^ x) r = blksz - len(tl) if r == 0: a = E.encrypt(a ^ tl ^ m0) else: pad = pad10star(tl, blksz) a = E.encrypt(a ^ pad ^ m1) return a def cmac(E, m): if VERBOSE: dump_omac(E) return omac(E, None, m), def cmacgen(bc): return [(0,), (1,), (3*bc.blksz,), (3*bc.blksz - 5,)] ###-------------------------------------------------------------------------- ### Counter mode. def ctr(E, m, c0): blksz = E.__class__.blksz y = C.WriteBuffer() c = C.MP.loadb(c0) while y.size < len(m): y.put(E.encrypt(c.storeb(blksz))) c += 1 return C.ByteString(m) ^ C.ByteString(y)[:len(m)] ###-------------------------------------------------------------------------- ### GCM. def gcm_mangle(x): y = C.WriteBuffer() for b in x: b = ord(b) bb = 0 for i in xrange(8): bb <<= 1 if b&1: bb |= 1 b >>= 1 y.putu8(bb) return C.ByteString(y) def gcm_mul(x, y): w = len(x) p = poly(8*w) u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y)) z = (u*v)%p return gcm_mangle(z.storel(w)) def gcm_pow(x, n): w = len(x) p = poly(8*w) u = C.GF.loadl(gcm_mangle(x)) z = pow(u, n, p) return gcm_mangle(z.storel(w)) def gcm_ctr(E, m, c0): y = C.WriteBuffer() pre = c0[:-4] c, = unpack('>L', c0[-4:]) while y.size < len(m): c += 1 y.put(E.encrypt(pre + pack('>L', c))) return C.ByteString(m) ^ C.ByteString(y)[:len(m)] def g(what, x, m, a0 = None): n = len(x) if a0 is None: a = Z(n) else: a = a0 i = 0 for b in blocks0(m, n)[0]: a = gcm_mul(a ^ b, x) if VERBOSE: print '%s[%d] = %s -> %s' % (what, i, hex(b), hex(a)) i += 1 return a def gcm_pad(w, x): return C.ByteString(x + Z(-len(x)%w)) def gcm_lens(w, a, b): if w < 12: n = w else: n = w/2 return C.ByteString(C.MP(a).storeb(n) + C.MP(b).storeb(n)) def ghash(whata, whatb, x, a, b): w = len(x) ha = g(whata, x, gcm_pad(w, a)) hb = g(whatb, x, gcm_pad(w, b)) if a: hc = gcm_mul(ha, gcm_pow(x, (len(b) + w - 1)/w)) ^ hb if VERBOSE: print '%s || %s -> %s' % (whata, whatb, hex(hc)) else: hc = hb return g(whatb, x, gcm_lens(w, 8*len(a), 8*len(b)), hc) def gcmenc(E, n, h, m, tsz = None): w = E.__class__.blksz x = E.encrypt(Z(w)) if VERBOSE: print 'x = %s' % hex(x) if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1)) else: c0 = ghash('?', 'n', x, EMPTY, n) if VERBOSE: print 'c0 = %s' % hex(c0) y = gcm_ctr(E, m, c0) t = ghash('h', 'y', x, h, y) ^ E.encrypt(c0) return y, t def gcmdec(E, n, h, y, t): w = E.__class__.blksz x = E.encrypt(Z(w)) if VERBOSE: print 'x = %s' % hex(x) if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1)) else: c0 = ghash('?', 'n', x, EMPTY, n) if VERBOSE: print 'c0 = %s' % hex(c0) m = gcm_ctr(E, y, c0) tt = ghash('h', 'y', x, h, y) ^ E.encrypt(c0) if t == tt: return m, else: return None, def gcmgen(bc): return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (bc.blksz, 3*bc.blksz, 3*bc.blksz), (bc.blksz - 4, bc.blksz + 3, 3*bc.blksz + 9), (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)] def gcm_mul_tests(nbits): print 'gcm-mul%d {' % nbits for i in xrange(64): x = R.block(nbits/8) y = R.block(nbits/8) z = gcm_mul(x, y) print ' %s\n %s\n %s;' % (hex(x), hex(y), hex(z)) print '}' ###-------------------------------------------------------------------------- ### CCM. def stbe(n, w): return C.MP(n).storeb(w) def ccm_fmthdr(blksz, n, hsz, msz, tsz): b = C.WriteBuffer() if blksz == 8: q = blksz - len(n) - 1 f = 0 if hsz: f |= 0x40 f |= (tsz - 1) << 3 f |= q - 1 b.putu8(f).put(n).put(stbe(msz, q)) elif blksz == 16: q = blksz - len(n) - 1 f = 0 if hsz: f |= 0x40 f |= (tsz - 2)/2 << 3 f |= q - 1 b.putu8(f).put(n).put(stbe(msz, q)) else: q = blksz - len(n) - 2 f0 = f1 = 0 if hsz: f1 |= 0x80 f0 |= tsz f1 |= q b.putu8(f0).putu8(f1).put(n).put(stbe(msz, q)) b = C.ByteString(b) if VERBOSE: print 'hdr = %s' % hex(b) return b def ccm_fmtctr(blksz, n, i = 0): b = C.WriteBuffer() if blksz == 8 or blksz == 16: q = blksz - len(n) - 1 b.putu8(q - 1).put(n).put(stbe(i, q)) else: q = blksz - len(n) - 2 b.putu8(0).putu8(q).put(n).put(stbe(i, q)) b = C.ByteString(b) if VERBOSE: print 'ctr = %s' % hex(b) return b def ccmaad(b, h, blksz): hsz = len(h) if not hsz: pass elif hsz < 0xfffe: b.putu16(hsz) elif hsz <= 0xffffffff: b.putu16(0xfffe).putu32(hsz) else: b.putu16(0xffff).putu64(hsz) b.put(h); b.zero((-b.size)%blksz) def ccmenc(E, n, h, m, tsz = None): blksz = E.__class__.blksz if tsz is None: tsz = blksz b = C.WriteBuffer() b.put(ccm_fmthdr(blksz, n, len(h), len(m), tsz)) ccmaad(b, h, blksz) b.put(m); b.zero((-b.size)%blksz) b = C.ByteString(b) a = Z(blksz) v, _ = blocks0(b, blksz) i = 0 for x in v: a = E.encrypt(a ^ x) if VERBOSE: print 'b[%d] = %s' % (i, hex(x)) print 'a[%d] = %s' % (i + 1, hex(a)) i += 1 y = ctr(E, a + m, ccm_fmtctr(blksz, n)) return C.ByteString(y[blksz:]), C.ByteString(y[0:tsz]) def ccmdec(E, n, h, y, t): blksz = E.__class__.blksz tsz = len(t) b = C.WriteBuffer() b.put(ccm_fmthdr(blksz, n, len(h), len(y), tsz)) ccmaad(b, h, blksz) mm = ctr(E, t + Z(blksz - tsz) + y, ccm_fmtctr(blksz, n)) u, m = C.ByteString(mm[0:tsz]), C.ByteString(mm[blksz:]) b.put(m); b.zero((-b.size)%blksz) b = C.ByteString(b) a = Z(blksz) v, _ = blocks0(b, blksz) i = 0 for x in v: a = E.encrypt(a ^ x) if VERBOSE: print 'b[%d] = %s' % (i, hex(x)) print 'a[%d] = %s' % (i + 1, hex(a)) i += 1 if u == a[:tsz]: return m, else: return None, def ccmgen(bc): bsz = bc.blksz return [(bsz - 5, 0, 0, 4), (bsz - 5, 1, 0, 4), (bsz - 5, 0, 1, 4), (bsz/2 + 1, 3*bc.blksz, 3*bc.blksz), (bsz/2 + 1, 3*bc.blksz - 5, 3*bc.blksz + 5)] ###-------------------------------------------------------------------------- ### EAX. def eaxenc(E, n, h, m, tsz = None): if VERBOSE: print 'k = %s' % hex(k) print 'n = %s' % hex(n) print 'h = %s' % hex(h) print 'm = %s' % hex(m) dump_omac(E) if tsz is None: tsz = E.__class__.blksz c0 = omac(E, 0, n) y = ctr(E, m, c0) ht = omac(E, 1, h) yt = omac(E, 2, y) if VERBOSE: print 'c0 = %s' % hex(c0) print 'ht = %s' % hex(ht) print 'yt = %s' % hex(yt) return y, C.ByteString((c0 ^ ht ^ yt)[:tsz]) def eaxdec(E, n, h, y, t): if VERBOSE: print 'k = %s' % hex(k) print 'n = %s' % hex(n) print 'h = %s' % hex(h) print 'y = %s' % hex(y) print 't = %s' % hex(t) dump_omac(E) c0 = omac(E, 0, n) m = ctr(E, y, c0) ht = omac(E, 1, h) yt = omac(E, 2, y) if VERBOSE: print 'c0 = %s' % hex(c0) print 'ht = %s' % hex(ht) print 'yt = %s' % hex(yt) if t == (c0 ^ ht ^ yt)[:len(t)]: return m, else: return None, def eaxgen(bc): return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (bc.blksz, 3*bc.blksz, 3*bc.blksz), (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)] ###-------------------------------------------------------------------------- ### PMAC. def ocb_masks(E): blksz = E.__class__.blksz p = poly(8*blksz) x = C.GF(2); xinv = p.modinv(x) z = Z(blksz) L = E.encrypt(z) Lxinv = mul_blk_gf(L, xinv, p) Lgamma = 66*[L] for i in xrange(1, len(Lgamma)): Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p) return Lgamma, Lxinv def dump_ocb(E): Lgamma, Lxinv = ocb_masks(E) print 'L x^-1 = %s' % hex(Lxinv) for i, lg in enumerate(Lgamma[:16]): print 'L x^%d = %s' % (i, hex(lg)) def pmac1(E, m): blksz = E.__class__.blksz Lgamma, Lxinv = ocb_masks(E) a = o = Z(blksz) i = 0 v, tl = blocks(m, blksz) for x in v: i += 1 b = ntz(i) o ^= Lgamma[b] a ^= E.encrypt(x ^ o) if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'A[%d]: %s' % (i, hex(a)) if len(tl) == blksz: a ^= tl ^ Lxinv else: a ^= pad10star(tl, blksz) return E.encrypt(a) def pmac2(E, m): blksz = E.__class__.blksz p = prim(8*blksz) L = E.encrypt(Z(blksz)) o = mul_blk_gf(L, C.GF(10), p) a = Z(blksz) v, tl = blocks(m, blksz) for x in v: a ^= E.encrypt(x ^ o) o = mul_blk_gf(o, C.GF(2), p) if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, C.GF(3), p) else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, C.GF(5), p) return E.encrypt(a) def ocb3_masks(E): Lgamma, _ = ocb_masks(E) Lstar = Lgamma[0] Ldollar = Lgamma[1] return Lstar, Ldollar, Lgamma[2:] def dump_ocb3(E): Lstar, Ldollar, Lgamma = ocb3_masks(E) print 'L_* = %s' % hex(Lstar) print 'L_$ = %s' % hex(Ldollar) for i, lg in enumerate(Lgamma[:16]): print 'L x^%d = %s' % (i, hex(lg)) def pmac3(E, m): ## Note that `PMAC3' is /not/ a secure MAC. It depends on other parts of ## OCB3 to prevent a rather easy linear-algebra attack. blksz = E.__class__.blksz Lstar, Ldollar, Lgamma = ocb3_masks(E) a = o = Z(blksz) i = 0 v, tl = blocks0(m, blksz) for x in v: i += 1 b = ntz(i) o ^= Lgamma[b] a ^= E.encrypt(x ^ o) if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'A[%d]: %s' % (i, hex(a)) if tl: o ^= Lstar a ^= E.encrypt(pad10star(tl, blksz) ^ o) if VERBOSE: print 'Z[%d]: * -> %s' % (i, hex(o)) print 'A[%d]: %s' % (i, hex(a)) return a def pmac1_pub(E, m): if VERBOSE: dump_ocb(E) return pmac1(E, m), def pmacgen(bc): return [(0,), (1,), (3*bc.blksz,), (3*bc.blksz - 5,)] ###-------------------------------------------------------------------------- ### OCB. def ocb1enc(E, n, h, m, tsz = None): ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with ## Associated-Data'. blksz = E.__class__.blksz if VERBOSE: dump_ocb(E) Lgamma, Lxinv = ocb_masks(E) if tsz is None: tsz = blksz a = Z(blksz) o = E.encrypt(n ^ Lgamma[0]) if VERBOSE: print 'R = %s' % hex(o) i = 0 y = C.WriteBuffer() v, tl = blocks(m, blksz) for x in v: i += 1 b = ntz(i) o ^= Lgamma[b] a ^= x if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'A[%d]: %s' % (i, hex(a)) y.put(E.encrypt(x ^ o) ^ o) i += 1 b = ntz(i) o ^= Lgamma[b] n = len(tl) if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz)) yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o) cfinal = tl ^ yfinal[:n] a ^= o ^ (tl + yfinal[n:]) y.put(cfinal) t = E.encrypt(a) if h: t ^= pmac1(E, h) return C.ByteString(y), C.ByteString(t[:tsz]) def ocb1dec(E, n, h, y, t): ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with ## Associated-Data'. blksz = E.__class__.blksz if VERBOSE: dump_ocb(E) Lgamma, Lxinv = ocb_masks(E) a = Z(blksz) o = E.encrypt(n ^ Lgamma[0]) if VERBOSE: print 'R = %s' % hex(o) i = 0 m = C.WriteBuffer() v, tl = blocks(y, blksz) for x in v: i += 1 b = ntz(i) o ^= Lgamma[b] if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'A[%d]: %s' % (i, hex(a)) u = E.decrypt(x ^ o) ^ o m.put(u) a ^= u i += 1 b = ntz(i) o ^= Lgamma[b] n = len(tl) if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz)) yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o) mfinal = tl ^ yfinal[:n] a ^= o ^ (mfinal + yfinal[n:]) m.put(mfinal) u = E.encrypt(a) if h: u ^= pmac1(E, h) if t == u[:len(t)]: return C.ByteString(m), else: return None, def ocb2enc(E, n, h, m, tsz = None): ## For OCB2, it's important for security that n = log_x (x + 1) is large in ## the field representations of GF(2^w) used -- in fact, we need more, that ## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}. The ## original paper lists the values for 64 and 128, but we support other ## block sizes, so here's the result of the (rather large, in some cases) ## computation. ## ## Block size log_x (x + 1) ## ## 64 9686038906114705801 ## 96 63214690573408919568138788065 ## 128 338793687469689340204974836150077311399 ## 192 161110085006042185925119981866940491651092686475226538785 ## 256 22928580326165511958494515843249267194111962539778797914076675796261938307298 blksz = E.__class__.blksz if tsz is None: tsz = blksz p = prim(8*blksz) L = E.encrypt(n) o = mul_blk_gf(L, C.GF(2), p) a = Z(blksz) v, tl = blocks(m, blksz) y = C.WriteBuffer() for x in v: a ^= x y.put(E.encrypt(x ^ o) ^ o) o = mul_blk_gf(o, C.GF(2), p) n = len(tl) yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o) cfinal = tl ^ yfinal[:n] a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, C.GF(3), p) y.put(cfinal) t = E.encrypt(a) if h: t ^= pmac2(E, h) return C.ByteString(y), C.ByteString(t[:tsz]) def ocb2dec(E, n, h, y, t): blksz = E.__class__.blksz p = prim(8*blksz) L = E.encrypt(n) o = mul_blk_gf(L, C.GF(2), p) a = Z(blksz) v, tl = blocks(y, blksz) m = C.WriteBuffer() for x in v: u = E.encrypt(x ^ o) ^ o y.put(u) a ^= u o = mul_blk_gf(o, C.GF(2), p) n = len(tl) yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o) mfinal = tl ^ yfinal[:n] a ^= (mfinal + yfinal[n:]) ^ mul_blk_gf(o, C.GF(3), p) m.put(mfinal) u = E.encrypt(a) if h: u ^= pmac2(E, h) if t == u[:len(t)]: return C.ByteString(m), else: return None, OCB3_STRETCH = { 4: ( 4, 17), 8: ( 5, 25), 12: ( 6, 33), 16: ( 6, 8), 24: ( 7, 40), 32: ( 8, 1), 48: ( 8, 80), 64: ( 8, 176), 96: ( 9, 160), 128: ( 9, 352), 200: (10, 192) } def ocb3nonce(E, n, tsz): ## Figure out how much we need to glue onto the nonce. This ends up being ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) - ## v - 1. But this is an annoying way to think about it because of the ## byte misalignment. Instead, think of it as a byte-aligned prefix ## encoding the tag and an `is the nonce full-length' flag, followed by ## optional padding, and then the nonce: ## ## F || N if l(N) = w - f ## F || 0^p || 1 || N otherwise ## ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2; ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1. blksz = E.__class__.blksz tszbits = min(C.MP(8*blksz - 1).nbits, 8) fwd = tszbits/8 + 1 f = 8*(tsz%blksz) << + 8*fwd - tszbits ## Form the augmented nonce. nb = C.WriteBuffer() nsz, nwd = len(n), blksz - fwd if nsz == nwd: f |= 1 nb.put(C.MP(f).storeb(fwd)) if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1) nb.put(n) nn = C.ByteString(nb) if VERBOSE: print 'aug-nonce = %s' % hex(nn) ## Calculate the initial offset. split, shift = OCB3_STRETCH[blksz] t2pw = C.MP(0).setbit(8*blksz) - 1 lomask = (C.MP(0).setbit(split) - 1) himask = ~lomask top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask ktop = C.MP.loadb(E.encrypt(top)) stretch = (ktop << 8*blksz) | (ktop ^ (ktop << shift)&t2pw) o = (stretch >> 8*blksz - bottom).storeb(blksz) if VERBOSE: print 'stretch = %s' % hex(stretch.storeb(2*blksz)) print 'Z[0] = %s' % hex(o) return o def ocb3enc(E, n, h, m, tsz = None): blksz = E.__class__.blksz if tsz is None: tsz = blksz Lstar, Ldollar, Lgamma = ocb3_masks(E) if VERBOSE: dump_ocb3(E) ## Set things up. o = ocb3nonce(E, n, tsz) a = C.ByteString.zero(blksz) ## Split the message into blocks. i = 0 y = C.WriteBuffer() v, tl = blocks0(m, blksz) for x in v: i += 1 b = ntz(i) o ^= Lgamma[b] a ^= x if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'A[%d]: %s' % (i, hex(a)) y.put(E.encrypt(x ^ o) ^ o) if tl: o ^= Lstar n = len(tl) pad = E.encrypt(o) a ^= pad10star(tl, blksz) if VERBOSE: print 'Z[%d]: * -> %s' % (i, hex(o)) print 'A[%d]: %s' % (i, hex(a)) y.put(tl ^ pad[0:n]) o ^= Ldollar t = E.encrypt(a ^ o) ^ pmac3(E, h) return C.ByteString(y), C.ByteString(t[:tsz]) def ocb3dec(E, n, h, y, t): blksz = E.__class__.blksz tsz = len(t) Lstar, Ldollar, Lgamma = ocb3_masks(E) if VERBOSE: dump_ocb3(E) ## Set things up. o = ocb3nonce(E, n, tsz) a = C.ByteString.zero(blksz) ## Split the message into blocks. i = 0 m = C.WriteBuffer() v, tl = blocks0(y, blksz) for x in v: i += 1 b = ntz(i) o ^= Lgamma[b] if VERBOSE: print 'Z[%d]: %d -> %s' % (i, b, hex(o)) print 'A[%d]: %s' % (i, hex(a)) u = E.encrypt(x ^ o) ^ o m.put(u) a ^= u if tl: o ^= Lstar n = len(tl) pad = E.encrypt(o) if VERBOSE: print 'Z[%d]: * -> %s' % (i, hex(o)) print 'A[%d]: %s' % (i, hex(a)) u = tl ^ pad[0:n] m.put(u) a ^= pad10star(u, blksz) o ^= Ldollar u = E.encrypt(a ^ o) ^ pmac3(E, h) if t == u[:tsz]: return C.ByteString(m), else: return None, def ocbgen(bc): w = bc.blksz return [(w, 0, 0), (w, 1, 0), (w, 0, 1), (w, 0, 3*w), (w, 3*w, 3*w), (w, 0, 3*w + 5), (w, 3*w - 5, 3*w + 5)] def ocb3gen(bc): w = bc.blksz return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1), (w - 5, 0, 3*w), (w - 3, 3*w, 3*w), (w - 2, 0, 3*w + 5), (w - 2, 3*w - 5, 3*w + 5)] def ocb3_mct(bc, ksz, tsz): k = C.ByteString(C.WriteBuffer().zero(ksz - 4).putu32(8*tsz)) E = bc(k) n = C.MP(1) nw = bc.blksz - 4 cbuf = C.WriteBuffer() for i in xrange(128): s = C.ByteString.zero(i) y, t = ocb3enc(E, n.storeb(nw), s, s, tsz); n += 1; cbuf.put(y).put(t) y, t = ocb3enc(E, n.storeb(nw), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t) y, t = ocb3enc(E, n.storeb(nw), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t) _, t = ocb3enc(E, n.storeb(nw), C.ByteString(cbuf), EMPTY, tsz) print hex(t) def ocb3_mct2(bc): k = C.bytes('000102030405060708090a0b0c0d0e0f') E = bc(k) tsz = min(E.blksz, 32) n = C.MP(1) cbuf = C.WriteBuffer() for i in xrange(128): sbuf = C.WriteBuffer() for j in xrange(i): sbuf.putu8(j) s = C.ByteString(sbuf) y, t = ocb3enc(E, n.storeb(2), s, s, tsz); n += 1; cbuf.put(y).put(t) y, t = ocb3enc(E, n.storeb(2), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t) y, t = ocb3enc(E, n.storeb(2), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t) _, t = ocb3enc(E, n.storeb(2), C.ByteString(cbuf), EMPTY, tsz) print hex(t) ###-------------------------------------------------------------------------- ### Main program. class struct (object): def __init__(me, **kw): me.__dict__.update(kw) binarg = struct(mk = R.block, parse = C.bytes, show = safehex) intarg = struct(mk = lambda x: x, parse = int, show = None) MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc), 'eax-dec': (dummygen, 4*[binarg], eaxdec), 'ccm-enc': (ccmgen, 3*[binarg] + [intarg], ccmenc), 'ccm-dec': (dummygen, 4*[binarg], ccmdec), 'cmac': (cmacgen, [binarg], cmac), 'gcm-enc': (gcmgen, 3*[binarg] + [intarg], gcmenc), 'gcm-dec': (dummygen, 4*[binarg], gcmdec), 'ocb1-enc': (ocbgen, 3*[binarg] + [intarg], ocb1enc), 'ocb1-dec': (dummygen, 4*[binarg], ocb1dec), 'ocb2-enc': (ocbgen, 3*[binarg] + [intarg], ocb2enc), 'ocb2-dec': (dummygen, 4*[binarg], ocb2dec), 'ocb3-enc': (ocb3gen, 3*[binarg] + [intarg], ocb3enc), 'ocb3-dec': (dummygen, 4*[binarg], ocb3dec), 'pmac1': (pmacgen, [binarg], pmac1_pub) } mode = argv[1] if len(argv) == 3 and mode == 'gcm-mul': VERBOSE = False nbits = int(argv[2]) gcm_mul_tests(nbits) exit(0) bc = None for d in CUSTOM, C.gcprps: try: bc = d[argv[2]] except KeyError: pass else: break if bc is None: raise KeyError, argv[2] if len(argv) == 5 and mode == 'ocb3-mct': VERBOSE = False ksz, tsz = int(argv[3]), int(argv[4]) ocb3_mct(bc, ksz, tsz) exit(0) if len(argv) == 3 and mode == 'ocb3-mct2': VERBOSE = False ocb3_mct2(bc) exit(0) if len(argv) == 3: VERBOSE = False gen, argty, func = MODEMAP[mode] if mode.endswith('-enc'): mode = mode[:-4] print '%s-%s {' % (bc.name, mode) for ksz in keylens(bc.keysz): for argvals in gen(bc): k = R.block(ksz) args = [t.mk(a) for t, a in izip(argty, argvals)] rets = func(bc(k), *args) print ' %s' % safehex(k) for t, a in izip(argty, args): if t.show: print ' %s' % t.show(a) for r, lastp in with_lastp(rets): print ' %s%s' % (safehex(r), lastp and ';' or '') print '}' else: VERBOSE = True k = C.bytes(argv[3]) gen, argty, func = MODEMAP[mode] args = [t.parse(a) for t, a in izip(argty, argv[4:])] rets = func(bc(k), *args) for r in rets: if r is None: print "X" else: print hex(r) ###----- That's all, folks --------------------------------------------------