t/t-bytes.py: Check that indexing, slicing, etc. return `C.ByteString'.
[catacomb-python] / t / t-algorithms.py
index 8e073f2..d7b97cb 100644 (file)
@@ -149,6 +149,7 @@ class TestKeysize (U.TestCase):
     for n in [0, 12, 20, 5000]:
       me.assertTrue(ksz.check(n))
       me.assertEqual(ksz.best(n), n)
+      me.assertEqual(ksz.pad(n), n)
 
     ## A typical two-byte spec.  (No published algorithms actually /need/ a
     ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
@@ -160,6 +161,7 @@ class TestKeysize (U.TestCase):
     for n in [0, 12, 20, 5000]:
       me.assertTrue(ksz.check(n))
       me.assertEqual(ksz.best(n), n)
+      me.assertEqual(ksz.pad(n), n)
 
     ## Check construction.
     ksz = C.KeySZAny(15)
@@ -186,6 +188,8 @@ class TestKeysize (U.TestCase):
       else: me.assertFalse(ksz.check(x))
       if best is None: me.assertRaises(ValueError, ksz.best, x)
       else: me.assertEqual(ksz.best(x), best)
+      if pad is None: me.assertRaises(ValueError, ksz.pad, x)
+      else: me.assertEqual(ksz.pad(x), pad)
 
     ## Check construction.
     ksz = C.KeySZSet(7)
@@ -210,13 +214,15 @@ class TestKeysize (U.TestCase):
     me.assertEqual(ksz.min, 4)
     me.assertEqual(ksz.max, 32)
     me.assertEqual(ksz.mod, 4)
-    for x, best in [(3, None), (4, 4), (5, 4),
-                    (15, 12), (16, 16), (17, 16),
-                    (31, 28), (32, 32), (33, 32)]:
-      if x == best: me.assertTrue(ksz.check(x))
+    for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
+                         (15, 12, 16), (16, 16, 16), (17, 16, 20),
+                         (31, 28, 32), (32, 32, 32), (33, 32, None)]:
+      if x == best == pad: me.assertTrue(ksz.check(x))
       else: me.assertFalse(ksz.check(x))
       if best is None: me.assertRaises(ValueError, ksz.best, x)
       else: me.assertEqual(ksz.best(x), best)
+      if pad is None: me.assertRaises(ValueError, ksz.pad, x)
+      else: me.assertEqual(ksz.pad(x), pad)
 
     ## Check construction.
     ksz = C.KeySZRange(28, 21, 35, 7)
@@ -310,6 +316,280 @@ TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
    "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
 
 ###--------------------------------------------------------------------------
+class TestAuthenticatedEncryption \
+        (HashBufferTestMixin, T.GenericTestMixin):
+  """Test authenticated encryption schemes."""
+
+  def _test_aead(me, aecls):
+
+    ## Check the class properties.
+    me.assertEqual(type(aecls.name), str)
+    me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
+    me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
+    me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
+    me.assertEqual(type(aecls.blksz), int)
+    me.assertEqual(type(aecls.bufsz), int)
+    me.assertEqual(type(aecls.ohd), int)
+    me.assertEqual(type(aecls.flags), int)
+
+    ## Check round-tripping, with full precommitment.  First, select some
+    ## parameters.  (It's conceivable that some AEAD schemes are more
+    ## restrictive than advertised by the various properties, but this works
+    ## out OK in practice.)
+    k = T.span(aecls.keysz.default)
+    n = T.span(aecls.noncesz.default)
+    if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
+    else: h = T.span(131)
+    m = T.span(253)
+    tsz = aecls.tagsz.default
+    key = aecls(k)
+
+    ## Next, encrypt a message, checking that things are proper as we go.
+    enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+    me.assertEqual(enc.hsz, len(h))
+    me.assertEqual(enc.msz, len(m))
+    me.assertEqual(enc.mlen, 0)
+    me.assertEqual(enc.tsz, tsz)
+    aad = enc.aad()
+    if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
+    else: me.assertEqual(aad.hsz, None)
+    me.assertEqual(aad.hlen, 0)
+    if not aecls.flags&C.AEADF_NOAAD:
+      aad.hash(h[0:83])
+      me.assertEqual(aad.hlen, 83)
+      aad.hash(h[83:131])
+      me.assertEqual(aad.hlen, 131)
+    c0 = enc.encrypt(m[0:57])
+    me.assertEqual(enc.mlen, 57)
+    me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
+    c1 = enc.encrypt(m[57:189])
+    me.assertEqual(enc.mlen, 189)
+    me.assertTrue(132 - aecls.bufsz <= len(c1) <=
+                  132 + aecls.bufsz + aecls.ohd)
+    c2 = enc.encrypt(m[189:253])
+    me.assertEqual(enc.mlen, 253)
+    me.assertTrue(64 - aecls.bufsz <= len(c2) <=
+                  64 + aecls.bufsz + aecls.ohd)
+    c3, t = enc.done(aad = aad)
+    me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
+    c = c0 + c1 + c2 + c3
+    me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
+    me.assertEqual(len(t), tsz)
+
+    ## And now decrypt it again, with different record boundaries.
+    dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+    me.assertEqual(dec.hsz, len(h))
+    me.assertEqual(dec.csz, len(c))
+    me.assertEqual(dec.clen, 0)
+    me.assertEqual(dec.tsz, tsz)
+    aad = dec.aad()
+    if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
+    else: me.assertEqual(aad.hsz, None)
+    me.assertEqual(aad.hlen, 0)
+    aad.hash(h)
+    m0 = dec.decrypt(c[0:156])
+    me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
+    m1 = dec.decrypt(c[156:])
+    me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
+                  len(c) - 156 + aecls.bufsz)
+    m2 = dec.done(tag = t, aad = aad)
+    me.assertEqual(m0 + m1 + m2, m)
+
+    ## And again, with the wrong tag.
+    dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+    aad = dec.aad(); aad.hash(h)
+    _ = dec.decrypt(c)
+    me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
+
+    ## Check that the all-in-one methods work.
+    me.assertEqual((c, t),
+                   key.encrypt(n = n, h = h, m = m, tsz = tsz))
+    me.assertEqual(m,
+                   key.decrypt(n = n, h = h, c = c, t = t))
+
+    ## Check that bad key, nonce, and tag lengths are rejected.
+    badlen = bad_key_size(aecls.keysz)
+    if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
+    badlen = bad_key_size(aecls.noncesz)
+    if badlen is not None:
+      me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
+                      hsz = len(h), msz = len(m), tsz = tsz)
+      me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
+                      hsz = len(h), csz = len(c), tsz = tsz)
+      if not aecls.flags&C.AEADF_PCTSZ:
+        enc = key.enc(nonce = n, hsz = 0, msz = len(m))
+        _ = enc.encrypt(m)
+        me.assertRaises(ValueError, enc.done, tsz = badlen)
+    badlen = bad_key_size(aecls.tagsz)
+    if badlen is not None:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      hsz = len(h), msz = len(m), tsz = badlen)
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      hsz = len(h), csz = len(c), tsz = badlen)
+
+    ## Check that we can't get a loose `aad' object from a scheme which has
+    ## nonce-dependent AAD processing.
+    if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
+
+    ## Check the menagerie of AAD hashing methods.
+    if not aecls.flags&C.AEADF_NOAAD:
+      def mkhash(hsz):
+        enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
+        aad = enc.aad()
+        return aad, lambda: enc.done(aad = aad)[1]
+      me.check_hashbuffer(mkhash)
+
+    ## Check that encryption/decryption works with the given precommitments.
+    def quick_enc_check(**kw):
+      enc = key.enc(**kw)
+      aad = enc.aad().hash(h)
+      c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
+      me.assertEqual((c, t), (c0 + c1, tt))
+    def quick_dec_check(**kw):
+      dec = key.dec(**kw)
+      aad = dec.aad().hash(h)
+      m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
+      me.assertEqual(m, m0 + m1)
+
+    ## Check that we can get away without precommitting to the header length
+    ## if and only if the AEAD scheme says it will let us.
+    if aecls.flags&C.AEADF_PCHSZ:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      msz = len(m), tsz = tsz)
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      csz = len(c), tsz = tsz)
+    else:
+      quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
+      quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
+
+    ## Check that we can get away without precommitting to the message/
+    ## ciphertext length if and only if the AEAD scheme says it will let us.
+    if aecls.flags&C.AEADF_PCMSZ:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      hsz = len(h), tsz = tsz)
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      hsz = len(h), tsz = tsz)
+    else:
+      quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
+      quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
+
+    ## Check that we can get away without precommitting to the tag length if
+    ## and only if the AEAD scheme says it will let us.
+    if aecls.flags&C.AEADF_PCTSZ:
+      me.assertRaises(ValueError, key.enc, nonce = n,
+                      hsz = len(h), msz = len(m))
+      me.assertRaises(ValueError, key.dec, nonce = n,
+                      hsz = len(h), csz = len(c))
+    else:
+      quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
+      quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
+
+    ## Check that if we precommit to the header length, we're properly held
+    ## to the commitment.
+    if not aecls.flags&C.AEADF_NOAAD:
+
+      ## First, check encryption with underrun.  If we must supply AAD first,
+      ## then the underrun will be reported when we start trying to encrypt;
+      ## otherwise, checking is delayed until `done'.
+      enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      aad = enc.aad().hash(h[0:83])
+      if aecls.flags&C.AEADF_AADFIRST:
+        me.assertRaises(ValueError, enc.encrypt, m)
+      else:
+        _ = enc.encrypt(m)
+        me.assertRaises(ValueError, enc.done, aad = aad)
+
+      ## Next, check decryption with underrun.  If we must supply AAD first,
+      ## then the underrun will be reported when we start trying to encrypt;
+      ## otherwise, checking is delayed until `done'.
+      dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+      aad = dec.aad().hash(h[0:83])
+      if aecls.flags&C.AEADF_AADFIRST:
+        me.assertRaises(ValueError, dec.decrypt, c)
+      else:
+        _ = dec.decrypt(c)
+        me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
+
+      ## If AAD processing is nonce-dependent then an overrun will be
+      ## detected imediately.
+      if aecls.flags&C.AEADF_AADNDEP:
+        enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+        aad = enc.aad().hash(h[0:83])
+        me.assertRaises(ValueError, aad.hash, h[82:131])
+        dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
+        aad = dec.aad().hash(h[0:83])
+        me.assertRaises(ValueError, aad.hash, h[82:131])
+
+    ## Some additional tests for nonce-dependent `aad' objects.
+    if aecls.flags&C.AEADF_AADNDEP:
+
+      ## Check that `aad' objects can't be used once their parents are gone.
+      enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      aad = enc.aad()
+      del enc
+      me.assertRaises(ValueError, aad.hash, h)
+
+      ## Check that they can't be crossed over.
+      enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
+      enc0.aad().hash(h)
+      aad1 = enc1.aad().hash(h)
+      _ = enc0.encrypt(m)
+      me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
+
+    ## Test copying AAD.
+    if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
+      aad0 = key.aad()
+      aad0.hash(h[0:83])
+      aad1 = aad0.copy()
+      aad2 = aad1.copy()
+      aad0.hash(h[83:131])
+      aad1.hash(h[83:131])
+      aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
+      me.assertEqual(key.enc(nonce = n, hsz = len(h),
+                             msz = 0, tsz = tsz).done(aad = aad0),
+                     key.enc(nonce = n, hsz = len(h),
+                             msz = 0, tsz = tsz).done(aad = aad1))
+      me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
+                                msz = 0, tsz = tsz).done(aad = aad0),
+                        key.enc(nonce = n, hsz = len(h),
+                                msz = 0, tsz = tsz).done(aad = aad2))
+
+    ## Check that if we precommit to the message length, we're properly held
+    ## to the commitment.  (Fortunately, this is way simpler than the AAD
+    ## case above.)  First, try an underrun.
+    enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
+    _ = enc.encrypt(m[0:183])
+    me.assertRaises(ValueError, enc.done, tsz = tsz)
+    dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
+    _ = dec.decrypt(c[0:183])
+    me.assertRaises(ValueError, dec.done, tag = t)
+
+    ## And now an overrun.
+    enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
+    me.assertRaises(ValueError, enc.encrypt, m)
+    dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
+    me.assertRaises(ValueError, dec.decrypt, c)
+
+    ## Finally, check that if we precommit to a tag length, we're properly
+    ## held to the commitment.  This depends on being able to find a tag size
+    ## which isn't the default.
+    tsz1 = different_key_size(aecls.tagsz, tsz)
+    if tsz1 is not None:
+      enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
+      _ = enc.encrypt(m)
+      me.assertRaises(ValueError, enc.done, tsz = tsz)
+      dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
+      aad = dec.aad().hash(h)
+      _ = dec.decrypt(c)
+      me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
+
+TestAuthenticatedEncryption.generate_testcases \
+  ((name, C.gcaeads[name]) for name in
+   ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
+    "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
+
+###--------------------------------------------------------------------------
 class BaseTestHash (HashBufferTestMixin):
   """Base class for testing hash functions."""
 
@@ -432,7 +712,7 @@ class TestHLatin (U.TestCase):
   """Test the `hsalsa20' and `hchacha20' functions."""
 
   def test_hlatin(me):
-    kk = [T.span(sz) for sz in [32]]
+    kk = [T.span(sz) for sz in [10, 16, 32]]
     n = T.span(16)
     bad_k = T.span(18)
     bad_n = T.span(13)
@@ -462,6 +742,13 @@ class TestKeccak (HashBufferTestMixin):
     st1.mix(m0).step()
     me.assertNotEqual(st0.extract(32), st1.extract(32))
 
+    ## Check state copying.
+    st1 = st0.copy()
+    mask = st1.extract(len(m1))
+    st0.mix(m1)
+    st1.mix(m1)
+    me.assertEqual(st0.extract(32), st1.extract(32))
+
     ## Check error conditions.
     _ = st0.extract(200)
     me.assertRaises(ValueError, st0.extract, 201)
@@ -501,7 +788,7 @@ class TestKeccak (HashBufferTestMixin):
 
     ## Check masking.
     x = xcls().hash(m).xof()
-    me.assertEqual(x.mask(m), C.ByteString(m) ^ C.ByteString(h[0:len(m)]))
+    me.assertEqual(x.mask(m), m ^ h[0:len(m)])
 
     ## Check the `check' method.
     me.assertTrue(xcls().hash(m).check(h0))