symm/cmac.h, symm/cmac-def.h: Implement the CMAC (OMAC) message auth'n mode.
[catacomb] / utils / advmodes
diff --git a/utils/advmodes b/utils/advmodes
new file mode 100755 (executable)
index 0000000..96e5f5b
--- /dev/null
@@ -0,0 +1,310 @@
+#! /usr/bin/python
+
+from sys import argv
+from struct import pack
+from itertools import izip
+import catacomb as C
+
+R = C.FibRand(0)
+
+###--------------------------------------------------------------------------
+### Utilities.
+
+def combs(things, k):
+  ii = range(k)
+  n = len(things)
+  while True:
+    yield [things[i] for i in ii]
+    for j in xrange(k):
+      if j == k - 1: lim = n
+      else: lim = ii[j + 1]
+      i = ii[j] + 1
+      if i < lim:
+        ii[j] = i
+        break
+      ii[j] = j
+    else:
+      return
+
+POLYMAP = {}
+
+def poly(nbits):
+  try: return POLYMAP[nbits]
+  except KeyError: pass
+  base = C.GF(0).setbit(nbits).setbit(0)
+  for k in xrange(1, nbits, 2):
+    for cc in combs(range(1, nbits), k):
+      p = base + sum(C.GF(0).setbit(c) for c in cc)
+      if p.irreduciblep(): POLYMAP[nbits] = p; return p
+  raise ValueError, nbits
+
+def Z(n):
+  return C.ByteString.zero(n)
+
+def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
+
+def with_lastp(it):
+  it = iter(it)
+  try: j = next(it)
+  except StopIteration: raise ValueError, 'empty iter'
+  lastp = False
+  while not lastp:
+    i = j
+    try: j = next(it)
+    except StopIteration: lastp = True
+    yield i, lastp
+
+def safehex(x):
+  if len(x): return hex(x)
+  else: return '""'
+
+def keylens(ksz):
+  sel = []
+  if isinstance(ksz, C.KeySZSet): kk = ksz.set
+  elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
+  elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
+  kk = list(kk); kk = kk[:]
+  n = len(kk)
+  while n and len(sel) < 4:
+    i = R.range(n)
+    n -= 1
+    kk[i], kk[n] = kk[n], kk[i]
+    sel.append(kk[n])
+  return sel
+
+def pad0star(m, w):
+  n = len(m)
+  if not n: r = w
+  else: r = (-len(m))%w
+  if r: m += Z(r)
+  return C.ByteString(m)
+
+def pad10star(m, w):
+  r = w - len(m)%w
+  if r: m += '\x80' + Z(r - 1)
+  return C.ByteString(m)
+
+def ntz(i):
+  j = 0
+  while (i&1) == 0: i >>= 1; j += 1
+  return j
+
+def blocks(x, w):
+  v, i, n = [], 0, len(x)
+  while n - i > w:
+    v.append(C.ByteString(x[i:i + w]))
+    i += w
+  return v, C.ByteString(x[i:])
+
+EMPTY = C.bytes('')
+
+def blocks0(x, w):
+  v, tl = blocks(x, w)
+  if len(tl) == w: v.append(tl); tl = EMPTY
+  return v, tl
+
+CUSTOM = {}
+
+###--------------------------------------------------------------------------
+### RC6.
+
+class RC6Cipher (type):
+  def __new__(cls, w, r):
+    name = 'rc6-%d/%d' % (w, r)
+    me = type(name, (RC6Base,), {})
+    me.name = name
+    me.r = r
+    me.w = w
+    me.blksz = w/2
+    me.keysz = C.KeySZRange(me.blksz, 1, 255, 1)
+    return me
+
+def rotw(w):
+  return w.bit_length() - 1
+
+def rol(w, x, n):
+  m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1
+  return ((x&m0) << n) | (x >> (w - n))&m1
+
+def ror(w, x, n):
+  m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1
+  return ((x&m0) << (w - n)) | (x >> n)&m1
+
+class RC6Base (object):
+
+  ## Magic constants.
+  P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06)
+  Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188)
+
+  def __init__(me, k):
+
+    ## Build the magic numbers.
+    P = me.P400 >> (400 - me.w)
+    if P%2 == 0: P += 1
+    Q = me.Q400 >> (400 - me.w)
+    if Q%2 == 0: Q += 1
+    M = C.MP(0).setbit(me.w) - 1
+
+    ## Convert the key into words.
+    wb = me.w/8
+    c = (len(k) + wb - 1)/wb
+    kb, ktl = blocks(k, me.w/8)
+    L = map(C.MP.loadl, kb + [ktl])
+    assert c == len(L)
+
+    ## Build the subkey table.
+    me.d = rotw(me.w)
+    n = 2*me.r + 4
+    S = [(P + i*Q)&M for i in xrange(n)]
+
+    ##for j in xrange(c):
+    ##  print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
+    ##for i in xrange(n):
+    ##  print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
+
+    i = j = 0
+    A = B = C.MP(0)
+
+    for s in xrange(3*max(c, n)):
+      A = S[i] = rol(me.w, S[i] + A + B, 3)
+      B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d))
+      ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
+      ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
+      i = (i + 1)%n
+      j = (j + 1)%c
+
+    ## Done.
+    me.s = S
+
+  def encrypt(me, x):
+    M = C.MP(0).setbit(me.w) - 1
+    a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0])
+    b = (b + me.s[0])&M
+    d = (d + me.s[1])&M
+    ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0'))
+    ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0'))
+    for i in xrange(2, 2*me.r + 2, 2):
+      t = rol(me.w, 2*b*b + b, me.d)
+      u = rol(me.w, 2*d*d + d, me.d)
+      a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M
+      c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M
+      ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
+      ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
+      a, b, c, d = b, c, d, a
+    a = (a + me.s[2*me.r + 2])&M
+    c = (c + me.s[2*me.r + 3])&M
+    ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
+    ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
+    return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
+                        c.storel(me.blksz/4) + d.storel(me.blksz/4))
+
+  def decrypt(me, x):
+    M = C.MP(0).setbit(me.w) - 1
+    a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4))
+    c = (c - me.s[2*me.r + 3])&M
+    a = (a - me.s[2*me.r + 2])&M
+    for i in xrange(2*me.r + 1, 1, -2):
+      a, b, c, d = d, a, b, c
+      u = rol(me.w, 2*d*d + d, me.d)
+      t = rol(me.w, 2*b*b + b, me.d)
+      c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u
+      a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t
+    a = (a + s[2*me.r + 2])&M
+    c = (c + s[2*me.r + 3])&M
+    return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
+                        c.storel(me.blksz/4) + d.storel(me.blksz/4))
+
+for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16),
+               (32, 20), (48, 16), (64, 16), (96, 16), (128, 16),
+               (192, 16), (256, 16), (400, 16)]:
+  CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r)
+
+###--------------------------------------------------------------------------
+### OMAC (or CMAC).
+
+def omac_masks(E):
+  blksz = E.__class__.blksz
+  p = poly(8*blksz)
+  z = Z(blksz)
+  L = E.encrypt(z)
+  m0 = mul_blk_gf(L, 2, p)
+  m1 = mul_blk_gf(m0, 2, p)
+  return m0, m1
+
+def dump_omac(E):
+  blksz = E.__class__.blksz
+  m0, m1 = omac_masks(E)
+  print 'L = %s' % hex(E.encrypt(Z(blksz)))
+  print 'm0 = %s' % hex(m0)
+  print 'm1 = %s' % hex(m1)
+  for t in xrange(3):
+    print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz))))
+    print 'z%d = %s' % (t, hex(omac(E, t, '')))
+
+def omac(E, t, m):
+  blksz = E.__class__.blksz
+  m0, m1 = omac_masks(E)
+  a = Z(blksz)
+  if t is not None: m = C.MP(t).storeb(blksz) + m
+  v, tl = blocks(m, blksz)
+  for x in v: a = E.encrypt(a ^ x)
+  r = blksz - len(tl)
+  if r == 0:
+    a = E.encrypt(a ^ tl ^ m0)
+  else:
+    pad = pad10star(tl, blksz)
+    a = E.encrypt(a ^ pad ^ m1)
+  return a
+
+def cmac(E, m):
+  if VERBOSE: dump_omac(E)
+  return omac(E, None, m),
+
+def cmacgen(bc):
+  return [(0,), (1,),
+          (3*bc.blksz,),
+          (3*bc.blksz - 5,)]
+
+###--------------------------------------------------------------------------
+### Main program.
+
+class struct (object):
+  def __init__(me, **kw):
+    me.__dict__.update(kw)
+
+binarg = struct(mk = R.block, parse = C.bytes, show = safehex)
+intarg = struct(mk = lambda x: x, parse = int, show = None)
+
+MODEMAP = { 'cmac': (cmacgen, [binarg], cmac) }
+
+mode = argv[1]
+bc = None
+for d in CUSTOM, C.gcprps:
+  try: bc = d[argv[2]]
+  except KeyError: pass
+  else: break
+if bc is None: raise KeyError, argv[2]
+if len(argv) == 3:
+  VERBOSE = False
+  gen, argty, func = MODEMAP[mode]
+  print '%s-%s {' % (bc.name, mode)
+  for ksz in keylens(bc.keysz):
+    for argvals in gen(bc):
+      k = R.block(ksz)
+      args = [t.mk(a) for t, a in izip(argty, argvals)]
+      rets = func(bc(k), *args)
+      print '  %s' % safehex(k)
+      for t, a in izip(argty, args):
+        if t.show: print '    %s' % t.show(a)
+      for r, lastp in with_lastp(rets):
+        print '    %s%s' % (safehex(r), lastp and ';' or '')
+  print '}'
+else:
+  VERBOSE = True
+  k = C.bytes(argv[3])
+  gen, argty, func = MODEMAP[mode]
+  args = [t.parse(a) for t, a in izip(argty, argv[4:])]
+  rets = func(bc(k), *args)
+  for r in rets: print hex(r)
+
+###----- That's all, folks --------------------------------------------------