X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/blobdiff_plain/204416123385794c92b7767e7702c0d4fd387468..740847afe208bb8f33e7d6cf642acaf4aa739f6a:/t/t-algorithms.py diff --git a/t/t-algorithms.py b/t/t-algorithms.py new file mode 100644 index 0000000..d7b97cb --- /dev/null +++ b/t/t-algorithms.py @@ -0,0 +1,863 @@ +### -*- mode: python, coding: utf-8 -*- +### +### Test symmetric algorithms +### +### (c) 2019 Straylight/Edgeware +### + +###----- Licensing notice --------------------------------------------------- +### +### This file is part of the Python interface to Catacomb. +### +### Catacomb/Python is free software: you can redistribute it and/or +### modify it under the terms of the GNU General Public License as +### published by the Free Software Foundation; either version 2 of the +### License, or (at your option) any later version. +### +### Catacomb/Python is distributed in the hope that it will be useful, but +### WITHOUT ANY WARRANTY; without even the implied warranty of +### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +### General Public License for more details. +### +### You should have received a copy of the GNU General Public License +### along with Catacomb/Python. If not, write to the Free Software +### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, +### USA. + +###-------------------------------------------------------------------------- +### Imported modules. + +import catacomb as C +import unittest as U +import testutils as T + +###-------------------------------------------------------------------------- +### Utilities. + +def bad_key_size(ksz): + if isinstance(ksz, C.KeySZAny): return None + elif isinstance(ksz, C.KeySZRange): + if ksz.mod != 1: return ksz.min + 1 + elif ksz.max != 0: return ksz.max + 1 + elif ksz.min != 0: return ksz.min - 1 + else: return None + elif isinstance(ksz, C.KeySZSet): + for sz in sorted(ksz.set): + if sz + 1 not in ksz.set: return sz + 1 + assert False, "That should have worked." + else: + return None + +def different_key_size(ksz, sz): + if isinstance(ksz, C.KeySZAny): return sz + 1 + elif isinstance(ksz, C.KeySZRange): + if sz > ksz.min: return sz - ksz.mod + elif ksz.max == 0 or sz < ksz.max: return sz + ksz.mod + else: return None + elif isinstance(ksz, C.KeySZSet): + for sz1 in sorted(ksz.set): + if sz != sz1: return sz1 + return None + else: + return None + +class HashBufferTestMixin (U.TestCase): + """Mixin class for testing all of the various `hash...' methods.""" + + def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn): + """Check `hashuN'.""" + + ## Check encoding an integer. + h0, donefn0 = makefn(w + 2) + hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1) + h1, donefn1 = makefn(w + 2) + h1.hash(T.span(w + 2)) + me.assertEqual(donefn0(), donefn1()) + + ## Check overflow detection. + h0, _ = makefn(w) + me.assertRaises((OverflowError, ValueError), + hashfn, h0, 1 << 8*w) + + def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn): + """Check `hashbufN'.""" + + ## Go through a number of different sizes. + for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]: + if n >= 1 << 8*w: continue + h0, donefn0 = makefn(2 + w + n) + hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff) + h1, donefn1 = makefn(2 + w + n) + h1.hash(T.prep_lenseq(w, n, bigendp, True)) + me.assertEqual(donefn0(), donefn1()) + + ## Check blocks which are too large for the length prefix. + if w <= 3: + n = 1 << 8*w + h0, _ = makefn(w + n) + me.assertRaises((ValueError, OverflowError, TypeError), + hashfn, h0, C.ByteString.zero(n)) + + def check_hashbuffer(me, makefn): + """Test the various `hash...' methods.""" + + ## Check `hashuN'. + me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n)) + me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n)) + me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n)) + me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n)) + if hasattr(makefn(0)[0], "hashu24"): + me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n)) + me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n)) + me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n)) + me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n)) + me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n)) + me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n)) + if hasattr(makefn(0)[0], "hashu64"): + me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n)) + me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n)) + me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n)) + + ## Check `hashbufN'. + me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x)) + me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x)) + me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x)) + me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x)) + if hasattr(makefn(0)[0], "hashbuf24"): + me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x)) + me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x)) + me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x)) + me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x)) + me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x)) + me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x)) + if hasattr(makefn(0)[0], "hashbuf64"): + me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x)) + me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x)) + me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x)) + +###-------------------------------------------------------------------------- +class TestKeysize (U.TestCase): + + def test_any(me): + + ## A typical one-byte spec. + ksz = C.seal.keysz + me.assertEqual(type(ksz), C.KeySZAny) + me.assertEqual(ksz.default, 20) + me.assertEqual(ksz.min, 0) + me.assertEqual(ksz.max, 0) + for n in [0, 12, 20, 5000]: + me.assertTrue(ksz.check(n)) + me.assertEqual(ksz.best(n), n) + me.assertEqual(ksz.pad(n), n) + + ## A typical two-byte spec. (No published algorithms actually /need/ a + ## two-byte key-size spec, but all of the HMAC variants use one anyway.) + ksz = C.sha256_hmac.keysz + me.assertEqual(type(ksz), C.KeySZAny) + me.assertEqual(ksz.default, 32) + me.assertEqual(ksz.min, 0) + me.assertEqual(ksz.max, 0) + for n in [0, 12, 20, 5000]: + me.assertTrue(ksz.check(n)) + me.assertEqual(ksz.best(n), n) + me.assertEqual(ksz.pad(n), n) + + ## Check construction. + ksz = C.KeySZAny(15) + me.assertEqual(ksz.default, 15) + me.assertEqual(ksz.min, 0) + me.assertEqual(ksz.max, 0) + me.assertRaises(ValueError, lambda: C.KeySZAny(-8)) + me.assertEqual(C.KeySZAny(0).default, 0) + + def test_set(me): + ## Note that no published algorithm uses a 16-bit `set' spec. + + ## A typical spec. + ksz = C.salsa20.keysz + me.assertEqual(type(ksz), C.KeySZSet) + me.assertEqual(ksz.default, 32) + me.assertEqual(ksz.min, 10) + me.assertEqual(ksz.max, 32) + me.assertEqual(set(ksz.set), set([10, 16, 32])) + for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16), + (15, 10, 16), (16, 16, 16), (17, 16, 32), + (31, 16, 32), (32, 32, 32), (33, 32, None)]: + if x == best == pad: me.assertTrue(ksz.check(x)) + else: me.assertFalse(ksz.check(x)) + if best is None: me.assertRaises(ValueError, ksz.best, x) + else: me.assertEqual(ksz.best(x), best) + if pad is None: me.assertRaises(ValueError, ksz.pad, x) + else: me.assertEqual(ksz.pad(x), pad) + + ## Check construction. + ksz = C.KeySZSet(7) + me.assertEqual(ksz.default, 7) + me.assertEqual(set(ksz.set), set([7])) + me.assertEqual(ksz.min, 7) + me.assertEqual(ksz.max, 7) + ksz = C.KeySZSet(7, [3, 6, 9]) + me.assertEqual(ksz.default, 7) + me.assertEqual(set(ksz.set), set([3, 6, 7, 9])) + me.assertEqual(ksz.min, 3) + me.assertEqual(ksz.max, 9) + + def test_range(me): + ## Note that no published algorithm uses a 16-bit `range' spec, or an + ## unbounded `range'. + + ## A typical spec. + ksz = C.rijndael.keysz + me.assertEqual(type(ksz), C.KeySZRange) + me.assertEqual(ksz.default, 32) + me.assertEqual(ksz.min, 4) + me.assertEqual(ksz.max, 32) + me.assertEqual(ksz.mod, 4) + for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8), + (15, 12, 16), (16, 16, 16), (17, 16, 20), + (31, 28, 32), (32, 32, 32), (33, 32, None)]: + if x == best == pad: me.assertTrue(ksz.check(x)) + else: me.assertFalse(ksz.check(x)) + if best is None: me.assertRaises(ValueError, ksz.best, x) + else: me.assertEqual(ksz.best(x), best) + if pad is None: me.assertRaises(ValueError, ksz.pad, x) + else: me.assertEqual(ksz.pad(x), pad) + + ## Check construction. + ksz = C.KeySZRange(28, 21, 35, 7) + me.assertEqual(ksz.default, 28) + me.assertEqual(ksz.min, 21) + me.assertEqual(ksz.max, 35) + me.assertEqual(ksz.mod, 7) + me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7) + me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7) + me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7) + me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7) + me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7) + me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7) + me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7) + + def test_conversions(me): + me.assertEqual(C.KeySZ.fromec(256), 128) + me.assertEqual(C.KeySZ.fromschnorr(256), 128) + me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128) + me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128) + me.assertEqual(C.KeySZ.toec(128), 256) + me.assertEqual(C.KeySZ.toschnorr(128), 256) + me.assertEqual(C.KeySZ.todl(128), 2958.6875) + me.assertEqual(C.KeySZ.toif(128), 2958.6875) + +###-------------------------------------------------------------------------- +class TestCipher (T.GenericTestMixin): + """Test basic symmetric ciphers.""" + + def _test_cipher(me, ccls): + + ## Check the class properties. + me.assertEqual(type(ccls.name), str) + me.assertTrue(isinstance(ccls.keysz, C.KeySZ)) + me.assertEqual(type(ccls.blksz), int) + + ## Check round-tripping. + k = T.span(ccls.keysz.default) + iv = T.span(ccls.blksz) + m = T.span(253) + enc = ccls(k) + dec = ccls(k) + try: enc.setiv(iv) + except ValueError: can_setiv = False + else: + can_setiv = True + dec.setiv(iv) + c0 = enc.encrypt(m[0:57]) + m0 = dec.decrypt(c0) + c1 = enc.encrypt(m[57:189]) + m1 = dec.decrypt(c1) + try: enc.bdry() + except ValueError: can_bdry = False + else: + dec.bdry() + can_bdry = True + c2 = enc.encrypt(m[189:253]) + m2 = dec.decrypt(c2) + me.assertEqual(len(c0) + len(c1) + len(c2), len(m)) + me.assertEqual(m0, m[0:57]) + me.assertEqual(m1, m[57:189]) + me.assertEqual(m2, m[189:253]) + + ## Check the `enczero' and `deczero' methods. + c3 = enc.enczero(32) + me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32)) + m4 = dec.deczero(32) + me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32)) + + ## Check that ciphers which support a `boundary' operation actually + ## need it. + if can_bdry: + dec = ccls(k) + if can_setiv: dec.setiv(iv) + m01 = dec.decrypt(c0 + c1) + me.assertEqual(m01, m[0:189]) + + ## Check that the boundary actually does something. + if can_bdry: + dec = ccls(k) + if can_setiv: dec.setiv(iv) + m012 = dec.decrypt(c0 + c1 + c2) + me.assertNotEqual(m012, m) + + ## Check that bad key lengths are rejected. + badlen = bad_key_size(ccls.keysz) + if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen)) + +TestCipher.generate_testcases((name, C.gcciphers[name]) for name in + ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb", + "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"]) + +###-------------------------------------------------------------------------- +class TestAuthenticatedEncryption \ + (HashBufferTestMixin, T.GenericTestMixin): + """Test authenticated encryption schemes.""" + + def _test_aead(me, aecls): + + ## Check the class properties. + me.assertEqual(type(aecls.name), str) + me.assertTrue(isinstance(aecls.keysz, C.KeySZ)) + me.assertTrue(isinstance(aecls.noncesz, C.KeySZ)) + me.assertTrue(isinstance(aecls.tagsz, C.KeySZ)) + me.assertEqual(type(aecls.blksz), int) + me.assertEqual(type(aecls.bufsz), int) + me.assertEqual(type(aecls.ohd), int) + me.assertEqual(type(aecls.flags), int) + + ## Check round-tripping, with full precommitment. First, select some + ## parameters. (It's conceivable that some AEAD schemes are more + ## restrictive than advertised by the various properties, but this works + ## out OK in practice.) + k = T.span(aecls.keysz.default) + n = T.span(aecls.noncesz.default) + if aecls.flags&C.AEADF_NOAAD: h = T.span(0) + else: h = T.span(131) + m = T.span(253) + tsz = aecls.tagsz.default + key = aecls(k) + + ## Next, encrypt a message, checking that things are proper as we go. + enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz) + me.assertEqual(enc.hsz, len(h)) + me.assertEqual(enc.msz, len(m)) + me.assertEqual(enc.mlen, 0) + me.assertEqual(enc.tsz, tsz) + aad = enc.aad() + if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h)) + else: me.assertEqual(aad.hsz, None) + me.assertEqual(aad.hlen, 0) + if not aecls.flags&C.AEADF_NOAAD: + aad.hash(h[0:83]) + me.assertEqual(aad.hlen, 83) + aad.hash(h[83:131]) + me.assertEqual(aad.hlen, 131) + c0 = enc.encrypt(m[0:57]) + me.assertEqual(enc.mlen, 57) + me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd) + c1 = enc.encrypt(m[57:189]) + me.assertEqual(enc.mlen, 189) + me.assertTrue(132 - aecls.bufsz <= len(c1) <= + 132 + aecls.bufsz + aecls.ohd) + c2 = enc.encrypt(m[189:253]) + me.assertEqual(enc.mlen, 253) + me.assertTrue(64 - aecls.bufsz <= len(c2) <= + 64 + aecls.bufsz + aecls.ohd) + c3, t = enc.done(aad = aad) + me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd) + c = c0 + c1 + c2 + c3 + me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd) + me.assertEqual(len(t), tsz) + + ## And now decrypt it again, with different record boundaries. + dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz) + me.assertEqual(dec.hsz, len(h)) + me.assertEqual(dec.csz, len(c)) + me.assertEqual(dec.clen, 0) + me.assertEqual(dec.tsz, tsz) + aad = dec.aad() + if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h)) + else: me.assertEqual(aad.hsz, None) + me.assertEqual(aad.hlen, 0) + aad.hash(h) + m0 = dec.decrypt(c[0:156]) + me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156) + m1 = dec.decrypt(c[156:]) + me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <= + len(c) - 156 + aecls.bufsz) + m2 = dec.done(tag = t, aad = aad) + me.assertEqual(m0 + m1 + m2, m) + + ## And again, with the wrong tag. + dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz) + aad = dec.aad(); aad.hash(h) + _ = dec.decrypt(c) + me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55")) + + ## Check that the all-in-one methods work. + me.assertEqual((c, t), + key.encrypt(n = n, h = h, m = m, tsz = tsz)) + me.assertEqual(m, + key.decrypt(n = n, h = h, c = c, t = t)) + + ## Check that bad key, nonce, and tag lengths are rejected. + badlen = bad_key_size(aecls.keysz) + if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen)) + badlen = bad_key_size(aecls.noncesz) + if badlen is not None: + me.assertRaises(ValueError, key.enc, nonce = T.span(badlen), + hsz = len(h), msz = len(m), tsz = tsz) + me.assertRaises(ValueError, key.dec, nonce = T.span(badlen), + hsz = len(h), csz = len(c), tsz = tsz) + if not aecls.flags&C.AEADF_PCTSZ: + enc = key.enc(nonce = n, hsz = 0, msz = len(m)) + _ = enc.encrypt(m) + me.assertRaises(ValueError, enc.done, tsz = badlen) + badlen = bad_key_size(aecls.tagsz) + if badlen is not None: + me.assertRaises(ValueError, key.enc, nonce = n, + hsz = len(h), msz = len(m), tsz = badlen) + me.assertRaises(ValueError, key.dec, nonce = n, + hsz = len(h), csz = len(c), tsz = badlen) + + ## Check that we can't get a loose `aad' object from a scheme which has + ## nonce-dependent AAD processing. + if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad) + + ## Check the menagerie of AAD hashing methods. + if not aecls.flags&C.AEADF_NOAAD: + def mkhash(hsz): + enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz) + aad = enc.aad() + return aad, lambda: enc.done(aad = aad)[1] + me.check_hashbuffer(mkhash) + + ## Check that encryption/decryption works with the given precommitments. + def quick_enc_check(**kw): + enc = key.enc(**kw) + aad = enc.aad().hash(h) + c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz) + me.assertEqual((c, t), (c0 + c1, tt)) + def quick_dec_check(**kw): + dec = key.dec(**kw) + aad = dec.aad().hash(h) + m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t) + me.assertEqual(m, m0 + m1) + + ## Check that we can get away without precommitting to the header length + ## if and only if the AEAD scheme says it will let us. + if aecls.flags&C.AEADF_PCHSZ: + me.assertRaises(ValueError, key.enc, nonce = n, + msz = len(m), tsz = tsz) + me.assertRaises(ValueError, key.dec, nonce = n, + csz = len(c), tsz = tsz) + else: + quick_enc_check(nonce = n, msz = len(m), tsz = tsz) + quick_dec_check(nonce = n, csz = len(c), tsz = tsz) + + ## Check that we can get away without precommitting to the message/ + ## ciphertext length if and only if the AEAD scheme says it will let us. + if aecls.flags&C.AEADF_PCMSZ: + me.assertRaises(ValueError, key.enc, nonce = n, + hsz = len(h), tsz = tsz) + me.assertRaises(ValueError, key.dec, nonce = n, + hsz = len(h), tsz = tsz) + else: + quick_enc_check(nonce = n, hsz = len(h), tsz = tsz) + quick_dec_check(nonce = n, hsz = len(h), tsz = tsz) + + ## Check that we can get away without precommitting to the tag length if + ## and only if the AEAD scheme says it will let us. + if aecls.flags&C.AEADF_PCTSZ: + me.assertRaises(ValueError, key.enc, nonce = n, + hsz = len(h), msz = len(m)) + me.assertRaises(ValueError, key.dec, nonce = n, + hsz = len(h), csz = len(c)) + else: + quick_enc_check(nonce = n, hsz = len(h), msz = len(m)) + quick_dec_check(nonce = n, hsz = len(h), csz = len(c)) + + ## Check that if we precommit to the header length, we're properly held + ## to the commitment. + if not aecls.flags&C.AEADF_NOAAD: + + ## First, check encryption with underrun. If we must supply AAD first, + ## then the underrun will be reported when we start trying to encrypt; + ## otherwise, checking is delayed until `done'. + enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz) + aad = enc.aad().hash(h[0:83]) + if aecls.flags&C.AEADF_AADFIRST: + me.assertRaises(ValueError, enc.encrypt, m) + else: + _ = enc.encrypt(m) + me.assertRaises(ValueError, enc.done, aad = aad) + + ## Next, check decryption with underrun. If we must supply AAD first, + ## then the underrun will be reported when we start trying to encrypt; + ## otherwise, checking is delayed until `done'. + dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz) + aad = dec.aad().hash(h[0:83]) + if aecls.flags&C.AEADF_AADFIRST: + me.assertRaises(ValueError, dec.decrypt, c) + else: + _ = dec.decrypt(c) + me.assertRaises(ValueError, dec.done, tag = t, aad = aad) + + ## If AAD processing is nonce-dependent then an overrun will be + ## detected imediately. + if aecls.flags&C.AEADF_AADNDEP: + enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz) + aad = enc.aad().hash(h[0:83]) + me.assertRaises(ValueError, aad.hash, h[82:131]) + dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz) + aad = dec.aad().hash(h[0:83]) + me.assertRaises(ValueError, aad.hash, h[82:131]) + + ## Some additional tests for nonce-dependent `aad' objects. + if aecls.flags&C.AEADF_AADNDEP: + + ## Check that `aad' objects can't be used once their parents are gone. + enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz) + aad = enc.aad() + del enc + me.assertRaises(ValueError, aad.hash, h) + + ## Check that they can't be crossed over. + enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz) + enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz) + enc0.aad().hash(h) + aad1 = enc1.aad().hash(h) + _ = enc0.encrypt(m) + me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1) + + ## Test copying AAD. + if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD: + aad0 = key.aad() + aad0.hash(h[0:83]) + aad1 = aad0.copy() + aad2 = aad1.copy() + aad0.hash(h[83:131]) + aad1.hash(h[83:131]) + aad2.hash(h[83:131] ^ 48*C.bytes("ff")) + me.assertEqual(key.enc(nonce = n, hsz = len(h), + msz = 0, tsz = tsz).done(aad = aad0), + key.enc(nonce = n, hsz = len(h), + msz = 0, tsz = tsz).done(aad = aad1)) + me.assertNotEqual(key.enc(nonce = n, hsz = len(h), + msz = 0, tsz = tsz).done(aad = aad0), + key.enc(nonce = n, hsz = len(h), + msz = 0, tsz = tsz).done(aad = aad2)) + + ## Check that if we precommit to the message length, we're properly held + ## to the commitment. (Fortunately, this is way simpler than the AAD + ## case above.) First, try an underrun. + enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz) + _ = enc.encrypt(m[0:183]) + me.assertRaises(ValueError, enc.done, tsz = tsz) + dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz) + _ = dec.decrypt(c[0:183]) + me.assertRaises(ValueError, dec.done, tag = t) + + ## And now an overrun. + enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz) + me.assertRaises(ValueError, enc.encrypt, m) + dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz) + me.assertRaises(ValueError, dec.decrypt, c) + + ## Finally, check that if we precommit to a tag length, we're properly + ## held to the commitment. This depends on being able to find a tag size + ## which isn't the default. + tsz1 = different_key_size(aecls.tagsz, tsz) + if tsz1 is not None: + enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1) + _ = enc.encrypt(m) + me.assertRaises(ValueError, enc.done, tsz = tsz) + dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1) + aad = dec.aad().hash(h) + _ = dec.decrypt(c) + me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad) + +TestAuthenticatedEncryption.generate_testcases \ + ((name, C.gcaeads[name]) for name in + ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm", + "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"]) + +###-------------------------------------------------------------------------- +class BaseTestHash (HashBufferTestMixin): + """Base class for testing hash functions.""" + + def check_hash(me, hcls, need_bufsz = True): + """ + Check hash class HCLS. + + If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz', + `name', or `hashsz' attributes. This test is mostly reused for MACs, + which don't have these attributes. + """ + ## Check the class properties. + if need_bufsz: + me.assertEqual(type(hcls.name), str) + me.assertEqual(type(hcls.bufsz), int) + me.assertEqual(type(hcls.hashsz), int) + + ## Set some initial values. + m = T.span(131) + h = hcls().hash(m).done() + + ## Check that hash length comes out right. + if need_bufsz: me.assertEqual(len(h), hcls.hashsz) + + ## Check that we get the same answer if we split the message up. + me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done()) + + ## Check the `check' method. + me.assertTrue(hcls().hash(m).check(h)) + me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa"))) + + ## Check the menagerie of random hashing methods. + def mkhash(_): + h = hcls() + return h, h.done + me.check_hashbuffer(mkhash) + +class TestHash (BaseTestHash, T.GenericTestMixin): + """Test hash functions.""" + def _test_hash(me, hcls): me.check_hash(hcls, need_bufsz = True) + +TestHash.generate_testcases((name, C.gchashes[name]) for name in + ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256", + "crc32"]) + +###-------------------------------------------------------------------------- +class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin): + """Test message authentication codes.""" + + def _test_mac(me, mcls): + + ## Check the MAC properties. + me.assertEqual(type(mcls.name), str) + me.assertTrue(isinstance(mcls.keysz, C.KeySZ)) + me.assertEqual(type(mcls.tagsz), int) + + ## Test hashing. + k = T.span(mcls.keysz.default) + key = mcls(k) + me.check_hash(key, need_bufsz = False) + + ## Check that bad key lengths are rejected. + badlen = bad_key_size(mcls.keysz) + if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen)) + +TestMessageAuthentication.generate_testcases \ + ((name, C.gcmacs[name]) for name in + ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"]) + +class TestPoly1305 (HashBufferTestMixin): + """Check the Poly1305 one-time message authentication function.""" + + def test_poly1305(me): + + ## Check the MAC properties. + me.assertEqual(C.poly1305.name, "poly1305") + me.assertEqual(type(C.poly1305.keysz), C.KeySZSet) + me.assertEqual(C.poly1305.keysz.default, 16) + me.assertEqual(set(C.poly1305.keysz.set), set([16])) + me.assertEqual(C.poly1305.tagsz, 16) + me.assertEqual(C.poly1305.masksz, 16) + + ## Set some initial values. + k = T.span(16) + u = T.span(64)[-16:] + m = T.span(149) + key = C.poly1305(k) + t = key(u).hash(m).done() + + ## Check the key properties. + me.assertEqual(len(t), 16) + + ## Check that we get the same answer if we split the message up. + me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done()) + + ## Check the `check' method. + me.assertTrue(key(u).hash(m).check(t)) + me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc"))) + + ## Check the menagerie of random hashing methods. + def mkhash(_): + h = key(u) + return h, h.done + me.check_hashbuffer(mkhash) + + ## Check that we can't complete hashing without a mask. + me.assertRaises(ValueError, key().hash(m).done) + + ## Check `concat'. + h0 = key().hash(m[0:96]) + h1 = key().hash(m[96:117]) + me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done()) + key1 = C.poly1305(k) + me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1) + me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117])) + me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1) + +###-------------------------------------------------------------------------- +class TestHLatin (U.TestCase): + """Test the `hsalsa20' and `hchacha20' functions.""" + + def test_hlatin(me): + kk = [T.span(sz) for sz in [10, 16, 32]] + n = T.span(16) + bad_k = T.span(18) + bad_n = T.span(13) + for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf, + C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]: + for k in kk: + h = fn(k, n) + me.assertEqual(len(h), 32) + me.assertRaises(ValueError, fn, bad_k, n) + me.assertRaises(ValueError, fn, k, bad_n) + +###-------------------------------------------------------------------------- +class TestKeccak (HashBufferTestMixin): + """Test the Keccak-p[1600, n] sponge function.""" + + def test_keccak(me): + + ## Make a state and feed some stuff into it. + m0 = T.bin("some initial string") + m1 = T.bin("awesome follow-up string") + st0 = C.Keccak1600() + me.assertEqual(st0.nround, 24) + st0.mix(m0).step() + + ## Make another step with a different round count. + st1 = C.Keccak1600(23) + st1.mix(m0).step() + me.assertNotEqual(st0.extract(32), st1.extract(32)) + + ## Check state copying. + st1 = st0.copy() + mask = st1.extract(len(m1)) + st0.mix(m1) + st1.mix(m1) + me.assertEqual(st0.extract(32), st1.extract(32)) + + ## Check error conditions. + _ = st0.extract(200) + me.assertRaises(ValueError, st0.extract, 201) + st0.mix(T.span(200)) + me.assertRaises(ValueError, st0.mix, T.span(201)) + + def check_shake(me, xcls, c, done_matches_xof = True): + """ + Test the SHAKE and cSHAKE XOFs. + + This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false + to indicate that the XOF output is range-separated from the fixed-length + outputs (unlike the basic SHAKE functions). + """ + + ## Check the hash attributes. + x = xcls() + me.assertEqual(x.rate, 200 - c) + me.assertEqual(x.buffered, 0) + me.assertEqual(x.state, "absorb") + + ## Set some initial values. + func = T.bin("TESTXOF") + perso = T.bin("catacomb-python test") + m = T.span(167) + h0 = xcls().hash(m).done(193) + me.assertEqual(len(h0), 193) + h1 = xcls(func = func, perso = perso).hash(m).done(193) + me.assertEqual(len(h1), 193) + me.assertNotEqual(h0, h1) + + ## Check input and output in pieces, and the state machine. + if done_matches_xof: h = h0 + else: h = xcls().hash(m).xof().get(len(h0)) + x = xcls().hash(m[0:76]).hash(m[76:167]).xof() + me.assertEqual(h, x.get(98) + x.get(95)) + + ## Check masking. + x = xcls().hash(m).xof() + me.assertEqual(x.mask(m), m ^ h[0:len(m)]) + + ## Check the `check' method. + me.assertTrue(xcls().hash(m).check(h0)) + me.assertFalse(xcls().hash(m).check(h1)) + + ## Check the menagerie of random hashing methods. + def mkhash(_): + x = xcls(func = func, perso = perso) + return x, lambda: x.done(100 - x.rate//2) + me.check_hashbuffer(mkhash) + + ## Check the state machine tracking. + x = xcls(); me.assertEqual(x.state, "absorb") + x.hash(m); me.assertEqual(x.state, "absorb") + xx = x.copy() + h = xx.done(100 - x.rate//2) + me.assertEqual(xx.state, "dead") + me.assertRaises(ValueError, xx.done, 1) + me.assertRaises(ValueError, xx.get, 1) + me.assertEqual(x.state, "absorb") + me.assertRaises(ValueError, x.get, 1) + x.xof(); me.assertEqual(x.state, "squeeze") + me.assertRaises(ValueError, x.done, 1) + _ = x.get(1) + yy = x.copy(); me.assertEqual(yy.state, "squeeze") + + def test_shake128(me): me.check_shake(C.Shake128, 32) + def test_shake256(me): me.check_shake(C.Shake256, 64) + + def check_kmac(me, mcls, c): + k = T.span(32) + me.check_shake(lambda func = None, perso = T.bin(""): + mcls(k, perso = perso), + c, done_matches_xof = False) + + def test_kmac128(me): me.check_kmac(C.KMAC128, 32) + def test_kmac256(me): me.check_kmac(C.KMAC256, 64) + +###-------------------------------------------------------------------------- +class TestPRP (T.GenericTestMixin): + """Test pseudorandom permutations (PRPs).""" + + def _test_prp(me, pcls): + + ## Check the PRP properties. + me.assertEqual(type(pcls.name), str) + me.assertTrue(isinstance(pcls.keysz, C.KeySZ)) + me.assertEqual(type(pcls.blksz), int) + + ## Check round-tripping. + k = T.span(pcls.keysz.default) + key = pcls(k) + m = T.span(pcls.blksz) + c = key.encrypt(m) + me.assertEqual(len(c), pcls.blksz) + me.assertEqual(m, key.decrypt(c)) + + ## Check that bad key lengths are rejected. + badlen = bad_key_size(pcls.keysz) + if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen)) + + ## Check that bad blocks are rejected. + badblk = T.span(pcls.blksz + 1) + me.assertRaises(ValueError, key.encrypt, badblk) + me.assertRaises(ValueError, key.decrypt, badblk) + +TestPRP.generate_testcases((name, C.gcprps[name]) for name in + ["desx", "blowfish", "rijndael"]) + +###----- That's all, folks -------------------------------------------------- + +if __name__ == "__main__": U.main()