algorithms.c: Add bindings for STROBE.
[catacomb-python] / t / t-algorithms.py
index 52decd6..fd239fb 100644 (file)
@@ -64,76 +64,82 @@ def different_key_size(ksz, sz):
 class HashBufferTestMixin (U.TestCase):
   """Mixin class for testing all of the various `hash...' methods."""
 
-  def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
+  HASHMETH = "hash"
+  def dohash(me, h, op, arg):
+    getattr(h, me.HASHMETH + op)(arg)
+    return me
+
+  def check_hashbuffer_hashn(me, w, bigendp, makefn, hashop):
     """Check `hashuN'."""
 
     ## Check encoding an integer.
     h0, donefn0 = makefn(w + 2)
-    hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
+    me.dohash(h0, "u8", 0x00)
+    me.dohash(h0, hashop, T.bytes_as_int(w, bigendp))
+    me.dohash(h0, "u8", w + 1)
     h1, donefn1 = makefn(w + 2)
-    h1.hash(T.span(w + 2))
+    me.dohash(h1, "", T.span(w + 2))
     me.assertEqual(donefn0(), donefn1())
 
     ## Check overflow detection.
     h0, _ = makefn(w)
-    me.assertRaises((OverflowError, ValueError),
-                    hashfn, h0, 1 << 8*w)
+    me.assertRaises(OverflowError, me.dohash, h0, hashop, 1 << 8*w)
 
-  def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
+  def check_hashbuffer_bufn(me, w, bigendp, makefn, hashop):
     """Check `hashbufN'."""
 
     ## Go through a number of different sizes.
     for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
       if n >= 1 << 8*w: continue
       h0, donefn0 = makefn(2 + w + n)
-      hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
+      me.dohash(h0, "u8", 0x00)
+      me.dohash(h0, hashop, T.span(n))
+      me.dohash(h0, "u8", 0xff)
       h1, donefn1 = makefn(2 + w + n)
-      h1.hash(T.prep_lenseq(w, n, bigendp, True))
+      me.dohash(h1, "", T.prep_lenseq(w, n, bigendp, True))
       me.assertEqual(donefn0(), donefn1())
 
     ## Check blocks which are too large for the length prefix.
     if w <= 3:
       n = 1 << 8*w
       h0, _ = makefn(w + n)
-      me.assertRaises((ValueError, OverflowError, TypeError),
-                      hashfn, h0, C.ByteString.zero(n))
+      me.assertRaises(ValueError,
+                      me.dohash, h0, hashop, C.ByteString.zero(n))
 
   def check_hashbuffer(me, makefn):
     """Test the various `hash...' methods."""
 
     ## Check `hashuN'.
-    me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
-    me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
-    me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
-    me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
-    if hasattr(makefn(0)[0], "hashu24"):
-      me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
-      me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
-      me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
-    me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
-    me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
-    me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
-    if hasattr(makefn(0)[0], "hashu64"):
-      me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
-      me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
-      me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
+    me.check_hashbuffer_hashn(1, True, makefn, "u8")
+    me.check_hashbuffer_hashn(2, True, makefn, "u16")
+    me.check_hashbuffer_hashn(2, True, makefn, "u16b")
+    me.check_hashbuffer_hashn(2, False, makefn, "u16l")
+    me.check_hashbuffer_hashn(3, True, makefn, "u24")
+    me.check_hashbuffer_hashn(3, True, makefn, "u24b")
+    me.check_hashbuffer_hashn(3, False, makefn, "u24l")
+    me.check_hashbuffer_hashn(4, True, makefn, "u32")
+    me.check_hashbuffer_hashn(4, True, makefn, "u32b")
+    me.check_hashbuffer_hashn(4, False, makefn, "u32l")
+    if hasattr(makefn(0)[0], me.HASHMETH + "u64"):
+      me.check_hashbuffer_hashn(8, True, makefn, "u64")
+      me.check_hashbuffer_hashn(8, True, makefn, "u64b")
+      me.check_hashbuffer_hashn(8, False, makefn, "u64l")
 
     ## Check `hashbufN'.
-    me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
-    me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
-    me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
-    me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
-    if hasattr(makefn(0)[0], "hashbuf24"):
-      me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
-      me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
-      me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
-    me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
-    me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
-    me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
-    if hasattr(makefn(0)[0], "hashbuf64"):
-      me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
-      me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
-      me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
+    me.check_hashbuffer_bufn(1, True, makefn, "buf8")
+    me.check_hashbuffer_bufn(2, True, makefn, "buf16")
+    me.check_hashbuffer_bufn(2, True, makefn, "buf16b")
+    me.check_hashbuffer_bufn(2, False, makefn, "buf16l")
+    me.check_hashbuffer_bufn(3, True, makefn, "buf24")
+    me.check_hashbuffer_bufn(3, True, makefn, "buf24b")
+    me.check_hashbuffer_bufn(3, False, makefn, "buf24l")
+    me.check_hashbuffer_bufn(4, True, makefn, "buf32")
+    me.check_hashbuffer_bufn(4, True, makefn, "buf32b")
+    me.check_hashbuffer_bufn(4, False, makefn, "buf32l")
+    if hasattr(makefn(0)[0], me.HASHMETH + "u64"):
+      me.check_hashbuffer_bufn(8, True, makefn, "buf64")
+      me.check_hashbuffer_bufn(8, True, makefn, "buf64b")
+      me.check_hashbuffer_bufn(8, False, makefn, "buf64l")
 
 ###--------------------------------------------------------------------------
 class TestKeysize (U.TestCase):
@@ -602,29 +608,28 @@ class BaseTestHash (HashBufferTestMixin):
     """
     Check hash class HCLS.
 
-    If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
-    `name', or `hashsz' attributes.  This test is mostly reused for MACs,
-    which don't have these attributes.
+    If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz'
+    attribute.  This test is mostly reused for MACs, which don't have this
+    attribute.
     """
     ## Check the class properties.
-    if need_bufsz:
-      me.assertEqual(type(hcls.name), str)
-      me.assertEqual(type(hcls.bufsz), int)
-      me.assertEqual(type(hcls.hashsz), int)
+    me.assertEqual(type(hcls.name), str)
+    if need_bufsz: me.assertEqual(type(hcls.bufsz), int)
+    me.assertEqual(type(hcls.hashsz), int)
 
     ## Set some initial values.
     m = T.span(131)
     h = hcls().hash(m).done()
 
     ## Check that hash length comes out right.
-    if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
+    me.assertEqual(len(h), hcls.hashsz)
 
     ## Check that we get the same answer if we split the message up.
     me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
 
     ## Check the `check' method.
     me.assertTrue(hcls().hash(m).check(h))
-    me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
+    me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa")))
 
     ## Check the menagerie of random hashing methods.
     def mkhash(_):
@@ -654,6 +659,7 @@ class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
     ## Test hashing.
     k = T.span(mcls.keysz.default)
     key = mcls(k)
+    me.assertEqual(key.hashsz, key.tagsz)
     me.check_hash(key, need_bufsz = False)
 
     ## Check that bad key lengths are rejected.
@@ -685,6 +691,9 @@ class TestPoly1305 (HashBufferTestMixin):
     t = key(u).hash(m).done()
 
     ## Check the key properties.
+    me.assertEqual(key.name, "poly1305")
+    me.assertEqual(key.tagsz, 16)
+    me.assertEqual(key.tagsz, 16)
     me.assertEqual(len(t), 16)
 
     ## Check that we get the same answer if we split the message up.
@@ -747,11 +756,11 @@ class TestKeccak (HashBufferTestMixin):
     st1.mix(m0).step()
     me.assertNotEqual(st0.extract(32), st1.extract(32))
 
-    ## Check state copying.
+    ## Check state copying and `mix' vs `set'.
     st1 = st0.copy()
     mask = st1.extract(len(m1))
     st0.mix(m1)
-    st1.mix(m1)
+    st1.set(m1 ^ mask)
     me.assertEqual(st0.extract(32), st1.extract(32))
 
     ## Check error conditions.
@@ -759,6 +768,9 @@ class TestKeccak (HashBufferTestMixin):
     me.assertRaises(ValueError, st0.extract, 201)
     st0.mix(T.span(200))
     me.assertRaises(ValueError, st0.mix, T.span(201))
+    st0.set(T.span(200))
+    me.assertRaises(ValueError, st0.set, T.span(201))
+    me.assertRaises(ValueError, st0.set, T.span(199))
 
   def check_shake(me, xcls, c, done_matches_xof = True):
     """
@@ -833,6 +845,96 @@ class TestKeccak (HashBufferTestMixin):
   def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
 
 ###--------------------------------------------------------------------------
+class TestStrobe (HashBufferTestMixin):
+
+  HASHMETH = "process"
+
+  def test_strobe(me):
+
+    ## Construction.
+    s0 = C.Strobe()
+    s1 = C.Strobe(128)
+    me.assertEqual(s0.l, 128)
+    me.assertEqual(s1.l, 128)
+    C.Strobe(704)
+    me.assertRaises(ValueError, C.Strobe, 127)
+    me.assertRaises(ValueError, C.Strobe, 736)
+    me.assertEqual(s0.role, C.STRBRL_UNDCD)
+    me.assertEqual(s1.role, C.STRBRL_UNDCD)
+
+    ## `process' vs operation-specific functions.  (This follows Hamburg's
+    ## `Concurrence' test.)
+    h = T.bin("testing")
+    s0.ad(h, f = "M"); s1.begin(C.STROBE_AD | C.STRBF_M).process(h).done()
+
+    t = s1.begin(C.STRBF_I | C.STRBF_A | C.STRBF_C).process(10); s1.done()
+    me.assertEqual(s0.prf(10), t)
+    me.assertEqual(t, C.bytes("8a13a189683bf5678170"))
+
+    h = T.bin("Hello")
+    s0.ad(h); s1.begin(" A   ").process(h).done()
+
+    m = T.bin("World"); c = s0.encout(m)
+    me.assertFalse(s1.activep)
+    m1 = s1.begin("IACT ").process(c); me.assertTrue(s1.activep);
+    s1.done(); me.assertFalse(s1.activep)
+    me.assertEqual(c, C.bytes("123bfbee34"))
+    me.assertEqual(m1, m)
+    me.assertEqual(s0.role, C.STRBRL_INIT)
+    me.assertEqual(s1.role, C.STRBRL_RESP)
+
+    m = T.bin("foo"); s0.clrout(m); s1.begin("IA T ").process(m).done()
+    m = T.bin("bar"); s0.clrin(m);  s1.begin(" A T ").process(m).done()
+
+    c = T.bin("baz"); m = s0.encin(c)
+    c1 = s1.begin(" ACT ").process(m); s1.done()
+    me.assertEqual(m, C.bytes("15e518"))
+    me.assertEqual(c1, c)
+
+    xxx = T.bin(199*"X")
+    for i in T.range(200):
+      c = s0.begin(" ACT ").process(xxx[:i]); s0.done(); s1.encin(c)
+
+    t = s1.begin("IAC  ").process(123); s1.done()
+    me.assertEqual(t, C.bytes
+      ("45476fc0806aee35e864c4f18e6ba62bd3eb1b1e8bef9042b30b0f15d00c3e9f"
+       "5d5904ab789d4c67eaed582473c15aa4424f11d52b21a296b36db3392e2ecbb2"
+       "dc6963bafba3b23882d061f1d335e86e470e8d819591bf0c223e24b925751d04"
+       "f789fc73bc55f7d2b3ed4881c625aa6321d31511b13f6d5e4ce54a"))
+    me.assertEqual(s0.prf(123), t)
+
+    ## Copying and MAC.
+    s2 = s0.copy()
+    t = s0.macout(16)
+    me.assertEqual(t, C.bytes("171419608e11e7c907d493209e17f26b"))
+    me.assertEqual(s2.begin("  CT ").process(16), t); s2.done()
+    s3 = s1.copy(); me.assertFalse(s3.macin(~t))
+    s3 = s1.copy(); me.assertTrue(s3.macin(t))
+    s3 = s1.copy(); me.assertFalse(s3.begin("I CT ").process(~t).done())
+    me.assertTrue(s1.begin("I CT ").process(t).done())
+
+    ## Test the remaining operations.  (From the Catacomb test vectors.)
+    k = T.bin("this is my super-secret key")
+    s0.key(k); s1.begin(" AC   ").process(k).done()
+
+    t = s1.begin(C.STROBE_PRF).process(32); s1.done()
+    me.assertEqual(t, C.bytes
+      ("5418e0f0ee7f982bbbdc2b4bbf49425d0088abfa98ee21d8ad8a3610d15ebb68"))
+    me.assertEqual(s0.prf(32), t)
+
+    s0.ratchet(32); s1.begin("C").process(32).done()
+
+    t = s1.begin("  CT ").process(16); s1.done()
+    me.assertEqual(t, C.bytes("b2084ebdfabd50768c91eebc190132cc"))
+    me.assertTrue(s0.macin(t))
+
+  def test_hashbuffer(me):
+    def mkhash(hsz):
+      s = C.Strobe(256).begin(C.STROBE_AD)
+      return s, lambda: s.done().macout(16)
+    me.check_hashbuffer(mkhash)
+
+###--------------------------------------------------------------------------
 class TestPRP (T.GenericTestMixin):
   """Test pseudorandom permutations (PRPs)."""