#! /usr/bin/python ### -*-python-*- ### ### Generalization of OCB mode for other block sizes ### ### (c) 2017 Mark Wooding ### ###----- Licensing notice --------------------------------------------------- ### ### This program is free software; you can redistribute it and/or modify ### it under the terms of the GNU General Public License as published by ### the Free Software Foundation; either version 2 of the License, or ### (at your option) any later version. ### ### This program is distributed in the hope that it will be useful, ### but WITHOUT ANY WARRANTY; without even the implied warranty of ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ### GNU General Public License for more details. ### ### You should have received a copy of the GNU General Public License ### along with this program; if not, write to the Free Software Foundation, ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. from sys import argv, stderr from struct import pack from itertools import izip from contextlib import contextmanager try: from kalyna import Kalyna except ImportError: Kalyna = None import catacomb as C R = C.FibRand(0) ###-------------------------------------------------------------------------- ### Utilities. def combs(things, k): ii = range(k) n = len(things) while True: yield [things[i] for i in ii] for j in xrange(k): if j == k - 1: lim = n else: lim = ii[j + 1] i = ii[j] + 1 if i < lim: ii[j] = i break ii[j] = j else: return POLYMAP = {} def poly(nbits): try: return POLYMAP[nbits] except KeyError: pass base = C.GF(0).setbit(nbits).setbit(0) for k in xrange(1, nbits, 2): for cc in combs(range(1, nbits), k): p = base + sum(C.GF(0).setbit(c) for c in cc) if p.irreduciblep(): POLYMAP[nbits] = p; return p raise ValueError, nbits def prim(nbits): ## No fancy way to do this: I'd need a much cleverer factoring algorithm ## than I have in my pockets. if nbits == 64: cc = [64, 4, 3, 1, 0] elif nbits == 96: cc = [96, 10, 9, 6, 0] elif nbits == 128: cc = [128, 7, 2, 1, 0] elif nbits == 192: cc = [192, 15, 11, 5, 0] elif nbits == 256: cc = [256, 10, 5, 2, 0] else: raise ValueError, 'no field for %d bits' % nbits p = C.GF(0) for c in cc: p = p.setbit(c) return p def Z(n): return C.ByteString.zero(n) def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8) def with_lastp(it): it = iter(it) try: j = next(it) except StopIteration: raise ValueError, 'empty iter' lastp = False while not lastp: i = j try: j = next(it) except StopIteration: lastp = True yield i, lastp def safehex(x): if len(x): return hex(x) else: return '""' def keylens(ksz): sel = [] if isinstance(ksz, C.KeySZSet): kk = ksz.set elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod) elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0] kk = list(kk); kk = kk[:] n = len(kk) while n and len(sel) < 4: i = R.range(n) n -= 1 kk[i], kk[n] = kk[n], kk[i] sel.append(kk[n]) return sel def pad0star(m, w): n = len(m) if not n: r = w else: r = (-len(m))%w if r: m += Z(r) return C.ByteString(m) def pad10star(m, w): r = w - len(m)%w if r: m += '\x80' + Z(r - 1) return C.ByteString(m) def ntz(i): j = 0 while (i&1) == 0: i >>= 1; j += 1 return j def blocks(x, w): v, i, n = [], 0, len(x) while n - i > w: v.append(C.ByteString(x[i:i + w])) i += w return v, C.ByteString(x[i:]) EMPTY = C.bytes('') def blocks0(x, w): v, tl = blocks(x, w) if len(tl) == w: v.append(tl); tl = EMPTY return v, tl ###-------------------------------------------------------------------------- ### Kalyna decoration. KALYNA = {} if Kalyna is not None: class KalynaCipher (type): def __new__(cls, blksz): assert blksz in [16, 32, 64] name = 'Kalyna-%d' % (8*blksz) me = type(name, (KalynaBase,), {}) me.name = name me.blksz = blksz if blksz == 64: me.keysz = C.KeySZSet(64) else: me.keysz = C.KeySZSet(2*blksz, [blksz]) return me class KalynaBase (object): def __init__(me, k): me._k = Kalyna(k, me.blksz) def encrypt(me, m): return C.ByteString(me._k.encrypt(m)) def decrypt(me, m): return C.ByteString(me._k.decrypt(m)) for i in [16, 32, 64]: KALYNA['kalyna%d' % (8*i)] = KalynaCipher(i) ###-------------------------------------------------------------------------- ### Luby--Rackoff large-block ciphers. class LubyRackoffCipher (type): def __new__(cls, bc, blksz): assert blksz%2 == 0 assert blksz <= 2*bc.blksz name = '%s-lr[%d]' % (bc.name, 8*blksz) me = type(name, (LubyRackoffBase,), {}) me.name = name me.blksz = blksz me.keysz = bc.keysz me.bc = bc return me @contextmanager def muffle(): global VERBOSE, LRVERBOSE _v, _lrv = VERBOSE, LRVERBOSE try: VERBOSE = LRVERBOSE = False yield None finally: VERBOSE, LRVERBOSE = _v, _lrv class LubyRackoffBase (object): NR = 4 # for strong-PRP security def __init__(me, k): if LRVERBOSE: print 'K = %s' % hex(k) bc, blksz = me.__class__.bc, me.__class__.blksz with muffle(): E = bc(k) me.f = [] ksz = len(k) i = C.MP(0) for j in xrange(me.NR): b = C.WriteBuffer() while b.size < ksz: with muffle(): x = E.encrypt(i.storeb(bc.blksz)) b.put(x) if LRVERBOSE: print 'E(K; [%d]) = %s' % (i, hex(x)) i += 1 kj = C.ByteString(C.ByteString(b)[0:ksz]) if LRVERBOSE: print 'K_%d = %s' % (j, hex(kj)) with muffle(): me.f.append(bc(kj)) def encrypt(me, m): bc, blksz = me.__class__.bc, me.__class__.blksz assert len(m) == blksz l, r = C.ByteString(m[:blksz/2]), C.ByteString(m[blksz/2:]) if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r)) for j in xrange(me.NR): l0 = pad0star(l, bc.blksz) with muffle(): t = me.f[j].encrypt(l0) l, r = r ^ t[:blksz/2], l if LRVERBOSE: print 'E(K_%d; L_%d || 0^*) = %s' % (j, j, hex(t)) print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r)) return C.ByteString(r + l) def decrypt(me, c): bc, blksz = me.__class__.bc, me.__class__.blksz assert len(c) == blksz l, r = C.ByteString(c[:blksz/2]), C.ByteString(c[blksz/2:]) for j in xrange(me.NR - 1, -1, -1): l0 = pad0star(l, bc.blksz) with muffle(): t = me.f[j].encrypt(l0) if LRVERBOSE: print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r)) print 'E(K_%d; L_%d || 0^*) = %s' % (j + 1, j + 1, hex(t)) l, r = r ^ t[:blksz/2], l if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r)) return C.ByteString(r + l) LRAES = {} for i in [8, 12, 16, 24, 32]: LRAES['lraes%d' % (8*i)] = LubyRackoffCipher(C.rijndael, i) LRAES['dlraes512'] = LubyRackoffCipher(LubyRackoffCipher(C.rijndael, 32), 64) ###-------------------------------------------------------------------------- ### PMAC. def ocb_masks(E): blksz = E.__class__.blksz p = poly(8*blksz) x = C.GF(2); xinv = p.modinv(x) z = Z(blksz) L = E.encrypt(z) Lxinv = mul_blk_gf(L, xinv, p) Lgamma = 66*[L] for i in xrange(1, len(Lgamma)): Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p) return Lgamma, Lxinv def dump_ocb(E): Lgamma, Lxinv = ocb_masks(E) print 'L x^-1 = %s' % hex(Lxinv) for i, lg in enumerate(Lgamma): print 'L x^%d = %s' % (i, hex(lg)) def pmac1(E, m): blksz = E.__class__.blksz Lgamma, Lxinv = ocb_masks(E) a = o = Z(blksz) i = 1 v, tl = blocks(m, blksz) for x in v: b = ntz(i) o ^= Lgamma[b] a ^= E.encrypt(x ^ o) if VERBOSE: print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o)) print 'A[%d]: %s' % (i - 1, hex(a)) i += 1 if len(tl) == blksz: a ^= tl ^ Lxinv else: a ^= pad10star(tl, blksz) return E.encrypt(a) def pmac2(E, m): blksz = E.__class__.blksz p = prim(8*blksz) L = E.encrypt(Z(blksz)) o = mul_blk_gf(L, 10, p) a = Z(blksz) v, tl = blocks(m, blksz) for x in v: a ^= E.encrypt(x ^ o) o = mul_blk_gf(o, 2, p) if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, 3, p) else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p) return E.encrypt(a) def ocb3_masks(E): Lgamma, _ = ocb_masks(E) Lstar = Lgamma[0] Ldollar = Lgamma[1] return Lstar, Ldollar, Lgamma[2:] def dump_ocb3(E): Lstar, Ldollar, Lgamma = ocb3_masks(E) print 'L_* : %s' % hex(Lstar) print 'L_$ : %s' % hex(Ldollar) for i, lg in enumerate(Lgamma[:4]): print 'L_%-8d: %s' % (i, hex(lg)) def pmac3(E, m): blksz = E.__class__.blksz Lstar, Ldollar, Lgamma = ocb3_masks(E) a = o = Z(blksz) i = 1 v, tl = blocks0(m, blksz) for x in v: b = ntz(i) o ^= Lgamma[b] a ^= E.encrypt(x ^ o) if VERBOSE: print 'Offset\'_%-2d: %s' % (i, hex(o)) print 'AuthSum_%-2d: %s' % (i, hex(a)) i += 1 if tl: o ^= Lstar a ^= E.encrypt(pad10star(tl, blksz) ^ o) if VERBOSE: print 'Offset\'_* : %s' % hex(o) print 'AuthSum_* : %s' % hex(a) return a def pmac1_pub(E, m): if VERBOSE: dump_ocb(E) return pmac1(E, m), def pmac2_pub(E, m): return pmac2(E, m), def pmac3_pub(E, m): return pmac3(E, m), def pmacgen(bc): return [(0,), (1,), (3*bc.blksz,), (3*bc.blksz - 5,)] ###-------------------------------------------------------------------------- ### OCB. ## For OCB2, it's important for security that n = log_x (x + 1) is large in ## the field representations of GF(2^w) used -- in fact, we need more, that ## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}. The ## original paper lists the values for 64 and 128, but we support other block ## sizes, so here's the result of the (rather large, in some cases) ## computation. ## ## Block size log_x (x + 1) ## ## 64 9686038906114705801 ## 96 63214690573408919568138788065 ## 128 338793687469689340204974836150077311399 ## 192 161110085006042185925119981866940491651092686475226538785 ## 256 22928580326165511958494515843249267194111962539778797914076675796261938307298 def ocb1(E, n, h, m, tsz = None): ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with ## Associated-Data'. blksz = E.__class__.blksz if VERBOSE: dump_ocb(E) Lgamma, Lxinv = ocb_masks(E) if tsz is None: tsz = blksz a = Z(blksz) o = E.encrypt(n ^ Lgamma[0]) if VERBOSE: print 'R = %s' % hex(o) i = 1 y = C.WriteBuffer() v, tl = blocks(m, blksz) for x in v: b = ntz(i) o ^= Lgamma[b] a ^= x if VERBOSE: print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o)) print 'A[%d]: %s' % (i - 1, hex(a)) y.put(E.encrypt(x ^ o) ^ o) i += 1 b = ntz(i) o ^= Lgamma[b] n = len(tl) if VERBOSE: print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o)) print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz)) yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o) cfinal = tl ^ yfinal[:n] a ^= o ^ (tl + yfinal[n:]) y.put(cfinal) t = E.encrypt(a) if h: t ^= pmac1(E, h) return C.ByteString(y), C.ByteString(t[:tsz]) def ocb2(E, n, h, m, tsz = None): blksz = E.__class__.blksz if tsz is None: tsz = blksz p = prim(8*blksz) L = E.encrypt(n) o = mul_blk_gf(L, 2, p) a = Z(blksz) v, tl = blocks(m, blksz) y = C.WriteBuffer() for x in v: a ^= x y.put(E.encrypt(x ^ o) ^ o) o = mul_blk_gf(o, 2, p) n = len(tl) yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o) cfinal = tl ^ yfinal[:n] a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, 3, p) y.put(cfinal) t = E.encrypt(a) if h: t ^= pmac2(E, h) return C.ByteString(y), C.ByteString(t[:tsz]) OCB3_STRETCH = { 8: (5, 25), 12: (6, 33), 16: (6, 8), 24: (7, 40), 32: (7, 120), 64: (8, 240) } def ocb3(E, n, h, m, tsz = None): blksz = E.__class__.blksz if tsz is None: tsz = blksz Lstar, Ldollar, Lgamma = ocb3_masks(E) if VERBOSE: dump_ocb3(E) ## Figure out how much we need to glue onto the nonce. This ends up being ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) - ## v - 1. But this is an annoying way to think about it because of the ## byte misalignment. Instead, think of it as a byte-aligned prefix ## encoding the tag and an `is the nonce full-length' flag, followed by ## optional padding, and then the nonce: ## ## F || N if l(N) = w - f ## F || 0^p || 1 || N otherwise ## ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2; ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1. tszbits = C.MP(8*blksz - 1).nbits fwd = tszbits/8 + 1 f = tsz << 3 + 8*fwd - tszbits ## Form the augmented nonce. nb = C.WriteBuffer() nsz, nwd = len(n), blksz - fwd if nsz == nwd: f |= 1 nb.put(C.MP(f).storeb(fwd)) if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1) nb.put(n) nn = C.ByteString(nb) if VERBOSE: print 'N\' : %s' % hex(nn) ## Calculate the initial offset. split, shift = OCB3_STRETCH[blksz] splitbits = 1 << split t2ps = C.MP(0).setbit(splitbits) lomask = (C.MP(0).setbit(split) - 1) himask = ~lomask top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask ktop = C.MP.loadb(E.encrypt(top)) stretch = (ktop << splitbits) | \ (((ktop ^ (ktop << shift)) >> (8*blksz - splitbits))%t2ps) o = (stretch >> splitbits - bottom).storeb(blksz) a = C.ByteString.zero(blksz) if VERBOSE: print 'bottom : %d' % bottom print 'Ktop : %s' % hex(ktop.storeb(blksz)) print 'Stretch : %s' % hex(stretch.storeb(blksz + (1 << split - 3))) print 'Offset_0 : %s' % hex(o) ## Split the message into blocks. i = 1 y = C.WriteBuffer() v, tl = blocks0(m, blksz) for x in v: b = ntz(i) o ^= Lgamma[b] a ^= x if VERBOSE: print 'Offset_%-3d: %s' % (i, hex(o)) print 'Checksum_%d: %s' % (i, hex(a)) y.put(E.encrypt(x ^ o) ^ o) i += 1 if tl: o ^= Lstar n = len(tl) pad = E.encrypt(o) a ^= pad10star(tl, blksz) if VERBOSE: print 'Offset_* : %s' % hex(o) print 'Checksum_*: %s' % hex(a) y.put(tl ^ pad[0:n]) o ^= Ldollar t = E.encrypt(a ^ o) ^ pmac3(E, h) return C.ByteString(y), C.ByteString(t[:tsz]) def ocbgen(bc): w = bc.blksz return [(w, 0, 0), (w, 1, 0), (w, 0, 1), (w, 0, 3*w), (w, 3*w, 3*w), (w, 0, 3*w + 5), (w, 3*w - 5, 3*w + 5)] def ocb3gen(bc): w = bc.blksz return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1), (w - 5, 0, 3*w), (w - 3, 3*w, 3*w), (w - 2, 0, 3*w + 5), (w - 2, 3*w - 5, 3*w + 5)] ###-------------------------------------------------------------------------- ### Main program. VERBOSE = LRVERBOSE = False class struct (object): def __init__(me, **kw): me.__dict__.update(kw) def mct(ocb, bc, ksz, nsz, tsz): k = C.MP(8*tsz).storeb(ksz) E = bc(k) e = C.ByteString('') n = C.MP(1) cbuf = C.WriteBuffer() for i in xrange(128): s = C.ByteString.zero(i) y, t = ocb(E, n.storeb(nsz), s, s, tsz); n += 1; cbuf.put(y).put(t) y, t = ocb(E, n.storeb(nsz), e, s, tsz); n += 1; cbuf.put(y).put(t) y, t = ocb(E, n.storeb(nsz), s, e, tsz); n += 1; cbuf.put(y).put(t) _, t = ocb(E, n.storeb(nsz), C.ByteString(cbuf), e, tsz) print hex(t) argc = len(argv) argi = 1 def usage(): print >>stderr, """\ usage: %s [-v] OCB BLKC OP ARGS... mct KSZ NSZ TSZ kat K N0 TSZ HSZ,MSZ ... lraes W K M""" % argv[0] exit(2) def arg(must = True, default = None): global argi if argi < argc: argi += 1; return argv[argi - 1] elif not must: return default else: usage() MODEMAP = { 'ocb1': ocb1, 'ocb2': ocb2, 'ocb3': ocb3 } def pat(sz): b = C.WriteBuffer() for i in xrange(sz): b.putu8(i%256) return C.ByteString(b) opt = arg() if opt == '-v': VERBOSE = True; opt = arg() ocb = MODEMAP[opt] bcname = arg() bc = None for d in LRAES, KALYNA, C.gcprps: try: bc = d[bcname] except KeyError: pass else: break if bc is None: raise KeyError, bcname mode = arg() if mode == 'mct': ksz = int(arg()); nsz = int(arg()); tsz = int(arg()) mct(ocb, bc, ksz, nsz, tsz) exit(0) elif mode == 'kat': k = C.bytes(arg()) E = bc(k) nspec = arg() if nspec.endswith('+'): ninc = 1; nspec = nspec[:-1] else: ninc = 0 n0 = C.bytes(nspec) nz = C.MP.loadb(n0) nsz = len(n0) tsz = int(arg()) print 'K: %s' % hex(k) while True: hmsz = arg(must = False) if hmsz is None: break hsz, msz = map(int, hmsz.split(',')) n = nz.storeb(nsz) h = pat(hsz) m = pat(msz) y, t = ocb(E, n, h, m, tsz) print print 'N: %s' % hex(n) print 'A: %s' % hex(h) print 'P: %s' % hex(m) print 'C: %s%s' % (hex(y), hex(t)) nz += ninc elif mode == 'lraes': w = int(arg()) k = C.bytes(arg()) m = C.bytes(arg()) LRVERBOSE = True lr = LubyRackoffCipher(bc, w) E = lr(k) print c = E.encrypt(m) print 'E\'(K, m) = %s' % hex(c) else: usage() ###----- That's all, folks --------------------------------------------------