X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/blobdiff_plain/5a5e2e11718380b70ec5c9ef5c15e8ea2773ceba..54fd7594ee5df9dbc9745d98adaa01a5ed43b6e4:/t/t-algorithms.py diff --git a/t/t-algorithms.py b/t/t-algorithms.py index d7b97cb..fd239fb 100644 --- a/t/t-algorithms.py +++ b/t/t-algorithms.py @@ -38,7 +38,7 @@ def bad_key_size(ksz): if isinstance(ksz, C.KeySZAny): return None elif isinstance(ksz, C.KeySZRange): if ksz.mod != 1: return ksz.min + 1 - elif ksz.max != 0: return ksz.max + 1 + elif ksz.max is not None: return ksz.max + 1 elif ksz.min != 0: return ksz.min - 1 else: return None elif isinstance(ksz, C.KeySZSet): @@ -52,7 +52,7 @@ def different_key_size(ksz, sz): if isinstance(ksz, C.KeySZAny): return sz + 1 elif isinstance(ksz, C.KeySZRange): if sz > ksz.min: return sz - ksz.mod - elif ksz.max == 0 or sz < ksz.max: return sz + ksz.mod + elif ksz.max is None or sz < ksz.max: return sz + ksz.mod else: return None elif isinstance(ksz, C.KeySZSet): for sz1 in sorted(ksz.set): @@ -64,76 +64,82 @@ def different_key_size(ksz, sz): class HashBufferTestMixin (U.TestCase): """Mixin class for testing all of the various `hash...' methods.""" - def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn): + HASHMETH = "hash" + def dohash(me, h, op, arg): + getattr(h, me.HASHMETH + op)(arg) + return me + + def check_hashbuffer_hashn(me, w, bigendp, makefn, hashop): """Check `hashuN'.""" ## Check encoding an integer. h0, donefn0 = makefn(w + 2) - hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1) + me.dohash(h0, "u8", 0x00) + me.dohash(h0, hashop, T.bytes_as_int(w, bigendp)) + me.dohash(h0, "u8", w + 1) h1, donefn1 = makefn(w + 2) - h1.hash(T.span(w + 2)) + me.dohash(h1, "", T.span(w + 2)) me.assertEqual(donefn0(), donefn1()) ## Check overflow detection. h0, _ = makefn(w) - me.assertRaises((OverflowError, ValueError), - hashfn, h0, 1 << 8*w) + me.assertRaises(OverflowError, me.dohash, h0, hashop, 1 << 8*w) - def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn): + def check_hashbuffer_bufn(me, w, bigendp, makefn, hashop): """Check `hashbufN'.""" ## Go through a number of different sizes. for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]: if n >= 1 << 8*w: continue h0, donefn0 = makefn(2 + w + n) - hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff) + me.dohash(h0, "u8", 0x00) + me.dohash(h0, hashop, T.span(n)) + me.dohash(h0, "u8", 0xff) h1, donefn1 = makefn(2 + w + n) - h1.hash(T.prep_lenseq(w, n, bigendp, True)) + me.dohash(h1, "", T.prep_lenseq(w, n, bigendp, True)) me.assertEqual(donefn0(), donefn1()) ## Check blocks which are too large for the length prefix. if w <= 3: n = 1 << 8*w h0, _ = makefn(w + n) - me.assertRaises((ValueError, OverflowError, TypeError), - hashfn, h0, C.ByteString.zero(n)) + me.assertRaises(ValueError, + me.dohash, h0, hashop, C.ByteString.zero(n)) def check_hashbuffer(me, makefn): """Test the various `hash...' methods.""" ## Check `hashuN'. - me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n)) - me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n)) - me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n)) - me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n)) - if hasattr(makefn(0)[0], "hashu24"): - me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n)) - me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n)) - me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n)) - me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n)) - me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n)) - me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n)) - if hasattr(makefn(0)[0], "hashu64"): - me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n)) - me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n)) - me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n)) + me.check_hashbuffer_hashn(1, True, makefn, "u8") + me.check_hashbuffer_hashn(2, True, makefn, "u16") + me.check_hashbuffer_hashn(2, True, makefn, "u16b") + me.check_hashbuffer_hashn(2, False, makefn, "u16l") + me.check_hashbuffer_hashn(3, True, makefn, "u24") + me.check_hashbuffer_hashn(3, True, makefn, "u24b") + me.check_hashbuffer_hashn(3, False, makefn, "u24l") + me.check_hashbuffer_hashn(4, True, makefn, "u32") + me.check_hashbuffer_hashn(4, True, makefn, "u32b") + me.check_hashbuffer_hashn(4, False, makefn, "u32l") + if hasattr(makefn(0)[0], me.HASHMETH + "u64"): + me.check_hashbuffer_hashn(8, True, makefn, "u64") + me.check_hashbuffer_hashn(8, True, makefn, "u64b") + me.check_hashbuffer_hashn(8, False, makefn, "u64l") ## Check `hashbufN'. - me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x)) - me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x)) - me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x)) - me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x)) - if hasattr(makefn(0)[0], "hashbuf24"): - me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x)) - me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x)) - me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x)) - me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x)) - me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x)) - me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x)) - if hasattr(makefn(0)[0], "hashbuf64"): - me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x)) - me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x)) - me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x)) + me.check_hashbuffer_bufn(1, True, makefn, "buf8") + me.check_hashbuffer_bufn(2, True, makefn, "buf16") + me.check_hashbuffer_bufn(2, True, makefn, "buf16b") + me.check_hashbuffer_bufn(2, False, makefn, "buf16l") + me.check_hashbuffer_bufn(3, True, makefn, "buf24") + me.check_hashbuffer_bufn(3, True, makefn, "buf24b") + me.check_hashbuffer_bufn(3, False, makefn, "buf24l") + me.check_hashbuffer_bufn(4, True, makefn, "buf32") + me.check_hashbuffer_bufn(4, True, makefn, "buf32b") + me.check_hashbuffer_bufn(4, False, makefn, "buf32l") + if hasattr(makefn(0)[0], me.HASHMETH + "u64"): + me.check_hashbuffer_bufn(8, True, makefn, "buf64") + me.check_hashbuffer_bufn(8, True, makefn, "buf64b") + me.check_hashbuffer_bufn(8, False, makefn, "buf64l") ###-------------------------------------------------------------------------- class TestKeysize (U.TestCase): @@ -145,7 +151,7 @@ class TestKeysize (U.TestCase): me.assertEqual(type(ksz), C.KeySZAny) me.assertEqual(ksz.default, 20) me.assertEqual(ksz.min, 0) - me.assertEqual(ksz.max, 0) + me.assertEqual(ksz.max, None) for n in [0, 12, 20, 5000]: me.assertTrue(ksz.check(n)) me.assertEqual(ksz.best(n), n) @@ -157,7 +163,7 @@ class TestKeysize (U.TestCase): me.assertEqual(type(ksz), C.KeySZAny) me.assertEqual(ksz.default, 32) me.assertEqual(ksz.min, 0) - me.assertEqual(ksz.max, 0) + me.assertEqual(ksz.max, None) for n in [0, 12, 20, 5000]: me.assertTrue(ksz.check(n)) me.assertEqual(ksz.best(n), n) @@ -167,7 +173,7 @@ class TestKeysize (U.TestCase): ksz = C.KeySZAny(15) me.assertEqual(ksz.default, 15) me.assertEqual(ksz.min, 0) - me.assertEqual(ksz.max, 0) + me.assertEqual(ksz.max, None) me.assertRaises(ValueError, lambda: C.KeySZAny(-8)) me.assertEqual(C.KeySZAny(0).default, 0) @@ -180,7 +186,7 @@ class TestKeysize (U.TestCase): me.assertEqual(ksz.default, 32) me.assertEqual(ksz.min, 10) me.assertEqual(ksz.max, 32) - me.assertEqual(set(ksz.set), set([10, 16, 32])) + me.assertEqual(ksz.set, set([10, 16, 32])) for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16), (15, 10, 16), (16, 16, 16), (17, 16, 32), (31, 16, 32), (32, 32, 32), (33, 32, None)]: @@ -194,12 +200,12 @@ class TestKeysize (U.TestCase): ## Check construction. ksz = C.KeySZSet(7) me.assertEqual(ksz.default, 7) - me.assertEqual(set(ksz.set), set([7])) + me.assertEqual(ksz.set, set([7])) me.assertEqual(ksz.min, 7) me.assertEqual(ksz.max, 7) - ksz = C.KeySZSet(7, [3, 6, 9]) + ksz = C.KeySZSet(7, iter([3, 6, 9])) me.assertEqual(ksz.default, 7) - me.assertEqual(set(ksz.set), set([3, 6, 7, 9])) + me.assertEqual(ksz.set, set([3, 6, 7, 9])) me.assertEqual(ksz.min, 3) me.assertEqual(ksz.max, 9) @@ -230,6 +236,11 @@ class TestKeysize (U.TestCase): me.assertEqual(ksz.min, 21) me.assertEqual(ksz.max, 35) me.assertEqual(ksz.mod, 7) + ksz = C.KeySZRange(28, 21, None, 7) + me.assertEqual(ksz.min, 21) + me.assertEqual(ksz.max, None) + me.assertEqual(ksz.mod, 7) + me.assertEqual(ksz.pad(36), 42) me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7) me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7) me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7) @@ -597,29 +608,28 @@ class BaseTestHash (HashBufferTestMixin): """ Check hash class HCLS. - If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz', - `name', or `hashsz' attributes. This test is mostly reused for MACs, - which don't have these attributes. + If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz' + attribute. This test is mostly reused for MACs, which don't have this + attribute. """ ## Check the class properties. - if need_bufsz: - me.assertEqual(type(hcls.name), str) - me.assertEqual(type(hcls.bufsz), int) - me.assertEqual(type(hcls.hashsz), int) + me.assertEqual(type(hcls.name), str) + if need_bufsz: me.assertEqual(type(hcls.bufsz), int) + me.assertEqual(type(hcls.hashsz), int) ## Set some initial values. m = T.span(131) h = hcls().hash(m).done() ## Check that hash length comes out right. - if need_bufsz: me.assertEqual(len(h), hcls.hashsz) + me.assertEqual(len(h), hcls.hashsz) ## Check that we get the same answer if we split the message up. me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done()) ## Check the `check' method. me.assertTrue(hcls().hash(m).check(h)) - me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa"))) + me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa"))) ## Check the menagerie of random hashing methods. def mkhash(_): @@ -649,6 +659,7 @@ class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin): ## Test hashing. k = T.span(mcls.keysz.default) key = mcls(k) + me.assertEqual(key.hashsz, key.tagsz) me.check_hash(key, need_bufsz = False) ## Check that bad key lengths are rejected. @@ -668,7 +679,7 @@ class TestPoly1305 (HashBufferTestMixin): me.assertEqual(C.poly1305.name, "poly1305") me.assertEqual(type(C.poly1305.keysz), C.KeySZSet) me.assertEqual(C.poly1305.keysz.default, 16) - me.assertEqual(set(C.poly1305.keysz.set), set([16])) + me.assertEqual(C.poly1305.keysz.set, set([16])) me.assertEqual(C.poly1305.tagsz, 16) me.assertEqual(C.poly1305.masksz, 16) @@ -680,6 +691,9 @@ class TestPoly1305 (HashBufferTestMixin): t = key(u).hash(m).done() ## Check the key properties. + me.assertEqual(key.name, "poly1305") + me.assertEqual(key.tagsz, 16) + me.assertEqual(key.tagsz, 16) me.assertEqual(len(t), 16) ## Check that we get the same answer if we split the message up. @@ -742,11 +756,11 @@ class TestKeccak (HashBufferTestMixin): st1.mix(m0).step() me.assertNotEqual(st0.extract(32), st1.extract(32)) - ## Check state copying. + ## Check state copying and `mix' vs `set'. st1 = st0.copy() mask = st1.extract(len(m1)) st0.mix(m1) - st1.mix(m1) + st1.set(m1 ^ mask) me.assertEqual(st0.extract(32), st1.extract(32)) ## Check error conditions. @@ -754,6 +768,9 @@ class TestKeccak (HashBufferTestMixin): me.assertRaises(ValueError, st0.extract, 201) st0.mix(T.span(200)) me.assertRaises(ValueError, st0.mix, T.span(201)) + st0.set(T.span(200)) + me.assertRaises(ValueError, st0.set, T.span(201)) + me.assertRaises(ValueError, st0.set, T.span(199)) def check_shake(me, xcls, c, done_matches_xof = True): """ @@ -797,14 +814,14 @@ class TestKeccak (HashBufferTestMixin): ## Check the menagerie of random hashing methods. def mkhash(_): x = xcls(func = func, perso = perso) - return x, lambda: x.done(100 - x.rate//2) + return x, x.done me.check_hashbuffer(mkhash) ## Check the state machine tracking. x = xcls(); me.assertEqual(x.state, "absorb") x.hash(m); me.assertEqual(x.state, "absorb") xx = x.copy() - h = xx.done(100 - x.rate//2) + h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2) me.assertEqual(xx.state, "dead") me.assertRaises(ValueError, xx.done, 1) me.assertRaises(ValueError, xx.get, 1) @@ -820,7 +837,7 @@ class TestKeccak (HashBufferTestMixin): def check_kmac(me, mcls, c): k = T.span(32) - me.check_shake(lambda func = None, perso = T.bin(""): + me.check_shake(lambda func = None, perso = None: mcls(k, perso = perso), c, done_matches_xof = False) @@ -828,6 +845,96 @@ class TestKeccak (HashBufferTestMixin): def test_kmac256(me): me.check_kmac(C.KMAC256, 64) ###-------------------------------------------------------------------------- +class TestStrobe (HashBufferTestMixin): + + HASHMETH = "process" + + def test_strobe(me): + + ## Construction. + s0 = C.Strobe() + s1 = C.Strobe(128) + me.assertEqual(s0.l, 128) + me.assertEqual(s1.l, 128) + C.Strobe(704) + me.assertRaises(ValueError, C.Strobe, 127) + me.assertRaises(ValueError, C.Strobe, 736) + me.assertEqual(s0.role, C.STRBRL_UNDCD) + me.assertEqual(s1.role, C.STRBRL_UNDCD) + + ## `process' vs operation-specific functions. (This follows Hamburg's + ## `Concurrence' test.) + h = T.bin("testing") + s0.ad(h, f = "M"); s1.begin(C.STROBE_AD | C.STRBF_M).process(h).done() + + t = s1.begin(C.STRBF_I | C.STRBF_A | C.STRBF_C).process(10); s1.done() + me.assertEqual(s0.prf(10), t) + me.assertEqual(t, C.bytes("8a13a189683bf5678170")) + + h = T.bin("Hello") + s0.ad(h); s1.begin(" A ").process(h).done() + + m = T.bin("World"); c = s0.encout(m) + me.assertFalse(s1.activep) + m1 = s1.begin("IACT ").process(c); me.assertTrue(s1.activep); + s1.done(); me.assertFalse(s1.activep) + me.assertEqual(c, C.bytes("123bfbee34")) + me.assertEqual(m1, m) + me.assertEqual(s0.role, C.STRBRL_INIT) + me.assertEqual(s1.role, C.STRBRL_RESP) + + m = T.bin("foo"); s0.clrout(m); s1.begin("IA T ").process(m).done() + m = T.bin("bar"); s0.clrin(m); s1.begin(" A T ").process(m).done() + + c = T.bin("baz"); m = s0.encin(c) + c1 = s1.begin(" ACT ").process(m); s1.done() + me.assertEqual(m, C.bytes("15e518")) + me.assertEqual(c1, c) + + xxx = T.bin(199*"X") + for i in T.range(200): + c = s0.begin(" ACT ").process(xxx[:i]); s0.done(); s1.encin(c) + + t = s1.begin("IAC ").process(123); s1.done() + me.assertEqual(t, C.bytes + ("45476fc0806aee35e864c4f18e6ba62bd3eb1b1e8bef9042b30b0f15d00c3e9f" + "5d5904ab789d4c67eaed582473c15aa4424f11d52b21a296b36db3392e2ecbb2" + "dc6963bafba3b23882d061f1d335e86e470e8d819591bf0c223e24b925751d04" + "f789fc73bc55f7d2b3ed4881c625aa6321d31511b13f6d5e4ce54a")) + me.assertEqual(s0.prf(123), t) + + ## Copying and MAC. + s2 = s0.copy() + t = s0.macout(16) + me.assertEqual(t, C.bytes("171419608e11e7c907d493209e17f26b")) + me.assertEqual(s2.begin(" CT ").process(16), t); s2.done() + s3 = s1.copy(); me.assertFalse(s3.macin(~t)) + s3 = s1.copy(); me.assertTrue(s3.macin(t)) + s3 = s1.copy(); me.assertFalse(s3.begin("I CT ").process(~t).done()) + me.assertTrue(s1.begin("I CT ").process(t).done()) + + ## Test the remaining operations. (From the Catacomb test vectors.) + k = T.bin("this is my super-secret key") + s0.key(k); s1.begin(" AC ").process(k).done() + + t = s1.begin(C.STROBE_PRF).process(32); s1.done() + me.assertEqual(t, C.bytes + ("5418e0f0ee7f982bbbdc2b4bbf49425d0088abfa98ee21d8ad8a3610d15ebb68")) + me.assertEqual(s0.prf(32), t) + + s0.ratchet(32); s1.begin("C").process(32).done() + + t = s1.begin(" CT ").process(16); s1.done() + me.assertEqual(t, C.bytes("b2084ebdfabd50768c91eebc190132cc")) + me.assertTrue(s0.macin(t)) + + def test_hashbuffer(me): + def mkhash(hsz): + s = C.Strobe(256).begin(C.STROBE_AD) + return s, lambda: s.done().macout(16) + me.check_hashbuffer(mkhash) + +###-------------------------------------------------------------------------- class TestPRP (T.GenericTestMixin): """Test pseudorandom permutations (PRPs)."""