t/t-algorithms.py (HashBufferTestMixin): Assemble method names dynamically.
[catacomb-python] / t / t-algorithms.py
CommitLineData
553d59fe
MW
1### -*- mode: python, coding: utf-8 -*-
2###
3### Test symmetric algorithms
4###
5### (c) 2019 Straylight/Edgeware
6###
7
8###----- Licensing notice ---------------------------------------------------
9###
10### This file is part of the Python interface to Catacomb.
11###
12### Catacomb/Python is free software: you can redistribute it and/or
13### modify it under the terms of the GNU General Public License as
14### published by the Free Software Foundation; either version 2 of the
15### License, or (at your option) any later version.
16###
17### Catacomb/Python is distributed in the hope that it will be useful, but
18### WITHOUT ANY WARRANTY; without even the implied warranty of
19### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20### General Public License for more details.
21###
22### You should have received a copy of the GNU General Public License
23### along with Catacomb/Python. If not, write to the Free Software
24### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25### USA.
26
27###--------------------------------------------------------------------------
28### Imported modules.
29
30import catacomb as C
31import unittest as U
32import testutils as T
33
34###--------------------------------------------------------------------------
35### Utilities.
36
37def bad_key_size(ksz):
38 if isinstance(ksz, C.KeySZAny): return None
39 elif isinstance(ksz, C.KeySZRange):
40 if ksz.mod != 1: return ksz.min + 1
a9410db7 41 elif ksz.max is not None: return ksz.max + 1
553d59fe
MW
42 elif ksz.min != 0: return ksz.min - 1
43 else: return None
44 elif isinstance(ksz, C.KeySZSet):
45 for sz in sorted(ksz.set):
46 if sz + 1 not in ksz.set: return sz + 1
47 assert False, "That should have worked."
48 else:
49 return None
50
51def different_key_size(ksz, sz):
52 if isinstance(ksz, C.KeySZAny): return sz + 1
53 elif isinstance(ksz, C.KeySZRange):
54 if sz > ksz.min: return sz - ksz.mod
a9410db7 55 elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
553d59fe
MW
56 else: return None
57 elif isinstance(ksz, C.KeySZSet):
58 for sz1 in sorted(ksz.set):
59 if sz != sz1: return sz1
60 return None
61 else:
62 return None
63
64class HashBufferTestMixin (U.TestCase):
65 """Mixin class for testing all of the various `hash...' methods."""
66
30ab7d89
MW
67 HASHMETH = "hash"
68 def dohash(me, h, op, arg):
69 getattr(h, me.HASHMETH + op)(arg)
70 return me
71
72 def check_hashbuffer_hashn(me, w, bigendp, makefn, hashop):
553d59fe
MW
73 """Check `hashuN'."""
74
75 ## Check encoding an integer.
76 h0, donefn0 = makefn(w + 2)
30ab7d89
MW
77 me.dohash(h0, "u8", 0x00)
78 me.dohash(h0, hashop, T.bytes_as_int(w, bigendp))
79 me.dohash(h0, "u8", w + 1)
553d59fe 80 h1, donefn1 = makefn(w + 2)
30ab7d89 81 me.dohash(h1, "", T.span(w + 2))
553d59fe
MW
82 me.assertEqual(donefn0(), donefn1())
83
84 ## Check overflow detection.
85 h0, _ = makefn(w)
30ab7d89 86 me.assertRaises(OverflowError, me.dohash, h0, hashop, 1 << 8*w)
553d59fe 87
30ab7d89 88 def check_hashbuffer_bufn(me, w, bigendp, makefn, hashop):
553d59fe
MW
89 """Check `hashbufN'."""
90
91 ## Go through a number of different sizes.
92 for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
93 if n >= 1 << 8*w: continue
94 h0, donefn0 = makefn(2 + w + n)
30ab7d89
MW
95 me.dohash(h0, "u8", 0x00)
96 me.dohash(h0, hashop, T.span(n))
97 me.dohash(h0, "u8", 0xff)
553d59fe 98 h1, donefn1 = makefn(2 + w + n)
30ab7d89 99 me.dohash(h1, "", T.prep_lenseq(w, n, bigendp, True))
553d59fe
MW
100 me.assertEqual(donefn0(), donefn1())
101
102 ## Check blocks which are too large for the length prefix.
103 if w <= 3:
104 n = 1 << 8*w
105 h0, _ = makefn(w + n)
30ab7d89
MW
106 me.assertRaises(ValueError,
107 me.dohash, h0, hashop, C.ByteString.zero(n))
553d59fe
MW
108
109 def check_hashbuffer(me, makefn):
110 """Test the various `hash...' methods."""
111
112 ## Check `hashuN'.
30ab7d89
MW
113 me.check_hashbuffer_hashn(1, True, makefn, "u8")
114 me.check_hashbuffer_hashn(2, True, makefn, "u16")
115 me.check_hashbuffer_hashn(2, True, makefn, "u16b")
116 me.check_hashbuffer_hashn(2, False, makefn, "u16l")
117 me.check_hashbuffer_hashn(3, True, makefn, "u24")
118 me.check_hashbuffer_hashn(3, True, makefn, "u24b")
119 me.check_hashbuffer_hashn(3, False, makefn, "u24l")
120 me.check_hashbuffer_hashn(4, True, makefn, "u32")
121 me.check_hashbuffer_hashn(4, True, makefn, "u32b")
122 me.check_hashbuffer_hashn(4, False, makefn, "u32l")
123 if hasattr(makefn(0)[0], me.HASHMETH + "u64"):
124 me.check_hashbuffer_hashn(8, True, makefn, "u64")
125 me.check_hashbuffer_hashn(8, True, makefn, "u64b")
126 me.check_hashbuffer_hashn(8, False, makefn, "u64l")
553d59fe
MW
127
128 ## Check `hashbufN'.
30ab7d89
MW
129 me.check_hashbuffer_bufn(1, True, makefn, "buf8")
130 me.check_hashbuffer_bufn(2, True, makefn, "buf16")
131 me.check_hashbuffer_bufn(2, True, makefn, "buf16b")
132 me.check_hashbuffer_bufn(2, False, makefn, "buf16l")
133 me.check_hashbuffer_bufn(3, True, makefn, "buf24")
134 me.check_hashbuffer_bufn(3, True, makefn, "buf24b")
135 me.check_hashbuffer_bufn(3, False, makefn, "buf24l")
136 me.check_hashbuffer_bufn(4, True, makefn, "buf32")
137 me.check_hashbuffer_bufn(4, True, makefn, "buf32b")
138 me.check_hashbuffer_bufn(4, False, makefn, "buf32l")
139 if hasattr(makefn(0)[0], me.HASHMETH + "u64"):
140 me.check_hashbuffer_bufn(8, True, makefn, "buf64")
141 me.check_hashbuffer_bufn(8, True, makefn, "buf64b")
142 me.check_hashbuffer_bufn(8, False, makefn, "buf64l")
553d59fe
MW
143
144###--------------------------------------------------------------------------
145class TestKeysize (U.TestCase):
146
147 def test_any(me):
148
149 ## A typical one-byte spec.
150 ksz = C.seal.keysz
151 me.assertEqual(type(ksz), C.KeySZAny)
152 me.assertEqual(ksz.default, 20)
153 me.assertEqual(ksz.min, 0)
a9410db7 154 me.assertEqual(ksz.max, None)
553d59fe
MW
155 for n in [0, 12, 20, 5000]:
156 me.assertTrue(ksz.check(n))
157 me.assertEqual(ksz.best(n), n)
606677f6 158 me.assertEqual(ksz.pad(n), n)
553d59fe
MW
159
160 ## A typical two-byte spec. (No published algorithms actually /need/ a
161 ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
162 ksz = C.sha256_hmac.keysz
163 me.assertEqual(type(ksz), C.KeySZAny)
164 me.assertEqual(ksz.default, 32)
165 me.assertEqual(ksz.min, 0)
a9410db7 166 me.assertEqual(ksz.max, None)
553d59fe
MW
167 for n in [0, 12, 20, 5000]:
168 me.assertTrue(ksz.check(n))
169 me.assertEqual(ksz.best(n), n)
606677f6 170 me.assertEqual(ksz.pad(n), n)
553d59fe
MW
171
172 ## Check construction.
173 ksz = C.KeySZAny(15)
174 me.assertEqual(ksz.default, 15)
175 me.assertEqual(ksz.min, 0)
a9410db7 176 me.assertEqual(ksz.max, None)
553d59fe
MW
177 me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
178 me.assertEqual(C.KeySZAny(0).default, 0)
179
180 def test_set(me):
181 ## Note that no published algorithm uses a 16-bit `set' spec.
182
183 ## A typical spec.
184 ksz = C.salsa20.keysz
185 me.assertEqual(type(ksz), C.KeySZSet)
186 me.assertEqual(ksz.default, 32)
187 me.assertEqual(ksz.min, 10)
188 me.assertEqual(ksz.max, 32)
fd46a88d 189 me.assertEqual(ksz.set, set([10, 16, 32]))
553d59fe
MW
190 for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
191 (15, 10, 16), (16, 16, 16), (17, 16, 32),
192 (31, 16, 32), (32, 32, 32), (33, 32, None)]:
193 if x == best == pad: me.assertTrue(ksz.check(x))
194 else: me.assertFalse(ksz.check(x))
195 if best is None: me.assertRaises(ValueError, ksz.best, x)
196 else: me.assertEqual(ksz.best(x), best)
606677f6
MW
197 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
198 else: me.assertEqual(ksz.pad(x), pad)
553d59fe
MW
199
200 ## Check construction.
201 ksz = C.KeySZSet(7)
202 me.assertEqual(ksz.default, 7)
fd46a88d 203 me.assertEqual(ksz.set, set([7]))
553d59fe
MW
204 me.assertEqual(ksz.min, 7)
205 me.assertEqual(ksz.max, 7)
fd46a88d 206 ksz = C.KeySZSet(7, iter([3, 6, 9]))
553d59fe 207 me.assertEqual(ksz.default, 7)
fd46a88d 208 me.assertEqual(ksz.set, set([3, 6, 7, 9]))
553d59fe
MW
209 me.assertEqual(ksz.min, 3)
210 me.assertEqual(ksz.max, 9)
211
212 def test_range(me):
213 ## Note that no published algorithm uses a 16-bit `range' spec, or an
214 ## unbounded `range'.
215
216 ## A typical spec.
217 ksz = C.rijndael.keysz
218 me.assertEqual(type(ksz), C.KeySZRange)
219 me.assertEqual(ksz.default, 32)
220 me.assertEqual(ksz.min, 4)
221 me.assertEqual(ksz.max, 32)
222 me.assertEqual(ksz.mod, 4)
606677f6
MW
223 for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
224 (15, 12, 16), (16, 16, 16), (17, 16, 20),
225 (31, 28, 32), (32, 32, 32), (33, 32, None)]:
226 if x == best == pad: me.assertTrue(ksz.check(x))
553d59fe
MW
227 else: me.assertFalse(ksz.check(x))
228 if best is None: me.assertRaises(ValueError, ksz.best, x)
229 else: me.assertEqual(ksz.best(x), best)
606677f6
MW
230 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
231 else: me.assertEqual(ksz.pad(x), pad)
553d59fe
MW
232
233 ## Check construction.
234 ksz = C.KeySZRange(28, 21, 35, 7)
235 me.assertEqual(ksz.default, 28)
236 me.assertEqual(ksz.min, 21)
237 me.assertEqual(ksz.max, 35)
238 me.assertEqual(ksz.mod, 7)
a9410db7
MW
239 ksz = C.KeySZRange(28, 21, None, 7)
240 me.assertEqual(ksz.min, 21)
241 me.assertEqual(ksz.max, None)
242 me.assertEqual(ksz.mod, 7)
243 me.assertEqual(ksz.pad(36), 42)
553d59fe
MW
244 me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
245 me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
246 me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
247 me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
248 me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
249 me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
250 me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
251
252 def test_conversions(me):
253 me.assertEqual(C.KeySZ.fromec(256), 128)
254 me.assertEqual(C.KeySZ.fromschnorr(256), 128)
255 me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
256 me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
257 me.assertEqual(C.KeySZ.toec(128), 256)
258 me.assertEqual(C.KeySZ.toschnorr(128), 256)
259 me.assertEqual(C.KeySZ.todl(128), 2958.6875)
260 me.assertEqual(C.KeySZ.toif(128), 2958.6875)
261
262###--------------------------------------------------------------------------
263class TestCipher (T.GenericTestMixin):
264 """Test basic symmetric ciphers."""
265
266 def _test_cipher(me, ccls):
267
268 ## Check the class properties.
269 me.assertEqual(type(ccls.name), str)
270 me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
271 me.assertEqual(type(ccls.blksz), int)
272
273 ## Check round-tripping.
274 k = T.span(ccls.keysz.default)
275 iv = T.span(ccls.blksz)
276 m = T.span(253)
277 enc = ccls(k)
278 dec = ccls(k)
279 try: enc.setiv(iv)
280 except ValueError: can_setiv = False
281 else:
282 can_setiv = True
283 dec.setiv(iv)
284 c0 = enc.encrypt(m[0:57])
285 m0 = dec.decrypt(c0)
286 c1 = enc.encrypt(m[57:189])
287 m1 = dec.decrypt(c1)
288 try: enc.bdry()
289 except ValueError: can_bdry = False
290 else:
291 dec.bdry()
292 can_bdry = True
293 c2 = enc.encrypt(m[189:253])
294 m2 = dec.decrypt(c2)
295 me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
296 me.assertEqual(m0, m[0:57])
297 me.assertEqual(m1, m[57:189])
298 me.assertEqual(m2, m[189:253])
299
300 ## Check the `enczero' and `deczero' methods.
301 c3 = enc.enczero(32)
302 me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
303 m4 = dec.deczero(32)
304 me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
305
306 ## Check that ciphers which support a `boundary' operation actually
307 ## need it.
308 if can_bdry:
309 dec = ccls(k)
310 if can_setiv: dec.setiv(iv)
311 m01 = dec.decrypt(c0 + c1)
312 me.assertEqual(m01, m[0:189])
313
314 ## Check that the boundary actually does something.
315 if can_bdry:
316 dec = ccls(k)
317 if can_setiv: dec.setiv(iv)
318 m012 = dec.decrypt(c0 + c1 + c2)
319 me.assertNotEqual(m012, m)
320
321 ## Check that bad key lengths are rejected.
322 badlen = bad_key_size(ccls.keysz)
323 if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
324
325TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
326 ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
327 "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
328
329###--------------------------------------------------------------------------
10f3f611
MW
330class TestAuthenticatedEncryption \
331 (HashBufferTestMixin, T.GenericTestMixin):
332 """Test authenticated encryption schemes."""
333
334 def _test_aead(me, aecls):
335
336 ## Check the class properties.
337 me.assertEqual(type(aecls.name), str)
338 me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
339 me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
340 me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
341 me.assertEqual(type(aecls.blksz), int)
342 me.assertEqual(type(aecls.bufsz), int)
343 me.assertEqual(type(aecls.ohd), int)
344 me.assertEqual(type(aecls.flags), int)
345
346 ## Check round-tripping, with full precommitment. First, select some
347 ## parameters. (It's conceivable that some AEAD schemes are more
348 ## restrictive than advertised by the various properties, but this works
349 ## out OK in practice.)
350 k = T.span(aecls.keysz.default)
351 n = T.span(aecls.noncesz.default)
352 if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
353 else: h = T.span(131)
354 m = T.span(253)
355 tsz = aecls.tagsz.default
356 key = aecls(k)
357
358 ## Next, encrypt a message, checking that things are proper as we go.
359 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
360 me.assertEqual(enc.hsz, len(h))
361 me.assertEqual(enc.msz, len(m))
362 me.assertEqual(enc.mlen, 0)
363 me.assertEqual(enc.tsz, tsz)
364 aad = enc.aad()
365 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
366 else: me.assertEqual(aad.hsz, None)
367 me.assertEqual(aad.hlen, 0)
368 if not aecls.flags&C.AEADF_NOAAD:
369 aad.hash(h[0:83])
370 me.assertEqual(aad.hlen, 83)
371 aad.hash(h[83:131])
372 me.assertEqual(aad.hlen, 131)
373 c0 = enc.encrypt(m[0:57])
374 me.assertEqual(enc.mlen, 57)
375 me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
376 c1 = enc.encrypt(m[57:189])
377 me.assertEqual(enc.mlen, 189)
378 me.assertTrue(132 - aecls.bufsz <= len(c1) <=
379 132 + aecls.bufsz + aecls.ohd)
380 c2 = enc.encrypt(m[189:253])
381 me.assertEqual(enc.mlen, 253)
382 me.assertTrue(64 - aecls.bufsz <= len(c2) <=
383 64 + aecls.bufsz + aecls.ohd)
384 c3, t = enc.done(aad = aad)
385 me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
386 c = c0 + c1 + c2 + c3
387 me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
388 me.assertEqual(len(t), tsz)
389
390 ## And now decrypt it again, with different record boundaries.
391 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
392 me.assertEqual(dec.hsz, len(h))
393 me.assertEqual(dec.csz, len(c))
394 me.assertEqual(dec.clen, 0)
395 me.assertEqual(dec.tsz, tsz)
396 aad = dec.aad()
397 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
398 else: me.assertEqual(aad.hsz, None)
399 me.assertEqual(aad.hlen, 0)
400 aad.hash(h)
401 m0 = dec.decrypt(c[0:156])
402 me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
403 m1 = dec.decrypt(c[156:])
404 me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
405 len(c) - 156 + aecls.bufsz)
406 m2 = dec.done(tag = t, aad = aad)
407 me.assertEqual(m0 + m1 + m2, m)
408
409 ## And again, with the wrong tag.
410 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
411 aad = dec.aad(); aad.hash(h)
412 _ = dec.decrypt(c)
413 me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
414
415 ## Check that the all-in-one methods work.
416 me.assertEqual((c, t),
417 key.encrypt(n = n, h = h, m = m, tsz = tsz))
418 me.assertEqual(m,
419 key.decrypt(n = n, h = h, c = c, t = t))
420
421 ## Check that bad key, nonce, and tag lengths are rejected.
422 badlen = bad_key_size(aecls.keysz)
423 if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
424 badlen = bad_key_size(aecls.noncesz)
425 if badlen is not None:
426 me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
427 hsz = len(h), msz = len(m), tsz = tsz)
428 me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
429 hsz = len(h), csz = len(c), tsz = tsz)
430 if not aecls.flags&C.AEADF_PCTSZ:
431 enc = key.enc(nonce = n, hsz = 0, msz = len(m))
432 _ = enc.encrypt(m)
433 me.assertRaises(ValueError, enc.done, tsz = badlen)
434 badlen = bad_key_size(aecls.tagsz)
435 if badlen is not None:
436 me.assertRaises(ValueError, key.enc, nonce = n,
437 hsz = len(h), msz = len(m), tsz = badlen)
438 me.assertRaises(ValueError, key.dec, nonce = n,
439 hsz = len(h), csz = len(c), tsz = badlen)
440
441 ## Check that we can't get a loose `aad' object from a scheme which has
442 ## nonce-dependent AAD processing.
443 if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
444
445 ## Check the menagerie of AAD hashing methods.
446 if not aecls.flags&C.AEADF_NOAAD:
447 def mkhash(hsz):
448 enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
449 aad = enc.aad()
450 return aad, lambda: enc.done(aad = aad)[1]
451 me.check_hashbuffer(mkhash)
452
453 ## Check that encryption/decryption works with the given precommitments.
454 def quick_enc_check(**kw):
455 enc = key.enc(**kw)
456 aad = enc.aad().hash(h)
457 c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
458 me.assertEqual((c, t), (c0 + c1, tt))
459 def quick_dec_check(**kw):
460 dec = key.dec(**kw)
461 aad = dec.aad().hash(h)
462 m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
463 me.assertEqual(m, m0 + m1)
464
465 ## Check that we can get away without precommitting to the header length
466 ## if and only if the AEAD scheme says it will let us.
467 if aecls.flags&C.AEADF_PCHSZ:
468 me.assertRaises(ValueError, key.enc, nonce = n,
469 msz = len(m), tsz = tsz)
470 me.assertRaises(ValueError, key.dec, nonce = n,
471 csz = len(c), tsz = tsz)
472 else:
473 quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
474 quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
475
476 ## Check that we can get away without precommitting to the message/
477 ## ciphertext length if and only if the AEAD scheme says it will let us.
478 if aecls.flags&C.AEADF_PCMSZ:
479 me.assertRaises(ValueError, key.enc, nonce = n,
480 hsz = len(h), tsz = tsz)
481 me.assertRaises(ValueError, key.dec, nonce = n,
482 hsz = len(h), tsz = tsz)
483 else:
484 quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
485 quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
486
487 ## Check that we can get away without precommitting to the tag length if
488 ## and only if the AEAD scheme says it will let us.
489 if aecls.flags&C.AEADF_PCTSZ:
490 me.assertRaises(ValueError, key.enc, nonce = n,
491 hsz = len(h), msz = len(m))
492 me.assertRaises(ValueError, key.dec, nonce = n,
493 hsz = len(h), csz = len(c))
494 else:
495 quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
496 quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
497
498 ## Check that if we precommit to the header length, we're properly held
499 ## to the commitment.
500 if not aecls.flags&C.AEADF_NOAAD:
501
502 ## First, check encryption with underrun. If we must supply AAD first,
503 ## then the underrun will be reported when we start trying to encrypt;
504 ## otherwise, checking is delayed until `done'.
505 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
506 aad = enc.aad().hash(h[0:83])
507 if aecls.flags&C.AEADF_AADFIRST:
508 me.assertRaises(ValueError, enc.encrypt, m)
509 else:
510 _ = enc.encrypt(m)
511 me.assertRaises(ValueError, enc.done, aad = aad)
512
513 ## Next, check decryption with underrun. If we must supply AAD first,
514 ## then the underrun will be reported when we start trying to encrypt;
515 ## otherwise, checking is delayed until `done'.
516 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
517 aad = dec.aad().hash(h[0:83])
518 if aecls.flags&C.AEADF_AADFIRST:
519 me.assertRaises(ValueError, dec.decrypt, c)
520 else:
521 _ = dec.decrypt(c)
522 me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
523
524 ## If AAD processing is nonce-dependent then an overrun will be
525 ## detected imediately.
526 if aecls.flags&C.AEADF_AADNDEP:
527 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
528 aad = enc.aad().hash(h[0:83])
529 me.assertRaises(ValueError, aad.hash, h[82:131])
530 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
531 aad = dec.aad().hash(h[0:83])
532 me.assertRaises(ValueError, aad.hash, h[82:131])
533
534 ## Some additional tests for nonce-dependent `aad' objects.
535 if aecls.flags&C.AEADF_AADNDEP:
536
537 ## Check that `aad' objects can't be used once their parents are gone.
538 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
539 aad = enc.aad()
540 del enc
541 me.assertRaises(ValueError, aad.hash, h)
542
543 ## Check that they can't be crossed over.
544 enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
545 enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
546 enc0.aad().hash(h)
547 aad1 = enc1.aad().hash(h)
548 _ = enc0.encrypt(m)
549 me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
550
551 ## Test copying AAD.
552 if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
553 aad0 = key.aad()
554 aad0.hash(h[0:83])
555 aad1 = aad0.copy()
556 aad2 = aad1.copy()
557 aad0.hash(h[83:131])
558 aad1.hash(h[83:131])
559 aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
560 me.assertEqual(key.enc(nonce = n, hsz = len(h),
561 msz = 0, tsz = tsz).done(aad = aad0),
562 key.enc(nonce = n, hsz = len(h),
563 msz = 0, tsz = tsz).done(aad = aad1))
564 me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
565 msz = 0, tsz = tsz).done(aad = aad0),
566 key.enc(nonce = n, hsz = len(h),
567 msz = 0, tsz = tsz).done(aad = aad2))
568
569 ## Check that if we precommit to the message length, we're properly held
570 ## to the commitment. (Fortunately, this is way simpler than the AAD
571 ## case above.) First, try an underrun.
572 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
573 _ = enc.encrypt(m[0:183])
574 me.assertRaises(ValueError, enc.done, tsz = tsz)
575 dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
576 _ = dec.decrypt(c[0:183])
577 me.assertRaises(ValueError, dec.done, tag = t)
578
579 ## And now an overrun.
580 enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
581 me.assertRaises(ValueError, enc.encrypt, m)
582 dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
583 me.assertRaises(ValueError, dec.decrypt, c)
584
585 ## Finally, check that if we precommit to a tag length, we're properly
586 ## held to the commitment. This depends on being able to find a tag size
587 ## which isn't the default.
588 tsz1 = different_key_size(aecls.tagsz, tsz)
589 if tsz1 is not None:
590 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
591 _ = enc.encrypt(m)
592 me.assertRaises(ValueError, enc.done, tsz = tsz)
593 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
594 aad = dec.aad().hash(h)
595 _ = dec.decrypt(c)
596 me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
597
598TestAuthenticatedEncryption.generate_testcases \
599 ((name, C.gcaeads[name]) for name in
600 ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
601 "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
602
603###--------------------------------------------------------------------------
553d59fe
MW
604class BaseTestHash (HashBufferTestMixin):
605 """Base class for testing hash functions."""
606
607 def check_hash(me, hcls, need_bufsz = True):
608 """
609 Check hash class HCLS.
610
71574cba
MW
611 If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz'
612 attribute. This test is mostly reused for MACs, which don't have this
613 attribute.
553d59fe
MW
614 """
615 ## Check the class properties.
71574cba
MW
616 me.assertEqual(type(hcls.name), str)
617 if need_bufsz: me.assertEqual(type(hcls.bufsz), int)
618 me.assertEqual(type(hcls.hashsz), int)
553d59fe
MW
619
620 ## Set some initial values.
621 m = T.span(131)
622 h = hcls().hash(m).done()
623
624 ## Check that hash length comes out right.
71574cba 625 me.assertEqual(len(h), hcls.hashsz)
553d59fe
MW
626
627 ## Check that we get the same answer if we split the message up.
628 me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
629
630 ## Check the `check' method.
631 me.assertTrue(hcls().hash(m).check(h))
71574cba 632 me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa")))
553d59fe
MW
633
634 ## Check the menagerie of random hashing methods.
635 def mkhash(_):
636 h = hcls()
637 return h, h.done
638 me.check_hashbuffer(mkhash)
639
640class TestHash (BaseTestHash, T.GenericTestMixin):
641 """Test hash functions."""
642 def _test_hash(me, hcls): me.check_hash(hcls, need_bufsz = True)
643
644TestHash.generate_testcases((name, C.gchashes[name]) for name in
645 ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
646 "crc32"])
647
648###--------------------------------------------------------------------------
649class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
650 """Test message authentication codes."""
651
652 def _test_mac(me, mcls):
653
654 ## Check the MAC properties.
655 me.assertEqual(type(mcls.name), str)
656 me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
657 me.assertEqual(type(mcls.tagsz), int)
658
659 ## Test hashing.
660 k = T.span(mcls.keysz.default)
661 key = mcls(k)
71574cba 662 me.assertEqual(key.hashsz, key.tagsz)
553d59fe
MW
663 me.check_hash(key, need_bufsz = False)
664
665 ## Check that bad key lengths are rejected.
666 badlen = bad_key_size(mcls.keysz)
667 if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
668
669TestMessageAuthentication.generate_testcases \
670 ((name, C.gcmacs[name]) for name in
671 ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
672
673class TestPoly1305 (HashBufferTestMixin):
674 """Check the Poly1305 one-time message authentication function."""
675
676 def test_poly1305(me):
677
678 ## Check the MAC properties.
679 me.assertEqual(C.poly1305.name, "poly1305")
680 me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
681 me.assertEqual(C.poly1305.keysz.default, 16)
fd46a88d 682 me.assertEqual(C.poly1305.keysz.set, set([16]))
553d59fe
MW
683 me.assertEqual(C.poly1305.tagsz, 16)
684 me.assertEqual(C.poly1305.masksz, 16)
685
686 ## Set some initial values.
687 k = T.span(16)
688 u = T.span(64)[-16:]
689 m = T.span(149)
690 key = C.poly1305(k)
691 t = key(u).hash(m).done()
692
693 ## Check the key properties.
71574cba
MW
694 me.assertEqual(key.name, "poly1305")
695 me.assertEqual(key.tagsz, 16)
696 me.assertEqual(key.tagsz, 16)
553d59fe
MW
697 me.assertEqual(len(t), 16)
698
699 ## Check that we get the same answer if we split the message up.
700 me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
701
702 ## Check the `check' method.
703 me.assertTrue(key(u).hash(m).check(t))
704 me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
705
706 ## Check the menagerie of random hashing methods.
707 def mkhash(_):
708 h = key(u)
709 return h, h.done
710 me.check_hashbuffer(mkhash)
711
712 ## Check that we can't complete hashing without a mask.
713 me.assertRaises(ValueError, key().hash(m).done)
714
715 ## Check `concat'.
716 h0 = key().hash(m[0:96])
717 h1 = key().hash(m[96:117])
718 me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
719 key1 = C.poly1305(k)
720 me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
721 me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
722 me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
723
724###--------------------------------------------------------------------------
725class TestHLatin (U.TestCase):
726 """Test the `hsalsa20' and `hchacha20' functions."""
727
728 def test_hlatin(me):
3b5f9ac0 729 kk = [T.span(sz) for sz in [10, 16, 32]]
553d59fe
MW
730 n = T.span(16)
731 bad_k = T.span(18)
732 bad_n = T.span(13)
733 for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
734 C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
735 for k in kk:
736 h = fn(k, n)
737 me.assertEqual(len(h), 32)
738 me.assertRaises(ValueError, fn, bad_k, n)
739 me.assertRaises(ValueError, fn, k, bad_n)
740
741###--------------------------------------------------------------------------
742class TestKeccak (HashBufferTestMixin):
743 """Test the Keccak-p[1600, n] sponge function."""
744
745 def test_keccak(me):
746
747 ## Make a state and feed some stuff into it.
748 m0 = T.bin("some initial string")
749 m1 = T.bin("awesome follow-up string")
750 st0 = C.Keccak1600()
751 me.assertEqual(st0.nround, 24)
752 st0.mix(m0).step()
753
754 ## Make another step with a different round count.
755 st1 = C.Keccak1600(23)
756 st1.mix(m0).step()
757 me.assertNotEqual(st0.extract(32), st1.extract(32))
758
df0bd0f9 759 ## Check state copying and `mix' vs `set'.
dad564fa
MW
760 st1 = st0.copy()
761 mask = st1.extract(len(m1))
762 st0.mix(m1)
df0bd0f9 763 st1.set(m1 ^ mask)
dad564fa
MW
764 me.assertEqual(st0.extract(32), st1.extract(32))
765
553d59fe
MW
766 ## Check error conditions.
767 _ = st0.extract(200)
768 me.assertRaises(ValueError, st0.extract, 201)
769 st0.mix(T.span(200))
770 me.assertRaises(ValueError, st0.mix, T.span(201))
df0bd0f9
MW
771 st0.set(T.span(200))
772 me.assertRaises(ValueError, st0.set, T.span(201))
773 me.assertRaises(ValueError, st0.set, T.span(199))
553d59fe
MW
774
775 def check_shake(me, xcls, c, done_matches_xof = True):
776 """
777 Test the SHAKE and cSHAKE XOFs.
778
779 This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
780 to indicate that the XOF output is range-separated from the fixed-length
781 outputs (unlike the basic SHAKE functions).
782 """
783
784 ## Check the hash attributes.
785 x = xcls()
786 me.assertEqual(x.rate, 200 - c)
787 me.assertEqual(x.buffered, 0)
788 me.assertEqual(x.state, "absorb")
789
790 ## Set some initial values.
791 func = T.bin("TESTXOF")
792 perso = T.bin("catacomb-python test")
793 m = T.span(167)
794 h0 = xcls().hash(m).done(193)
795 me.assertEqual(len(h0), 193)
796 h1 = xcls(func = func, perso = perso).hash(m).done(193)
797 me.assertEqual(len(h1), 193)
798 me.assertNotEqual(h0, h1)
799
800 ## Check input and output in pieces, and the state machine.
801 if done_matches_xof: h = h0
802 else: h = xcls().hash(m).xof().get(len(h0))
803 x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
804 me.assertEqual(h, x.get(98) + x.get(95))
805
806 ## Check masking.
807 x = xcls().hash(m).xof()
49e3a113 808 me.assertEqual(x.mask(m), m ^ h[0:len(m)])
553d59fe
MW
809
810 ## Check the `check' method.
811 me.assertTrue(xcls().hash(m).check(h0))
812 me.assertFalse(xcls().hash(m).check(h1))
813
814 ## Check the menagerie of random hashing methods.
815 def mkhash(_):
816 x = xcls(func = func, perso = perso)
aaf8c7aa 817 return x, x.done
553d59fe
MW
818 me.check_hashbuffer(mkhash)
819
820 ## Check the state machine tracking.
821 x = xcls(); me.assertEqual(x.state, "absorb")
822 x.hash(m); me.assertEqual(x.state, "absorb")
823 xx = x.copy()
aaf8c7aa 824 h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
553d59fe
MW
825 me.assertEqual(xx.state, "dead")
826 me.assertRaises(ValueError, xx.done, 1)
827 me.assertRaises(ValueError, xx.get, 1)
828 me.assertEqual(x.state, "absorb")
829 me.assertRaises(ValueError, x.get, 1)
830 x.xof(); me.assertEqual(x.state, "squeeze")
831 me.assertRaises(ValueError, x.done, 1)
832 _ = x.get(1)
833 yy = x.copy(); me.assertEqual(yy.state, "squeeze")
834
835 def test_shake128(me): me.check_shake(C.Shake128, 32)
836 def test_shake256(me): me.check_shake(C.Shake256, 64)
837
838 def check_kmac(me, mcls, c):
839 k = T.span(32)
04b75d71 840 me.check_shake(lambda func = None, perso = None:
553d59fe
MW
841 mcls(k, perso = perso),
842 c, done_matches_xof = False)
843
844 def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
845 def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
846
847###--------------------------------------------------------------------------
848class TestPRP (T.GenericTestMixin):
849 """Test pseudorandom permutations (PRPs)."""
850
851 def _test_prp(me, pcls):
852
853 ## Check the PRP properties.
854 me.assertEqual(type(pcls.name), str)
855 me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
856 me.assertEqual(type(pcls.blksz), int)
857
858 ## Check round-tripping.
859 k = T.span(pcls.keysz.default)
860 key = pcls(k)
861 m = T.span(pcls.blksz)
862 c = key.encrypt(m)
863 me.assertEqual(len(c), pcls.blksz)
864 me.assertEqual(m, key.decrypt(c))
865
866 ## Check that bad key lengths are rejected.
867 badlen = bad_key_size(pcls.keysz)
868 if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
869
870 ## Check that bad blocks are rejected.
871 badblk = T.span(pcls.blksz + 1)
872 me.assertRaises(ValueError, key.encrypt, badblk)
873 me.assertRaises(ValueError, key.decrypt, badblk)
874
875TestPRP.generate_testcases((name, C.gcprps[name]) for name in
876 ["desx", "blowfish", "rijndael"])
877
878###----- That's all, folks --------------------------------------------------
879
880if __name__ == "__main__": U.main()