symm/ocb3.h, symm/ocb3-def.h: Implement the OCB3 auth'ned encryption mode.
[catacomb] / utils / advmodes
index c65a933..bee4e34 100755 (executable)
@@ -1,6 +1,6 @@
 #! /usr/bin/python
 
-from sys import argv
+from sys import argv, exit
 from struct import unpack, pack
 from itertools import izip
 import catacomb as C
@@ -587,6 +587,43 @@ def pmac2(E, m):
   else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p)
   return E.encrypt(a)
 
+def ocb3_masks(E):
+  Lgamma, _ = ocb_masks(E)
+  Lstar = Lgamma[0]
+  Ldollar = Lgamma[1]
+  return Lstar, Ldollar, Lgamma[2:]
+
+def dump_ocb3(E):
+  Lstar, Ldollar, Lgamma = ocb3_masks(E)
+  print 'L_* = %s' % hex(Lstar)
+  print 'L_$ = %s' % hex(Ldollar)
+  for i, lg in enumerate(Lgamma[:16]):
+    print 'L x^%d = %s' % (i, hex(lg))
+
+def pmac3(E, m):
+  ## Note that `PMAC3' is /not/ a secure MAC.  It depends on other parts of
+  ## OCB3 to prevent a rather easy linear-algebra attack.
+  blksz = E.__class__.blksz
+  Lstar, Ldollar, Lgamma = ocb3_masks(E)
+  a = o = Z(blksz)
+  i = 0
+  v, tl = blocks0(m, blksz)
+  for x in v:
+    i += 1
+    b = ntz(i)
+    o ^= Lgamma[b]
+    a ^= E.encrypt(x ^ o)
+    if VERBOSE:
+      print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+  if tl:
+    o ^= Lstar
+    a ^= E.encrypt(pad10star(tl, blksz) ^ o)
+    if VERBOSE:
+      print 'Z[%d]: * -> %s' % (i, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+  return a
+
 def pmac1_pub(E, m):
   if VERBOSE: dump_ocb(E)
   return pmac1(E, m),
@@ -734,6 +771,138 @@ def ocb2dec(E, n, h, y, t):
   if t == u[:len(t)]: return C.ByteString(m),
   else: return None,
 
+OCB3_STRETCH = {  4: ( 4,  17),
+                  8: ( 5,  25),
+                 12: ( 6,  33),
+                 16: ( 6,   8),
+                 24: ( 7,  40),
+                 32: ( 8,   1),
+                 48: ( 8,  80),
+                 64: ( 8, 176),
+                 96: ( 9, 160),
+                128: ( 9, 352),
+                200: (10, 192) }
+
+def ocb3nonce(E, n, tsz):
+
+  ## Figure out how much we need to glue onto the nonce.  This ends up being
+  ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
+  ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
+  ## v - 1.  But this is an annoying way to think about it because of the
+  ## byte misalignment.  Instead, think of it as a byte-aligned prefix
+  ## encoding the tag and an `is the nonce full-length' flag, followed by
+  ## optional padding, and then the nonce:
+  ##
+  ##    F || N                  if l(N) = w - f
+  ##    F || 0^p || 1 || N      otherwise
+  ##
+  ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
+  ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
+  blksz = E.__class__.blksz
+  tszbits = min(C.MP(8*blksz - 1).nbits, 8)
+  fwd = tszbits/8 + 1
+  f = 8*(tsz%blksz) << + 8*fwd - tszbits
+
+  ## Form the augmented nonce.
+  nb = C.WriteBuffer()
+  nsz, nwd = len(n), blksz - fwd
+  if nsz == nwd: f |= 1
+  nb.put(C.MP(f).storeb(fwd))
+  if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
+  nb.put(n)
+  nn = C.ByteString(nb)
+  if VERBOSE: print 'aug-nonce = %s' % hex(nn)
+
+  ## Calculate the initial offset.
+  split, shift = OCB3_STRETCH[blksz]
+  t2pw = C.MP(0).setbit(8*blksz) - 1
+  lomask = (C.MP(0).setbit(split) - 1)
+  himask = ~lomask
+  top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
+  ktop = C.MP.loadb(E.encrypt(top))
+  stretch = (ktop << 8*blksz) | (ktop ^ (ktop << shift)&t2pw)
+  o = (stretch >> 8*blksz - bottom).storeb(blksz)
+  if VERBOSE:
+    print 'stretch = %s' % hex(stretch.storeb(2*blksz))
+    print 'Z[0] = %s' % hex(o)
+
+  return o
+
+def ocb3enc(E, n, h, m, tsz = None):
+  blksz = E.__class__.blksz
+  if tsz is None: tsz = blksz
+  Lstar, Ldollar, Lgamma = ocb3_masks(E)
+  if VERBOSE: dump_ocb3(E)
+
+  ## Set things up.
+  o = ocb3nonce(E, n, tsz)
+  a = C.ByteString.zero(blksz)
+
+  ## Split the message into blocks.
+  i = 0
+  y = C.WriteBuffer()
+  v, tl = blocks0(m, blksz)
+  for x in v:
+    i += 1
+    b = ntz(i)
+    o ^= Lgamma[b]
+    a ^= x
+    if VERBOSE:
+      print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+    y.put(E.encrypt(x ^ o) ^ o)
+  if tl:
+    o ^= Lstar
+    n = len(tl)
+    pad = E.encrypt(o)
+    a ^= pad10star(tl, blksz)
+    if VERBOSE:
+      print 'Z[%d]: * -> %s' % (i, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+    y.put(tl ^ pad[0:n])
+  o ^= Ldollar
+  t = E.encrypt(a ^ o) ^ pmac3(E, h)
+  return C.ByteString(y), C.ByteString(t[:tsz])
+
+def ocb3dec(E, n, h, y, t):
+  blksz = E.__class__.blksz
+  tsz = len(t)
+  Lstar, Ldollar, Lgamma = ocb3_masks(E)
+  if VERBOSE: dump_ocb3(E)
+
+  ## Set things up.
+  o = ocb3nonce(E, n, tsz)
+  a = C.ByteString.zero(blksz)
+
+  ## Split the message into blocks.
+  i = 0
+  m = C.WriteBuffer()
+  v, tl = blocks0(y, blksz)
+  for x in v:
+    i += 1
+    b = ntz(i)
+    o ^= Lgamma[b]
+    if VERBOSE:
+      print 'Z[%d]: %d -> %s' % (i, b, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+    u = E.encrypt(x ^ o) ^ o
+    m.put(u)
+    a ^= u
+  if tl:
+    o ^= Lstar
+    n = len(tl)
+    pad = E.encrypt(o)
+    if VERBOSE:
+      print 'Z[%d]: * -> %s' % (i, hex(o))
+      print 'A[%d]: %s' % (i, hex(a))
+    u = tl ^ pad[0:n]
+    m.put(u)
+    a ^= pad10star(u, blksz)
+  o ^= Ldollar
+  u = E.encrypt(a ^ o) ^ pmac3(E, h)
+  if t == u[:tsz]: return C.ByteString(m),
+  else: return None,
+
 def ocbgen(bc):
   w = bc.blksz
   return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
@@ -742,6 +911,44 @@ def ocbgen(bc):
           (w, 0, 3*w + 5),
           (w, 3*w - 5, 3*w + 5)]
 
+def ocb3gen(bc):
+  w = bc.blksz
+  return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
+          (w - 5, 0, 3*w),
+          (w - 3, 3*w, 3*w),
+          (w - 2, 0, 3*w + 5),
+          (w - 2, 3*w - 5, 3*w + 5)]
+
+def ocb3_mct(bc, ksz, tsz):
+  k = C.ByteString(C.WriteBuffer().zero(ksz - 4).putu32(8*tsz))
+  E = bc(k)
+  n = C.MP(1)
+  nw = bc.blksz - 4
+  cbuf = C.WriteBuffer()
+  for i in xrange(128):
+    s = C.ByteString.zero(i)
+    y, t = ocb3enc(E, n.storeb(nw), s, s, tsz); n += 1; cbuf.put(y).put(t)
+    y, t = ocb3enc(E, n.storeb(nw), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t)
+    y, t = ocb3enc(E, n.storeb(nw), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t)
+  _, t = ocb3enc(E, n.storeb(nw), C.ByteString(cbuf), EMPTY, tsz)
+  print hex(t)
+
+def ocb3_mct2(bc):
+  k = C.bytes('000102030405060708090a0b0c0d0e0f')
+  E = bc(k)
+  tsz = min(E.blksz, 32)
+  n = C.MP(1)
+  cbuf = C.WriteBuffer()
+  for i in xrange(128):
+    sbuf = C.WriteBuffer()
+    for j in xrange(i): sbuf.putu8(j)
+    s = C.ByteString(sbuf)
+    y, t = ocb3enc(E, n.storeb(2), s, s, tsz); n += 1; cbuf.put(y).put(t)
+    y, t = ocb3enc(E, n.storeb(2), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t)
+    y, t = ocb3enc(E, n.storeb(2), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t)
+  _, t = ocb3enc(E, n.storeb(2), C.ByteString(cbuf), EMPTY, tsz)
+  print hex(t)
+
 ###--------------------------------------------------------------------------
 ### Main program.
 
@@ -763,6 +970,8 @@ MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc),
             'ocb1-dec': (dummygen, 4*[binarg], ocb1dec),
             'ocb2-enc': (ocbgen, 3*[binarg] + [intarg], ocb2enc),
             'ocb2-dec': (dummygen, 4*[binarg], ocb2dec),
+            'ocb3-enc': (ocb3gen, 3*[binarg] + [intarg], ocb3enc),
+            'ocb3-dec': (dummygen, 4*[binarg], ocb3dec),
             'pmac1': (pmacgen, [binarg], pmac1_pub) }
 
 mode = argv[1]
@@ -772,6 +981,15 @@ for d in CUSTOM, C.gcprps:
   except KeyError: pass
   else: break
 if bc is None: raise KeyError, argv[2]
+if len(argv) == 5 and mode == 'ocb3-mct':
+  VERBOSE = False
+  ksz, tsz = int(argv[3]), int(argv[4])
+  ocb3_mct(bc, ksz, tsz)
+  exit(0)
+if len(argv) == 3 and mode == 'ocb3-mct2':
+  VERBOSE = False
+  ocb3_mct2(bc)
+  exit(0)
 if len(argv) == 3:
   VERBOSE = False
   gen, argty, func = MODEMAP[mode]