algorithms.c: Add a Keccak `set' method now there's upstream support.
[catacomb-python] / t / t-algorithms.py
index fc6bff2..455e947 100644 (file)
@@ -38,7 +38,7 @@ def bad_key_size(ksz):
   if isinstance(ksz, C.KeySZAny): return None
   elif isinstance(ksz, C.KeySZRange):
     if ksz.mod != 1: return ksz.min + 1
-    elif ksz.max != 0: return ksz.max + 1
+    elif ksz.max is not None: return ksz.max + 1
     elif ksz.min != 0: return ksz.min - 1
     else: return None
   elif isinstance(ksz, C.KeySZSet):
@@ -52,7 +52,7 @@ def different_key_size(ksz, sz):
   if isinstance(ksz, C.KeySZAny): return sz + 1
   elif isinstance(ksz, C.KeySZRange):
     if sz > ksz.min: return sz - ksz.mod
-    elif ksz.max == 0 or sz < ksz.max: return sz + ksz.mod
+    elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
     else: return None
   elif isinstance(ksz, C.KeySZSet):
     for sz1 in sorted(ksz.set):
@@ -76,8 +76,7 @@ class HashBufferTestMixin (U.TestCase):
 
     ## Check overflow detection.
     h0, _ = makefn(w)
-    me.assertRaises((OverflowError, ValueError),
-                    hashfn, h0, 1 << 8*w)
+    me.assertRaises(OverflowError, hashfn, h0, 1 << 8*w)
 
   def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
     """Check `hashbufN'."""
@@ -95,8 +94,7 @@ class HashBufferTestMixin (U.TestCase):
     if w <= 3:
       n = 1 << 8*w
       h0, _ = makefn(w + n)
-      me.assertRaises((ValueError, OverflowError, TypeError),
-                      hashfn, h0, C.ByteString.zero(n))
+      me.assertRaises(ValueError, hashfn, h0, C.ByteString.zero(n))
 
   def check_hashbuffer(me, makefn):
     """Test the various `hash...' methods."""
@@ -106,10 +104,9 @@ class HashBufferTestMixin (U.TestCase):
     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
     me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
-    if hasattr(makefn(0)[0], "hashu24"):
-      me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
-      me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
-      me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
+    me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
+    me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
+    me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
     me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
@@ -123,10 +120,9 @@ class HashBufferTestMixin (U.TestCase):
     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
     me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
-    if hasattr(makefn(0)[0], "hashbuf24"):
-      me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
-      me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
-      me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
+    me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
+    me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
+    me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
     me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
@@ -145,7 +141,7 @@ class TestKeysize (U.TestCase):
     me.assertEqual(type(ksz), C.KeySZAny)
     me.assertEqual(ksz.default, 20)
     me.assertEqual(ksz.min, 0)
-    me.assertEqual(ksz.max, 0)
+    me.assertEqual(ksz.max, None)
     for n in [0, 12, 20, 5000]:
       me.assertTrue(ksz.check(n))
       me.assertEqual(ksz.best(n), n)
@@ -157,7 +153,7 @@ class TestKeysize (U.TestCase):
     me.assertEqual(type(ksz), C.KeySZAny)
     me.assertEqual(ksz.default, 32)
     me.assertEqual(ksz.min, 0)
-    me.assertEqual(ksz.max, 0)
+    me.assertEqual(ksz.max, None)
     for n in [0, 12, 20, 5000]:
       me.assertTrue(ksz.check(n))
       me.assertEqual(ksz.best(n), n)
@@ -167,7 +163,7 @@ class TestKeysize (U.TestCase):
     ksz = C.KeySZAny(15)
     me.assertEqual(ksz.default, 15)
     me.assertEqual(ksz.min, 0)
-    me.assertEqual(ksz.max, 0)
+    me.assertEqual(ksz.max, None)
     me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
     me.assertEqual(C.KeySZAny(0).default, 0)
 
@@ -180,7 +176,7 @@ class TestKeysize (U.TestCase):
     me.assertEqual(ksz.default, 32)
     me.assertEqual(ksz.min, 10)
     me.assertEqual(ksz.max, 32)
-    me.assertEqual(set(ksz.set), set([10, 16, 32]))
+    me.assertEqual(ksz.set, set([10, 16, 32]))
     for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
                          (15, 10, 16), (16, 16, 16), (17, 16, 32),
                          (31, 16, 32), (32, 32, 32), (33, 32, None)]:
@@ -194,12 +190,12 @@ class TestKeysize (U.TestCase):
     ## Check construction.
     ksz = C.KeySZSet(7)
     me.assertEqual(ksz.default, 7)
-    me.assertEqual(set(ksz.set), set([7]))
+    me.assertEqual(ksz.set, set([7]))
     me.assertEqual(ksz.min, 7)
     me.assertEqual(ksz.max, 7)
-    ksz = C.KeySZSet(7, [3, 6, 9])
+    ksz = C.KeySZSet(7, iter([3, 6, 9]))
     me.assertEqual(ksz.default, 7)
-    me.assertEqual(set(ksz.set), set([3, 6, 7, 9]))
+    me.assertEqual(ksz.set, set([3, 6, 7, 9]))
     me.assertEqual(ksz.min, 3)
     me.assertEqual(ksz.max, 9)
 
@@ -230,6 +226,11 @@ class TestKeysize (U.TestCase):
     me.assertEqual(ksz.min, 21)
     me.assertEqual(ksz.max, 35)
     me.assertEqual(ksz.mod, 7)
+    ksz = C.KeySZRange(28, 21, None, 7)
+    me.assertEqual(ksz.min, 21)
+    me.assertEqual(ksz.max, None)
+    me.assertEqual(ksz.mod, 7)
+    me.assertEqual(ksz.pad(36), 42)
     me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
     me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
     me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
@@ -597,29 +598,28 @@ class BaseTestHash (HashBufferTestMixin):
     """
     Check hash class HCLS.
 
-    If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
-    `name', or `hashsz' attributes.  This test is mostly reused for MACs,
-    which don't have these attributes.
+    If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz'
+    attribute.  This test is mostly reused for MACs, which don't have this
+    attribute.
     """
     ## Check the class properties.
-    if need_bufsz:
-      me.assertEqual(type(hcls.name), str)
-      me.assertEqual(type(hcls.bufsz), int)
-      me.assertEqual(type(hcls.hashsz), int)
+    me.assertEqual(type(hcls.name), str)
+    if need_bufsz: me.assertEqual(type(hcls.bufsz), int)
+    me.assertEqual(type(hcls.hashsz), int)
 
     ## Set some initial values.
     m = T.span(131)
     h = hcls().hash(m).done()
 
     ## Check that hash length comes out right.
-    if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
+    me.assertEqual(len(h), hcls.hashsz)
 
     ## Check that we get the same answer if we split the message up.
     me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
 
     ## Check the `check' method.
     me.assertTrue(hcls().hash(m).check(h))
-    me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
+    me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa")))
 
     ## Check the menagerie of random hashing methods.
     def mkhash(_):
@@ -649,6 +649,7 @@ class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
     ## Test hashing.
     k = T.span(mcls.keysz.default)
     key = mcls(k)
+    me.assertEqual(key.hashsz, key.tagsz)
     me.check_hash(key, need_bufsz = False)
 
     ## Check that bad key lengths are rejected.
@@ -668,7 +669,7 @@ class TestPoly1305 (HashBufferTestMixin):
     me.assertEqual(C.poly1305.name, "poly1305")
     me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
     me.assertEqual(C.poly1305.keysz.default, 16)
-    me.assertEqual(set(C.poly1305.keysz.set), set([16]))
+    me.assertEqual(C.poly1305.keysz.set, set([16]))
     me.assertEqual(C.poly1305.tagsz, 16)
     me.assertEqual(C.poly1305.masksz, 16)
 
@@ -680,6 +681,9 @@ class TestPoly1305 (HashBufferTestMixin):
     t = key(u).hash(m).done()
 
     ## Check the key properties.
+    me.assertEqual(key.name, "poly1305")
+    me.assertEqual(key.tagsz, 16)
+    me.assertEqual(key.tagsz, 16)
     me.assertEqual(len(t), 16)
 
     ## Check that we get the same answer if we split the message up.
@@ -742,11 +746,21 @@ class TestKeccak (HashBufferTestMixin):
     st1.mix(m0).step()
     me.assertNotEqual(st0.extract(32), st1.extract(32))
 
+    ## Check state copying and `mix' vs `set'.
+    st1 = st0.copy()
+    mask = st1.extract(len(m1))
+    st0.mix(m1)
+    st1.set(m1 ^ mask)
+    me.assertEqual(st0.extract(32), st1.extract(32))
+
     ## Check error conditions.
     _ = st0.extract(200)
     me.assertRaises(ValueError, st0.extract, 201)
     st0.mix(T.span(200))
     me.assertRaises(ValueError, st0.mix, T.span(201))
+    st0.set(T.span(200))
+    me.assertRaises(ValueError, st0.set, T.span(201))
+    me.assertRaises(ValueError, st0.set, T.span(199))
 
   def check_shake(me, xcls, c, done_matches_xof = True):
     """
@@ -781,7 +795,7 @@ class TestKeccak (HashBufferTestMixin):
 
     ## Check masking.
     x = xcls().hash(m).xof()
-    me.assertEqual(x.mask(m), C.ByteString(m) ^ C.ByteString(h[0:len(m)]))
+    me.assertEqual(x.mask(m), m ^ h[0:len(m)])
 
     ## Check the `check' method.
     me.assertTrue(xcls().hash(m).check(h0))
@@ -790,14 +804,14 @@ class TestKeccak (HashBufferTestMixin):
     ## Check the menagerie of random hashing methods.
     def mkhash(_):
       x = xcls(func = func, perso = perso)
-      return x, lambda: x.done(100 - x.rate//2)
+      return x, x.done
     me.check_hashbuffer(mkhash)
 
     ## Check the state machine tracking.
     x = xcls(); me.assertEqual(x.state, "absorb")
     x.hash(m); me.assertEqual(x.state, "absorb")
     xx = x.copy()
-    h = xx.done(100 - x.rate//2)
+    h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
     me.assertEqual(xx.state, "dead")
     me.assertRaises(ValueError, xx.done, 1)
     me.assertRaises(ValueError, xx.get, 1)
@@ -813,7 +827,7 @@ class TestKeccak (HashBufferTestMixin):
 
   def check_kmac(me, mcls, c):
     k = T.span(32)
-    me.check_shake(lambda func = None, perso = T.bin(""):
+    me.check_shake(lambda func = None, perso = None:
                      mcls(k, perso = perso),
                    c, done_matches_xof = False)