symm/ocb1.h, symm/pmac1.h, ...: Implement PMAC1 and OCB1.
[catacomb] / utils / advmodes
index 96e5f5b..e0dd386 100755 (executable)
@@ -1,7 +1,7 @@
 #! /usr/bin/python
 
 from sys import argv
-from struct import pack
+from struct import unpack, pack
 from itertools import izip
 import catacomb as C
 
@@ -103,6 +103,8 @@ def blocks0(x, w):
   if len(tl) == w: v.append(tl); tl = EMPTY
   return v, tl
 
+def dummygen(bc): return []
+
 CUSTOM = {}
 
 ###--------------------------------------------------------------------------
@@ -266,6 +268,394 @@ def cmacgen(bc):
           (3*bc.blksz - 5,)]
 
 ###--------------------------------------------------------------------------
+### Counter mode.
+
+def ctr(E, m, c0):
+  blksz = E.__class__.blksz
+  y = C.WriteBuffer()
+  c = C.MP.loadb(c0)
+  while y.size < len(m):
+    y.put(E.encrypt(c.storeb(blksz)))
+    c += 1
+  return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
+
+###--------------------------------------------------------------------------
+### GCM.
+
+def gcm_mangle(x):
+  y = C.WriteBuffer()
+  for b in x:
+    b = ord(b)
+    bb = 0
+    for i in xrange(8):
+      bb <<= 1
+      if b&1: bb |= 1
+      b >>= 1
+    y.putu8(bb)
+  return C.ByteString(y)
+
+def gcm_mul(x, y):
+  w = len(x)
+  p = poly(8*w)
+  u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
+  z = (u*v)%p
+  return gcm_mangle(z.storel(w))
+
+def gcm_pow(x, n):
+  w = len(x)
+  p = poly(8*w)
+  u = C.GF.loadl(gcm_mangle(x))
+  z = pow(u, n, p)
+  return gcm_mangle(z.storel(w))
+
+def gcm_ctr(E, m, c0):
+  y = C.WriteBuffer()
+  pre = c0[:-4]
+  c, = unpack('>L', c0[-4:])
+  while y.size < len(m):
+    c += 1
+    y.put(E.encrypt(pre + pack('>L', c)))
+  return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
+
+def g(what, x, m, a0 = None):
+  n = len(x)
+  if a0 is None: a = Z(n)
+  else: a = a0
+  i = 0
+  for b in blocks0(m, n)[0]:
+    a = gcm_mul(a ^ b, x)
+    if VERBOSE: print '%s[%d] = %s -> %s' % (what, i, hex(b), hex(a))
+    i += 1
+  return a
+
+def gcm_pad(w, x):
+  return C.ByteString(x + Z(-len(x)%w))
+
+def gcm_lens(w, a, b):
+  if w < 12: n = w
+  else: n = w/2
+  return C.ByteString(C.MP(a).storeb(n) + C.MP(b).storeb(n))
+
+def ghash(whata, whatb, x, a, b):
+  w = len(x)
+  ha = g(whata, x, gcm_pad(w, a))
+  hb = g(whatb, x, gcm_pad(w, b))
+  if a:
+    hc = gcm_mul(ha, gcm_pow(x, (len(b) + w - 1)/w)) ^ hb
+    if VERBOSE: print '%s || %s -> %s' % (whata, whatb, hex(hc))
+  else:
+    hc = hb
+  return g(whatb, x, gcm_lens(w, 8*len(a), 8*len(b)), hc)
+
+def gcmenc(E, n, h, m, tsz = None):
+  w = E.__class__.blksz
+  x = E.encrypt(Z(w))
+  if VERBOSE: print 'x = %s' % hex(x)
+  if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
+  else: c0 = ghash('?', 'n', x, EMPTY, n)
+  if VERBOSE: print 'c0 = %s' % hex(c0)
+  y = gcm_ctr(E, m, c0)
+  t = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
+  return y, t
+
+def gcmdec(E, n, h, y, t):
+  w = E.__class__.blksz
+  x = E.encrypt(Z(w))
+  if VERBOSE: print 'x = %s' % hex(x)
+  if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
+  else: c0 = ghash('?', 'n', x, EMPTY, n)
+  if VERBOSE: print 'c0 = %s' % hex(c0)
+  m = gcm_ctr(E, y, c0)
+  tt = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
+  if t == tt: return m,
+  else: return None,
+
+def gcmgen(bc):
+  return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
+          (bc.blksz, 3*bc.blksz, 3*bc.blksz),
+          (bc.blksz - 4, bc.blksz + 3, 3*bc.blksz + 9),
+          (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
+
+###--------------------------------------------------------------------------
+### CCM.
+
+def stbe(n, w): return C.MP(n).storeb(w)
+
+def ccm_fmthdr(blksz, n, hsz, msz, tsz):
+  b = C.WriteBuffer()
+  if blksz == 8:
+    q = blksz - len(n) - 1
+    f = 0
+    if hsz: f |= 0x40
+    f |= (tsz - 1) << 3
+    f |= q - 1
+    b.putu8(f).put(n).put(stbe(msz, q))
+  elif blksz == 16:
+    q = blksz - len(n) - 1
+    f = 0
+    if hsz: f |= 0x40
+    f |= (tsz - 2)/2 << 3
+    f |= q - 1
+    b.putu8(f).put(n).put(stbe(msz, q))
+  else:
+    q = blksz - len(n) - 2
+    f0 = f1 = 0
+    if hsz: f1 |= 0x80
+    f0 |= tsz
+    f1 |= q
+    b.putu8(f0).putu8(f1).put(n).put(stbe(msz, q))
+  b = C.ByteString(b)
+  if VERBOSE: print 'hdr = %s' % hex(b)
+  return b
+
+def ccm_fmtctr(blksz, n, i = 0):
+  b = C.WriteBuffer()
+  if blksz == 8 or blksz == 16:
+    q = blksz - len(n) - 1
+    b.putu8(q - 1).put(n).put(stbe(i, q))
+  else:
+    q = blksz - len(n) - 2
+    b.putu8(0).putu8(q).put(n).put(stbe(i, q))
+  b = C.ByteString(b)
+  if VERBOSE: print 'ctr = %s' % hex(b)
+  return b
+
+def ccmaad(b, h, blksz):
+  hsz = len(h)
+  if not hsz: pass
+  elif hsz < 0xfffe: b.putu16(hsz)
+  elif hsz <= 0xffffffff: b.putu16(0xfffe).putu32(hsz)
+  else: b.putu16(0xffff).putu64(hsz)
+  b.put(h); b.zero((-b.size)%blksz)
+
+def ccmenc(E, n, h, m, tsz = None):
+  blksz = E.__class__.blksz
+  if tsz is None: tsz = blksz
+  b = C.WriteBuffer()
+  b.put(ccm_fmthdr(blksz, n, len(h), len(m), tsz))
+  ccmaad(b, h, blksz)
+  b.put(m); b.zero((-b.size)%blksz)
+  b = C.ByteString(b)
+  a = Z(blksz)
+  v, _ = blocks0(b, blksz)
+  i = 0
+  for x in v:
+    a = E.encrypt(a ^ x)
+    if VERBOSE:
+      print 'b[%d] = %s' % (i, hex(x))
+      print 'a[%d] = %s' % (i + 1, hex(a))
+    i += 1
+  y = ctr(E, a + m, ccm_fmtctr(blksz, n))
+  return C.ByteString(y[blksz:]), C.ByteString(y[0:tsz])
+
+def ccmdec(E, n, h, y, t):
+  blksz = E.__class__.blksz
+  tsz = len(t)
+  b = C.WriteBuffer()
+  b.put(ccm_fmthdr(blksz, n, len(h), len(y), tsz))
+  ccmaad(b, h, blksz)
+  mm = ctr(E, t + Z(blksz - tsz) + y, ccm_fmtctr(blksz, n))
+  u, m = C.ByteString(mm[0:tsz]), C.ByteString(mm[blksz:])
+  b.put(m); b.zero((-b.size)%blksz)
+  b = C.ByteString(b)
+  a = Z(blksz)
+  v, _ = blocks0(b, blksz)
+  i = 0
+  for x in v:
+    a = E.encrypt(a ^ x)
+    if VERBOSE:
+      print 'b[%d] = %s' % (i, hex(x))
+      print 'a[%d] = %s' % (i + 1, hex(a))
+    i += 1
+  if u == a[:tsz]: return m,
+  else: return None,
+
+def ccmgen(bc):
+  bsz = bc.blksz
+  return [(bsz - 5, 0, 0, 4), (bsz - 5, 1, 0, 4), (bsz - 5, 0, 1, 4),
+          (bsz/2 + 1, 3*bc.blksz, 3*bc.blksz),
+          (bsz/2 + 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
+
+###--------------------------------------------------------------------------
+### EAX.
+
+def eaxenc(E, n, h, m, tsz = None):
+  if VERBOSE:
+    print 'k = %s' % hex(k)
+    print 'n = %s' % hex(n)
+    print 'h = %s' % hex(h)
+    print 'm = %s' % hex(m)
+    dump_omac(E)
+  if tsz is None: tsz = E.__class__.blksz
+  c0 = omac(E, 0, n)
+  y = ctr(E, m, c0)
+  ht = omac(E, 1, h)
+  yt = omac(E, 2, y)
+  if VERBOSE:
+    print 'c0 = %s' % hex(c0)
+    print 'ht = %s' % hex(ht)
+    print 'yt = %s' % hex(yt)
+  return y, C.ByteString((c0 ^ ht ^ yt)[:tsz])
+
+def eaxdec(E, n, h, y, t):
+  if VERBOSE:
+    print 'k = %s' % hex(k)
+    print 'n = %s' % hex(n)
+    print 'h = %s' % hex(h)
+    print 'y = %s' % hex(y)
+    print 't = %s' % hex(t)
+    dump_omac(E)
+  c0 = omac(E, 0, n)
+  m = ctr(E, y, c0)
+  ht = omac(E, 1, h)
+  yt = omac(E, 2, y)
+  if VERBOSE:
+    print 'c0 = %s' % hex(c0)
+    print 'ht = %s' % hex(ht)
+    print 'yt = %s' % hex(yt)
+  if t == (c0 ^ ht ^ yt)[:len(t)]: return m,
+  else: return None,
+
+def eaxgen(bc):
+  return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
+          (bc.blksz, 3*bc.blksz, 3*bc.blksz),
+          (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
+
+###--------------------------------------------------------------------------
+### PMAC.
+
+def ocb_masks(E):
+  blksz = E.__class__.blksz
+  p = poly(8*blksz)
+  x = C.GF(2); xinv = p.modinv(x)
+  z = Z(blksz)
+  L = E.encrypt(z)
+  Lxinv = mul_blk_gf(L, xinv, p)
+  Lgamma = 66*[L]
+  for i in xrange(1, len(Lgamma)):
+    Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
+  return Lgamma, Lxinv
+
+def dump_ocb(E):
+  Lgamma, Lxinv = ocb_masks(E)
+  print 'L x^-1 = %s' % hex(Lxinv)
+  for i, lg in enumerate(Lgamma[:16]):
+    print 'L x^%d = %s' % (i, hex(lg))
+
+def pmac1(E, m):
+  blksz = E.__class__.blksz
+  Lgamma, Lxinv = ocb_masks(E)
+  a = o = Z(blksz)
+  i = 0
+  v, tl = blocks(m, blksz)
+  for x in v:
+    i += 1
+    b = ntz(i)
+    o ^= Lgamma[b]
+    a ^= E.encrypt(x ^ o)
+    if VERBOSE:
+      print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+  if len(tl) == blksz: a ^= tl ^ Lxinv
+  else: a ^= pad10star(tl, blksz)
+  return E.encrypt(a)
+
+def pmac1_pub(E, m):
+  if VERBOSE: dump_ocb(E)
+  return pmac1(E, m),
+
+def pmacgen(bc):
+  return [(0,), (1,),
+          (3*bc.blksz,),
+          (3*bc.blksz - 5,)]
+
+###--------------------------------------------------------------------------
+### OCB.
+
+def ocb1enc(E, n, h, m, tsz = None):
+  ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
+  ## Associated-Data'.
+  blksz = E.__class__.blksz
+  if VERBOSE: dump_ocb(E)
+  Lgamma, Lxinv = ocb_masks(E)
+  if tsz is None: tsz = blksz
+  a = Z(blksz)
+  o = E.encrypt(n ^ Lgamma[0])
+  if VERBOSE: print 'R = %s' % hex(o)
+  i = 0
+  y = C.WriteBuffer()
+  v, tl = blocks(m, blksz)
+  for x in v:
+    i += 1
+    b = ntz(i)
+    o ^= Lgamma[b]
+    a ^= x
+    if VERBOSE:
+      print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+    y.put(E.encrypt(x ^ o) ^ o)
+  i += 1
+  b = ntz(i)
+  o ^= Lgamma[b]
+  n = len(tl)
+  if VERBOSE:
+    print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+    print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
+  yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
+  cfinal = tl ^ yfinal[:n]
+  a ^= o ^ (tl + yfinal[n:])
+  y.put(cfinal)
+  t = E.encrypt(a)
+  if h: t ^= pmac1(E, h)
+  return C.ByteString(y), C.ByteString(t[:tsz])
+
+def ocb1dec(E, n, h, y, t):
+  ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
+  ## Associated-Data'.
+  blksz = E.__class__.blksz
+  if VERBOSE: dump_ocb(E)
+  Lgamma, Lxinv = ocb_masks(E)
+  a = Z(blksz)
+  o = E.encrypt(n ^ Lgamma[0])
+  if VERBOSE: print 'R = %s' % hex(o)
+  i = 0
+  m = C.WriteBuffer()
+  v, tl = blocks(y, blksz)
+  for x in v:
+    i += 1
+    b = ntz(i)
+    o ^= Lgamma[b]
+    if VERBOSE:
+      print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+    u = E.decrypt(x ^ o) ^ o
+    m.put(u)
+    a ^= u
+  i += 1
+  b = ntz(i)
+  o ^= Lgamma[b]
+  n = len(tl)
+  if VERBOSE:
+    print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+    print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
+  yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
+  mfinal = tl ^ yfinal[:n]
+  a ^= o ^ (mfinal + yfinal[n:])
+  m.put(mfinal)
+  u = E.encrypt(a)
+  if h: u ^= pmac1(E, h)
+  if t == u[:len(t)]: return C.ByteString(m),
+  else: return None,
+
+def ocbgen(bc):
+  w = bc.blksz
+  return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
+          (w, 0, 3*w),
+          (w, 3*w, 3*w),
+          (w, 0, 3*w + 5),
+          (w, 3*w - 5, 3*w + 5)]
+
+###--------------------------------------------------------------------------
 ### Main program.
 
 class struct (object):
@@ -275,7 +665,16 @@ class struct (object):
 binarg = struct(mk = R.block, parse = C.bytes, show = safehex)
 intarg = struct(mk = lambda x: x, parse = int, show = None)
 
-MODEMAP = { 'cmac': (cmacgen, [binarg], cmac) }
+MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc),
+            'eax-dec': (dummygen, 4*[binarg], eaxdec),
+            'ccm-enc': (ccmgen, 3*[binarg] + [intarg], ccmenc),
+            'ccm-dec': (dummygen, 4*[binarg], ccmdec),
+            'cmac': (cmacgen, [binarg], cmac),
+            'gcm-enc': (gcmgen, 3*[binarg] + [intarg], gcmenc),
+            'gcm-dec': (dummygen, 4*[binarg], gcmdec),
+            'ocb1-enc': (ocbgen, 3*[binarg] + [intarg], ocb1enc),
+            'ocb1-dec': (dummygen, 4*[binarg], ocb1dec),
+            'pmac1': (pmacgen, [binarg], pmac1_pub) }
 
 mode = argv[1]
 bc = None
@@ -287,6 +686,7 @@ if bc is None: raise KeyError, argv[2]
 if len(argv) == 3:
   VERBOSE = False
   gen, argty, func = MODEMAP[mode]
+  if mode.endswith('-enc'): mode = mode[:-4]
   print '%s-%s {' % (bc.name, mode)
   for ksz in keylens(bc.keysz):
     for argvals in gen(bc):
@@ -305,6 +705,8 @@ else:
   gen, argty, func = MODEMAP[mode]
   args = [t.parse(a) for t, a in izip(argty, argv[4:])]
   rets = func(bc(k), *args)
-  for r in rets: print hex(r)
+  for r in rets:
+    if r is None: print "X"
+    else: print hex(r)
 
 ###----- That's all, folks --------------------------------------------------