Merge branch '1.3.x'
[catacomb-python] / t / t-algorithms.py
diff --git a/t/t-algorithms.py b/t/t-algorithms.py
new file mode 100644 (file)
index 0000000..d7b97cb
--- /dev/null
@@ -0,0 +1,863 @@
+### -*- mode: python, coding: utf-8 -*-
+###
+### Test symmetric algorithms
+###
+### (c) 2019 Straylight/Edgeware
+###
+
+###----- Licensing notice ---------------------------------------------------
+###
+### This file is part of the Python interface to Catacomb.
+###
+### Catacomb/Python is free software: you can redistribute it and/or
+### modify it under the terms of the GNU General Public License as
+### published by the Free Software Foundation; either version 2 of the
+### License, or (at your option) any later version.
+###
+### Catacomb/Python is distributed in the hope that it will be useful, but
+### WITHOUT ANY WARRANTY; without even the implied warranty of
+### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+### General Public License for more details.
+###
+### You should have received a copy of the GNU General Public License
+### along with Catacomb/Python.  If not, write to the Free Software
+### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
+### USA.
+
+###--------------------------------------------------------------------------
+### Imported modules.
+
+import catacomb as C
+import unittest as U
+import testutils as T
+
+###--------------------------------------------------------------------------
+### Utilities.
+
+def bad_key_size(ksz):
+  if isinstance(ksz, C.KeySZAny): return None
+  elif isinstance(ksz, C.KeySZRange):
+    if ksz.mod != 1: return ksz.min + 1
+    elif ksz.max != 0: return ksz.max + 1
+    elif ksz.min != 0: return ksz.min - 1
+    else: return None
+  elif isinstance(ksz, C.KeySZSet):
+    for sz in sorted(ksz.set):
+      if sz + 1 not in ksz.set: return sz + 1
+    assert False, "That should have worked."
+  else:
+    return None
+
+def different_key_size(ksz, sz):
+  if isinstance(ksz, C.KeySZAny): return sz + 1
+  elif isinstance(ksz, C.KeySZRange):
+    if sz > ksz.min: return sz - ksz.mod
+    elif ksz.max == 0 or sz < ksz.max: return sz + ksz.mod
+    else: return None
+  elif isinstance(ksz, C.KeySZSet):
+    for sz1 in sorted(ksz.set):
+      if sz != sz1: return sz1
+    return None
+  else:
+    return None
+
+class HashBufferTestMixin (U.TestCase):
+  """Mixin class for testing all of the various `hash...' methods."""
+
+  def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
+    """Check `hashuN'."""
+
+    ## Check encoding an integer.
+    h0, donefn0 = makefn(w + 2)
+    hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
+    h1, donefn1 = makefn(w + 2)
+    h1.hash(T.span(w + 2))
+    me.assertEqual(donefn0(), donefn1())
+
+    ## Check overflow detection.
+    h0, _ = makefn(w)
+    me.assertRaises((OverflowError, ValueError),
+                    hashfn, h0, 1 << 8*w)
+
+  def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
+    """Check `hashbufN'."""
+
+    ## Go through a number of different sizes.
+    for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
+      if n >= 1 << 8*w: continue
+      h0, donefn0 = makefn(2 + w + n)
+      hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
+      h1, donefn1 = makefn(2 + w + n)
+      h1.hash(T.prep_lenseq(w, n, bigendp, True))
+      me.assertEqual(donefn0(), donefn1())
+
+    ## Check blocks which are too large for the length prefix.
+    if w <= 3:
+      n = 1 << 8*w
+      h0, _ = makefn(w + n)
+      me.assertRaises((ValueError, OverflowError, TypeError),
+                      hashfn, h0, C.ByteString.zero(n))
+
+  def check_hashbuffer(me, makefn):
+    """Test the various `hash...' methods."""
+
+    ## Check `hashuN'.
+    me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
+    me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
+    me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
+    me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
+    if hasattr(makefn(0)[0], "hashu24"):
+      me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
+      me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
+      me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
+    me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
+    me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
+    me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
+    if hasattr(makefn(0)[0], "hashu64"):
+      me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
+      me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
+      me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
+
+    ## Check `hashbufN'.
+    me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
+    me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
+    me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
+    me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
+    if hasattr(makefn(0)[0], "hashbuf24"):
+      me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
+      me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
+      me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
+    me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
+    me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
+    me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
+    if hasattr(makefn(0)[0], "hashbuf64"):
+      me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
+      me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
+      me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
+
+###--------------------------------------------------------------------------
+class TestKeysize (U.TestCase):
+
+  def test_any(me):
+
+    ## A typical one-byte spec.
+    ksz = C.seal.keysz
+    me.assertEqual(type(ksz), C.KeySZAny)
+    me.assertEqual(ksz.default, 20)
+    me.assertEqual(ksz.min, 0)
+    me.assertEqual(ksz.max, 0)
+    for n in [0, 12, 20, 5000]:
+      me.assertTrue(ksz.check(n))
+      me.assertEqual(ksz.best(n), n)
+      me.assertEqual(ksz.pad(n), n)
+
+    ## A typical two-byte spec.  (No published algorithms actually /need/ a
+    ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
+    ksz = C.sha256_hmac.keysz
+    me.assertEqual(type(ksz), C.KeySZAny)
+    me.assertEqual(ksz.default, 32)
+    me.assertEqual(ksz.min, 0)
+    me.assertEqual(ksz.max, 0)
+    for n in [0, 12, 20, 5000]:
+      me.assertTrue(ksz.check(n))
+      me.assertEqual(ksz.best(n), n)
+      me.assertEqual(ksz.pad(n), n)
+
+    ## Check construction.
+    ksz = C.KeySZAny(15)
+    me.assertEqual(ksz.default, 15)
+    me.assertEqual(ksz.min, 0)
+    me.assertEqual(ksz.max, 0)
+    me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
+    me.assertEqual(C.KeySZAny(0).default, 0)
+
+  def test_set(me):
+    ## Note that no published algorithm uses a 16-bit `set' spec.
+
+    ## A typical spec.
+    ksz = C.salsa20.keysz
+    me.assertEqual(type(ksz), C.KeySZSet)
+    me.assertEqual(ksz.default, 32)
+    me.assertEqual(ksz.min, 10)
+    me.assertEqual(ksz.max, 32)
+    me.assertEqual(set(ksz.set), set([10, 16, 32]))
+    for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
+                         (15, 10, 16), (16, 16, 16), (17, 16, 32),
+                         (31, 16, 32), (32, 32, 32), (33, 32, None)]:
+      if x == best == pad: me.assertTrue(ksz.check(x))
+      else: me.assertFalse(ksz.check(x))
+      if best is None: me.assertRaises(ValueError, ksz.best, x)
+      else: me.assertEqual(ksz.best(x), best)
+      if pad is None: me.assertRaises(ValueError, ksz.pad, x)
+      else: me.assertEqual(ksz.pad(x), pad)
+
+    ## Check construction.
+    ksz = C.KeySZSet(7)
+    me.assertEqual(ksz.default, 7)
+    me.assertEqual(set(ksz.set), set([7]))
+    me.assertEqual(ksz.min, 7)
+    me.assertEqual(ksz.max, 7)
+    ksz = C.KeySZSet(7, [3, 6, 9])
+    me.assertEqual(ksz.default, 7)
+    me.assertEqual(set(ksz.set), set([3, 6, 7, 9]))
+    me.assertEqual(ksz.min, 3)
+    me.assertEqual(ksz.max, 9)
+
+  def test_range(me):
+    ## Note that no published algorithm uses a 16-bit `range' spec, or an
+    ## unbounded `range'.
+
+    ## A typical spec.
+    ksz = C.rijndael.keysz
+    me.assertEqual(type(ksz), C.KeySZRange)
+    me.assertEqual(ksz.default, 32)
+    me.assertEqual(ksz.min, 4)
+    me.assertEqual(ksz.max, 32)
+    me.assertEqual(ksz.mod, 4)
+    for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
+                         (15, 12, 16), (16, 16, 16), (17, 16, 20),
+                         (31, 28, 32), (32, 32, 32), (33, 32, None)]:
+      if x == best == pad: me.assertTrue(ksz.check(x))
+      else: me.assertFalse(ksz.check(x))
+      if best is None: me.assertRaises(ValueError, ksz.best, x)
+      else: me.assertEqual(ksz.best(x), best)
+      if pad is None: me.assertRaises(ValueError, ksz.pad, x)
+      else: me.assertEqual(ksz.pad(x), pad)
+
+    ## Check construction.
+    ksz = C.KeySZRange(28, 21, 35, 7)
+    me.assertEqual(ksz.default, 28)
+    me.assertEqual(ksz.min, 21)
+    me.assertEqual(ksz.max, 35)
+    me.assertEqual(ksz.mod, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
+    me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
+
+  def test_conversions(me):
+    me.assertEqual(C.KeySZ.fromec(256), 128)
+    me.assertEqual(C.KeySZ.fromschnorr(256), 128)
+    me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
+    me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
+    me.assertEqual(C.KeySZ.toec(128), 256)
+    me.assertEqual(C.KeySZ.toschnorr(128), 256)
+    me.assertEqual(C.KeySZ.todl(128), 2958.6875)
+    me.assertEqual(C.KeySZ.toif(128), 2958.6875)
+
+###--------------------------------------------------------------------------
+class TestCipher (T.GenericTestMixin):
+  """Test basic symmetric ciphers."""
+
+  def _test_cipher(me, ccls):
+
+    ## Check the class properties.
+    me.assertEqual(type(ccls.name), str)
+    me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
+    me.assertEqual(type(ccls.blksz), int)
+
+    ## Check round-tripping.
+    k = T.span(ccls.keysz.default)
+    iv = T.span(ccls.blksz)
+    m = T.span(253)
+    enc = ccls(k)
+    dec = ccls(k)
+    try: enc.setiv(iv)
+    except ValueError: can_setiv = False
+    else:
+      can_setiv = True
+      dec.setiv(iv)
+    c0 = enc.encrypt(m[0:57])
+    m0 = dec.decrypt(c0)
+    c1 = enc.encrypt(m[57:189])
+    m1 = dec.decrypt(c1)
+    try: enc.bdry()
+    except ValueError: can_bdry = False
+    else:
+      dec.bdry()
+      can_bdry = True
+    c2 = enc.encrypt(m[189:253])
+    m2 = dec.decrypt(c2)
+    me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
+    me.assertEqual(m0, m[0:57])
+    me.assertEqual(m1, m[57:189])
+    me.assertEqual(m2, m[189:253])
+
+    ## Check the `enczero' and `deczero' methods.
+    c3 = enc.enczero(32)
+    me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
+    m4 = dec.deczero(32)
+    me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
+
+    ## Check that ciphers which support a `boundary' operation actually
+    ## need it.
+    if can_bdry:
+      dec = ccls(k)
+      if can_setiv: dec.setiv(iv)
+      m01 = dec.decrypt(c0 + c1)
+      me.assertEqual(m01, m[0:189])
+
+    ## Check that the boundary actually does something.
+    if can_bdry:
+      dec = ccls(k)
+      if can_setiv: dec.setiv(iv)
+      m012 = dec.decrypt(c0 + c1 + c2)
+      me.assertNotEqual(m012, m)
+
+    ## Check that bad key lengths are rejected.
+    badlen = bad_key_size(ccls.keysz)
+    if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
+
+TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
+  ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
+   "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
+
+###--------------------------------------------------------------------------
+class TestAuthenticatedEncryption \
+        (HashBufferTestMixin, T.GenericTestMixin):
+  """Test authenticated encryption schemes."""
+
+  def _test_aead(me, aecls):
+
+    ## Check the class properties.
+    me.assertEqual(type(aecls.name), str)
+    me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
+    me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
+    me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
+    me.assertEqual(type(aecls.blksz), int)
+    me.assertEqual(type(aecls.bufsz), int)
+    me.assertEqual(type(aecls.ohd), int)
+    me.assertEqual(type(aecls.flags), int)
+
+    ## Check round-tripping, with full precommitment.  First, select some
+    ## parameters.  (It's conceivable that some AEAD schemes are more
+    ## restrictive than advertised by the various properties, but this works
+    ## out OK in practice.)
+    k = T.span(aecls.keysz.default)
+    n = T.span(aecls.noncesz.default)
+    if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
+    else: h = T.span(131)
+    m = T.span(253)
+    tsz = aecls.tagsz.default
+    key = aecls(k)
+
+    ## Next, encrypt a message, checking that things are proper as we go.
+    enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+    me.assertEqual(enc.hsz, len(h))
+    me.assertEqual(enc.msz, len(m))
+    me.assertEqual(enc.mlen, 0)
+    me.assertEqual(enc.tsz, tsz)
+    aad = enc.aad()
+    if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
+    else: me.assertEqual(aad.hsz, None)
+    me.assertEqual(aad.hlen, 0)
+    if not aecls.flags&C.AEADF_NOAAD:
+      aad.hash(h[0:83])
+      me.assertEqual(aad.hlen, 83)
+      aad.hash(h[83:131])
+      me.assertEqual(aad.hlen, 131)
+    c0 = enc.encrypt(m[0:57])
+    me.assertEqual(enc.mlen, 57)
+    me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
+    c1 = enc.encrypt(m[57:189])
+    me.assertEqual(enc.mlen, 189)
+    me.assertTrue(132 - aecls.bufsz <= len(c1) <=
+                  132 + aecls.bufsz + aecls.ohd)
+    c2 = enc.encrypt(m[189:253])
+    me.assertEqual(enc.mlen, 253)
+    me.assertTrue(64 - aecls.bufsz <= len(c2) <=
+                  64 + aecls.bufsz + aecls.ohd)
+    c3, t = enc.done(aad = aad)
+    me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
+    c = c0 + c1 + c2 + c3
+    me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
+    me.assertEqual(len(t), tsz)
+
+    ## And now decrypt it again, with different record boundaries.
+    dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+    me.assertEqual(dec.hsz, len(h))
+    me.assertEqual(dec.csz, len(c))
+    me.assertEqual(dec.clen, 0)
+    me.assertEqual(dec.tsz, tsz)
+    aad = dec.aad()
+    if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
+    else: me.assertEqual(aad.hsz, None)
+    me.assertEqual(aad.hlen, 0)
+    aad.hash(h)
+    m0 = dec.decrypt(c[0:156])
+    me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
+    m1 = dec.decrypt(c[156:])
+    me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
+                  len(c) - 156 + aecls.bufsz)
+    m2 = dec.done(tag = t, aad = aad)
+    me.assertEqual(m0 + m1 + m2, m)
+
+    ## And again, with the wrong tag.
+    dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+    aad = dec.aad(); aad.hash(h)
+    _ = dec.decrypt(c)
+    me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
+
+    ## Check that the all-in-one methods work.
+    me.assertEqual((c, t),
+                   key.encrypt(n = n, h = h, m = m, tsz = tsz))
+    me.assertEqual(m,
+                   key.decrypt(n = n, h = h, c = c, t = t))
+
+    ## Check that bad key, nonce, and tag lengths are rejected.
+    badlen = bad_key_size(aecls.keysz)
+    if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
+    badlen = bad_key_size(aecls.noncesz)
+    if badlen is not None:
+      me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
+                      hsz = len(h), msz = len(m), tsz = tsz)
+      me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
+                      hsz = len(h), csz = len(c), tsz = tsz)
+      if not aecls.flags&C.AEADF_PCTSZ:
+        enc = key.enc(nonce = n, hsz = 0, msz = len(m))
+        _ = enc.encrypt(m)
+        me.assertRaises(ValueError, enc.done, tsz = badlen)
+    badlen = bad_key_size(aecls.tagsz)
+    if badlen is not None:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      hsz = len(h), msz = len(m), tsz = badlen)
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      hsz = len(h), csz = len(c), tsz = badlen)
+
+    ## Check that we can't get a loose `aad' object from a scheme which has
+    ## nonce-dependent AAD processing.
+    if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
+
+    ## Check the menagerie of AAD hashing methods.
+    if not aecls.flags&C.AEADF_NOAAD:
+      def mkhash(hsz):
+        enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
+        aad = enc.aad()
+        return aad, lambda: enc.done(aad = aad)[1]
+      me.check_hashbuffer(mkhash)
+
+    ## Check that encryption/decryption works with the given precommitments.
+    def quick_enc_check(**kw):
+      enc = key.enc(**kw)
+      aad = enc.aad().hash(h)
+      c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
+      me.assertEqual((c, t), (c0 + c1, tt))
+    def quick_dec_check(**kw):
+      dec = key.dec(**kw)
+      aad = dec.aad().hash(h)
+      m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
+      me.assertEqual(m, m0 + m1)
+
+    ## Check that we can get away without precommitting to the header length
+    ## if and only if the AEAD scheme says it will let us.
+    if aecls.flags&C.AEADF_PCHSZ:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      msz = len(m), tsz = tsz)
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      csz = len(c), tsz = tsz)
+    else:
+      quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
+      quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
+
+    ## Check that we can get away without precommitting to the message/
+    ## ciphertext length if and only if the AEAD scheme says it will let us.
+    if aecls.flags&C.AEADF_PCMSZ:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      hsz = len(h), tsz = tsz)
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      hsz = len(h), tsz = tsz)
+    else:
+      quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
+      quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
+
+    ## Check that we can get away without precommitting to the tag length if
+    ## and only if the AEAD scheme says it will let us.
+    if aecls.flags&C.AEADF_PCTSZ:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      hsz = len(h), msz = len(m))
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      hsz = len(h), csz = len(c))
+    else:
+      quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
+      quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
+
+    ## Check that if we precommit to the header length, we're properly held
+    ## to the commitment.
+    if not aecls.flags&C.AEADF_NOAAD:
+
+      ## First, check encryption with underrun.  If we must supply AAD first,
+      ## then the underrun will be reported when we start trying to encrypt;
+      ## otherwise, checking is delayed until `done'.
+      enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      aad = enc.aad().hash(h[0:83])
+      if aecls.flags&C.AEADF_AADFIRST:
+        me.assertRaises(ValueError, enc.encrypt, m)
+      else:
+        _ = enc.encrypt(m)
+        me.assertRaises(ValueError, enc.done, aad = aad)
+
+      ## Next, check decryption with underrun.  If we must supply AAD first,
+      ## then the underrun will be reported when we start trying to encrypt;
+      ## otherwise, checking is delayed until `done'.
+      dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+      aad = dec.aad().hash(h[0:83])
+      if aecls.flags&C.AEADF_AADFIRST:
+        me.assertRaises(ValueError, dec.decrypt, c)
+      else:
+        _ = dec.decrypt(c)
+        me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
+
+      ## If AAD processing is nonce-dependent then an overrun will be
+      ## detected imediately.
+      if aecls.flags&C.AEADF_AADNDEP:
+        enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+        aad = enc.aad().hash(h[0:83])
+        me.assertRaises(ValueError, aad.hash, h[82:131])
+        dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+        aad = dec.aad().hash(h[0:83])
+        me.assertRaises(ValueError, aad.hash, h[82:131])
+
+    ## Some additional tests for nonce-dependent `aad' objects.
+    if aecls.flags&C.AEADF_AADNDEP:
+
+      ## Check that `aad' objects can't be used once their parents are gone.
+      enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      aad = enc.aad()
+      del enc
+      me.assertRaises(ValueError, aad.hash, h)
+
+      ## Check that they can't be crossed over.
+      enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      enc0.aad().hash(h)
+      aad1 = enc1.aad().hash(h)
+      _ = enc0.encrypt(m)
+      me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
+
+    ## Test copying AAD.
+    if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
+      aad0 = key.aad()
+      aad0.hash(h[0:83])
+      aad1 = aad0.copy()
+      aad2 = aad1.copy()
+      aad0.hash(h[83:131])
+      aad1.hash(h[83:131])
+      aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
+      me.assertEqual(key.enc(nonce = n, hsz = len(h),
+                             msz = 0, tsz = tsz).done(aad = aad0),
+                     key.enc(nonce = n, hsz = len(h),
+                             msz = 0, tsz = tsz).done(aad = aad1))
+      me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
+                                msz = 0, tsz = tsz).done(aad = aad0),
+                        key.enc(nonce = n, hsz = len(h),
+                                msz = 0, tsz = tsz).done(aad = aad2))
+
+    ## Check that if we precommit to the message length, we're properly held
+    ## to the commitment.  (Fortunately, this is way simpler than the AAD
+    ## case above.)  First, try an underrun.
+    enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
+    _ = enc.encrypt(m[0:183])
+    me.assertRaises(ValueError, enc.done, tsz = tsz)
+    dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
+    _ = dec.decrypt(c[0:183])
+    me.assertRaises(ValueError, dec.done, tag = t)
+
+    ## And now an overrun.
+    enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
+    me.assertRaises(ValueError, enc.encrypt, m)
+    dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
+    me.assertRaises(ValueError, dec.decrypt, c)
+
+    ## Finally, check that if we precommit to a tag length, we're properly
+    ## held to the commitment.  This depends on being able to find a tag size
+    ## which isn't the default.
+    tsz1 = different_key_size(aecls.tagsz, tsz)
+    if tsz1 is not None:
+      enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
+      _ = enc.encrypt(m)
+      me.assertRaises(ValueError, enc.done, tsz = tsz)
+      dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
+      aad = dec.aad().hash(h)
+      _ = dec.decrypt(c)
+      me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
+
+TestAuthenticatedEncryption.generate_testcases \
+  ((name, C.gcaeads[name]) for name in
+   ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
+    "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
+
+###--------------------------------------------------------------------------
+class BaseTestHash (HashBufferTestMixin):
+  """Base class for testing hash functions."""
+
+  def check_hash(me, hcls, need_bufsz = True):
+    """
+    Check hash class HCLS.
+
+    If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
+    `name', or `hashsz' attributes.  This test is mostly reused for MACs,
+    which don't have these attributes.
+    """
+    ## Check the class properties.
+    if need_bufsz:
+      me.assertEqual(type(hcls.name), str)
+      me.assertEqual(type(hcls.bufsz), int)
+      me.assertEqual(type(hcls.hashsz), int)
+
+    ## Set some initial values.
+    m = T.span(131)
+    h = hcls().hash(m).done()
+
+    ## Check that hash length comes out right.
+    if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
+
+    ## Check that we get the same answer if we split the message up.
+    me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
+
+    ## Check the `check' method.
+    me.assertTrue(hcls().hash(m).check(h))
+    me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
+
+    ## Check the menagerie of random hashing methods.
+    def mkhash(_):
+      h = hcls()
+      return h, h.done
+    me.check_hashbuffer(mkhash)
+
+class TestHash (BaseTestHash, T.GenericTestMixin):
+  """Test hash functions."""
+  def _test_hash(me, hcls):    me.check_hash(hcls, need_bufsz = True)
+
+TestHash.generate_testcases((name, C.gchashes[name]) for name in
+  ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
+   "crc32"])
+
+###--------------------------------------------------------------------------
+class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
+  """Test message authentication codes."""
+
+  def _test_mac(me, mcls):
+
+    ## Check the MAC properties.
+    me.assertEqual(type(mcls.name), str)
+    me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
+    me.assertEqual(type(mcls.tagsz), int)
+
+    ## Test hashing.
+    k = T.span(mcls.keysz.default)
+    key = mcls(k)
+    me.check_hash(key, need_bufsz = False)
+
+    ## Check that bad key lengths are rejected.
+    badlen = bad_key_size(mcls.keysz)
+    if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
+
+TestMessageAuthentication.generate_testcases \
+  ((name, C.gcmacs[name]) for name in
+   ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
+
+class TestPoly1305 (HashBufferTestMixin):
+  """Check the Poly1305 one-time message authentication function."""
+
+  def test_poly1305(me):
+
+    ## Check the MAC properties.
+    me.assertEqual(C.poly1305.name, "poly1305")
+    me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
+    me.assertEqual(C.poly1305.keysz.default, 16)
+    me.assertEqual(set(C.poly1305.keysz.set), set([16]))
+    me.assertEqual(C.poly1305.tagsz, 16)
+    me.assertEqual(C.poly1305.masksz, 16)
+
+    ## Set some initial values.
+    k = T.span(16)
+    u = T.span(64)[-16:]
+    m = T.span(149)
+    key = C.poly1305(k)
+    t = key(u).hash(m).done()
+
+    ## Check the key properties.
+    me.assertEqual(len(t), 16)
+
+    ## Check that we get the same answer if we split the message up.
+    me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
+
+    ## Check the `check' method.
+    me.assertTrue(key(u).hash(m).check(t))
+    me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
+
+    ## Check the menagerie of random hashing methods.
+    def mkhash(_):
+      h = key(u)
+      return h, h.done
+    me.check_hashbuffer(mkhash)
+
+    ## Check that we can't complete hashing without a mask.
+    me.assertRaises(ValueError, key().hash(m).done)
+
+    ## Check `concat'.
+    h0 = key().hash(m[0:96])
+    h1 = key().hash(m[96:117])
+    me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
+    key1 = C.poly1305(k)
+    me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
+    me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
+    me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
+
+###--------------------------------------------------------------------------
+class TestHLatin (U.TestCase):
+  """Test the `hsalsa20' and `hchacha20' functions."""
+
+  def test_hlatin(me):
+    kk = [T.span(sz) for sz in [10, 16, 32]]
+    n = T.span(16)
+    bad_k = T.span(18)
+    bad_n = T.span(13)
+    for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
+               C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
+      for k in kk:
+        h = fn(k, n)
+        me.assertEqual(len(h), 32)
+      me.assertRaises(ValueError, fn, bad_k, n)
+      me.assertRaises(ValueError, fn, k, bad_n)
+
+###--------------------------------------------------------------------------
+class TestKeccak (HashBufferTestMixin):
+  """Test the Keccak-p[1600, n] sponge function."""
+
+  def test_keccak(me):
+
+    ## Make a state and feed some stuff into it.
+    m0 = T.bin("some initial string")
+    m1 = T.bin("awesome follow-up string")
+    st0 = C.Keccak1600()
+    me.assertEqual(st0.nround, 24)
+    st0.mix(m0).step()
+
+    ## Make another step with a different round count.
+    st1 = C.Keccak1600(23)
+    st1.mix(m0).step()
+    me.assertNotEqual(st0.extract(32), st1.extract(32))
+
+    ## Check state copying.
+    st1 = st0.copy()
+    mask = st1.extract(len(m1))
+    st0.mix(m1)
+    st1.mix(m1)
+    me.assertEqual(st0.extract(32), st1.extract(32))
+
+    ## Check error conditions.
+    _ = st0.extract(200)
+    me.assertRaises(ValueError, st0.extract, 201)
+    st0.mix(T.span(200))
+    me.assertRaises(ValueError, st0.mix, T.span(201))
+
+  def check_shake(me, xcls, c, done_matches_xof = True):
+    """
+    Test the SHAKE and cSHAKE XOFs.
+
+    This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
+    to indicate that the XOF output is range-separated from the fixed-length
+    outputs (unlike the basic SHAKE functions).
+    """
+
+    ## Check the hash attributes.
+    x = xcls()
+    me.assertEqual(x.rate, 200 - c)
+    me.assertEqual(x.buffered, 0)
+    me.assertEqual(x.state, "absorb")
+
+    ## Set some initial values.
+    func = T.bin("TESTXOF")
+    perso = T.bin("catacomb-python test")
+    m = T.span(167)
+    h0 = xcls().hash(m).done(193)
+    me.assertEqual(len(h0), 193)
+    h1 = xcls(func = func, perso = perso).hash(m).done(193)
+    me.assertEqual(len(h1), 193)
+    me.assertNotEqual(h0, h1)
+
+    ## Check input and output in pieces, and the state machine.
+    if done_matches_xof: h = h0
+    else: h = xcls().hash(m).xof().get(len(h0))
+    x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
+    me.assertEqual(h, x.get(98) + x.get(95))
+
+    ## Check masking.
+    x = xcls().hash(m).xof()
+    me.assertEqual(x.mask(m), m ^ h[0:len(m)])
+
+    ## Check the `check' method.
+    me.assertTrue(xcls().hash(m).check(h0))
+    me.assertFalse(xcls().hash(m).check(h1))
+
+    ## Check the menagerie of random hashing methods.
+    def mkhash(_):
+      x = xcls(func = func, perso = perso)
+      return x, lambda: x.done(100 - x.rate//2)
+    me.check_hashbuffer(mkhash)
+
+    ## Check the state machine tracking.
+    x = xcls(); me.assertEqual(x.state, "absorb")
+    x.hash(m); me.assertEqual(x.state, "absorb")
+    xx = x.copy()
+    h = xx.done(100 - x.rate//2)
+    me.assertEqual(xx.state, "dead")
+    me.assertRaises(ValueError, xx.done, 1)
+    me.assertRaises(ValueError, xx.get, 1)
+    me.assertEqual(x.state, "absorb")
+    me.assertRaises(ValueError, x.get, 1)
+    x.xof(); me.assertEqual(x.state, "squeeze")
+    me.assertRaises(ValueError, x.done, 1)
+    _ = x.get(1)
+    yy = x.copy(); me.assertEqual(yy.state, "squeeze")
+
+  def test_shake128(me): me.check_shake(C.Shake128, 32)
+  def test_shake256(me): me.check_shake(C.Shake256, 64)
+
+  def check_kmac(me, mcls, c):
+    k = T.span(32)
+    me.check_shake(lambda func = None, perso = T.bin(""):
+                     mcls(k, perso = perso),
+                   c, done_matches_xof = False)
+
+  def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
+  def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
+
+###--------------------------------------------------------------------------
+class TestPRP (T.GenericTestMixin):
+  """Test pseudorandom permutations (PRPs)."""
+
+  def _test_prp(me, pcls):
+
+    ## Check the PRP properties.
+    me.assertEqual(type(pcls.name), str)
+    me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
+    me.assertEqual(type(pcls.blksz), int)
+
+    ## Check round-tripping.
+    k = T.span(pcls.keysz.default)
+    key = pcls(k)
+    m = T.span(pcls.blksz)
+    c = key.encrypt(m)
+    me.assertEqual(len(c), pcls.blksz)
+    me.assertEqual(m, key.decrypt(c))
+
+    ## Check that bad key lengths are rejected.
+    badlen = bad_key_size(pcls.keysz)
+    if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
+
+    ## Check that bad blocks are rejected.
+    badblk = T.span(pcls.blksz + 1)
+    me.assertRaises(ValueError, key.encrypt, badblk)
+    me.assertRaises(ValueError, key.decrypt, badblk)
+
+TestPRP.generate_testcases((name, C.gcprps[name]) for name in
+  ["desx", "blowfish", "rijndael"])
+
+###----- That's all, folks --------------------------------------------------
+
+if __name__ == "__main__": U.main()