algorithms.c (KeySZSet): Represent the set as an actual (frozen) set.
[catacomb-python] / t / t-algorithms.py
1 ### -*- mode: python, coding: utf-8 -*-
2 ###
3 ### Test symmetric algorithms
4 ###
5 ### (c) 2019 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software: you can redistribute it and/or
13 ### modify it under the terms of the GNU General Public License as
14 ### published by the Free Software Foundation; either version 2 of the
15 ### License, or (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful, but
18 ### WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 ### General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python. If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imported modules.
29
30 import catacomb as C
31 import unittest as U
32 import testutils as T
33
34 ###--------------------------------------------------------------------------
35 ### Utilities.
36
37 def bad_key_size(ksz):
38 if isinstance(ksz, C.KeySZAny): return None
39 elif isinstance(ksz, C.KeySZRange):
40 if ksz.mod != 1: return ksz.min + 1
41 elif ksz.max is not None: return ksz.max + 1
42 elif ksz.min != 0: return ksz.min - 1
43 else: return None
44 elif isinstance(ksz, C.KeySZSet):
45 for sz in sorted(ksz.set):
46 if sz + 1 not in ksz.set: return sz + 1
47 assert False, "That should have worked."
48 else:
49 return None
50
51 def different_key_size(ksz, sz):
52 if isinstance(ksz, C.KeySZAny): return sz + 1
53 elif isinstance(ksz, C.KeySZRange):
54 if sz > ksz.min: return sz - ksz.mod
55 elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
56 else: return None
57 elif isinstance(ksz, C.KeySZSet):
58 for sz1 in sorted(ksz.set):
59 if sz != sz1: return sz1
60 return None
61 else:
62 return None
63
64 class HashBufferTestMixin (U.TestCase):
65 """Mixin class for testing all of the various `hash...' methods."""
66
67 def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68 """Check `hashuN'."""
69
70 ## Check encoding an integer.
71 h0, donefn0 = makefn(w + 2)
72 hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73 h1, donefn1 = makefn(w + 2)
74 h1.hash(T.span(w + 2))
75 me.assertEqual(donefn0(), donefn1())
76
77 ## Check overflow detection.
78 h0, _ = makefn(w)
79 me.assertRaises((OverflowError, ValueError),
80 hashfn, h0, 1 << 8*w)
81
82 def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83 """Check `hashbufN'."""
84
85 ## Go through a number of different sizes.
86 for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87 if n >= 1 << 8*w: continue
88 h0, donefn0 = makefn(2 + w + n)
89 hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90 h1, donefn1 = makefn(2 + w + n)
91 h1.hash(T.prep_lenseq(w, n, bigendp, True))
92 me.assertEqual(donefn0(), donefn1())
93
94 ## Check blocks which are too large for the length prefix.
95 if w <= 3:
96 n = 1 << 8*w
97 h0, _ = makefn(w + n)
98 me.assertRaises((ValueError, OverflowError, TypeError),
99 hashfn, h0, C.ByteString.zero(n))
100
101 def check_hashbuffer(me, makefn):
102 """Test the various `hash...' methods."""
103
104 ## Check `hashuN'.
105 me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
106 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
107 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
108 me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
109 if hasattr(makefn(0)[0], "hashu24"):
110 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
111 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
112 me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
113 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
114 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
115 me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
116 if hasattr(makefn(0)[0], "hashu64"):
117 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
118 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
119 me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
120
121 ## Check `hashbufN'.
122 me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
123 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
124 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
125 me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
126 if hasattr(makefn(0)[0], "hashbuf24"):
127 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
128 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
129 me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
130 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
131 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
132 me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
133 if hasattr(makefn(0)[0], "hashbuf64"):
134 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
135 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
136 me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
137
138 ###--------------------------------------------------------------------------
139 class TestKeysize (U.TestCase):
140
141 def test_any(me):
142
143 ## A typical one-byte spec.
144 ksz = C.seal.keysz
145 me.assertEqual(type(ksz), C.KeySZAny)
146 me.assertEqual(ksz.default, 20)
147 me.assertEqual(ksz.min, 0)
148 me.assertEqual(ksz.max, None)
149 for n in [0, 12, 20, 5000]:
150 me.assertTrue(ksz.check(n))
151 me.assertEqual(ksz.best(n), n)
152 me.assertEqual(ksz.pad(n), n)
153
154 ## A typical two-byte spec. (No published algorithms actually /need/ a
155 ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
156 ksz = C.sha256_hmac.keysz
157 me.assertEqual(type(ksz), C.KeySZAny)
158 me.assertEqual(ksz.default, 32)
159 me.assertEqual(ksz.min, 0)
160 me.assertEqual(ksz.max, None)
161 for n in [0, 12, 20, 5000]:
162 me.assertTrue(ksz.check(n))
163 me.assertEqual(ksz.best(n), n)
164 me.assertEqual(ksz.pad(n), n)
165
166 ## Check construction.
167 ksz = C.KeySZAny(15)
168 me.assertEqual(ksz.default, 15)
169 me.assertEqual(ksz.min, 0)
170 me.assertEqual(ksz.max, None)
171 me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
172 me.assertEqual(C.KeySZAny(0).default, 0)
173
174 def test_set(me):
175 ## Note that no published algorithm uses a 16-bit `set' spec.
176
177 ## A typical spec.
178 ksz = C.salsa20.keysz
179 me.assertEqual(type(ksz), C.KeySZSet)
180 me.assertEqual(ksz.default, 32)
181 me.assertEqual(ksz.min, 10)
182 me.assertEqual(ksz.max, 32)
183 me.assertEqual(ksz.set, set([10, 16, 32]))
184 for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
185 (15, 10, 16), (16, 16, 16), (17, 16, 32),
186 (31, 16, 32), (32, 32, 32), (33, 32, None)]:
187 if x == best == pad: me.assertTrue(ksz.check(x))
188 else: me.assertFalse(ksz.check(x))
189 if best is None: me.assertRaises(ValueError, ksz.best, x)
190 else: me.assertEqual(ksz.best(x), best)
191 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
192 else: me.assertEqual(ksz.pad(x), pad)
193
194 ## Check construction.
195 ksz = C.KeySZSet(7)
196 me.assertEqual(ksz.default, 7)
197 me.assertEqual(ksz.set, set([7]))
198 me.assertEqual(ksz.min, 7)
199 me.assertEqual(ksz.max, 7)
200 ksz = C.KeySZSet(7, iter([3, 6, 9]))
201 me.assertEqual(ksz.default, 7)
202 me.assertEqual(ksz.set, set([3, 6, 7, 9]))
203 me.assertEqual(ksz.min, 3)
204 me.assertEqual(ksz.max, 9)
205
206 def test_range(me):
207 ## Note that no published algorithm uses a 16-bit `range' spec, or an
208 ## unbounded `range'.
209
210 ## A typical spec.
211 ksz = C.rijndael.keysz
212 me.assertEqual(type(ksz), C.KeySZRange)
213 me.assertEqual(ksz.default, 32)
214 me.assertEqual(ksz.min, 4)
215 me.assertEqual(ksz.max, 32)
216 me.assertEqual(ksz.mod, 4)
217 for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
218 (15, 12, 16), (16, 16, 16), (17, 16, 20),
219 (31, 28, 32), (32, 32, 32), (33, 32, None)]:
220 if x == best == pad: me.assertTrue(ksz.check(x))
221 else: me.assertFalse(ksz.check(x))
222 if best is None: me.assertRaises(ValueError, ksz.best, x)
223 else: me.assertEqual(ksz.best(x), best)
224 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
225 else: me.assertEqual(ksz.pad(x), pad)
226
227 ## Check construction.
228 ksz = C.KeySZRange(28, 21, 35, 7)
229 me.assertEqual(ksz.default, 28)
230 me.assertEqual(ksz.min, 21)
231 me.assertEqual(ksz.max, 35)
232 me.assertEqual(ksz.mod, 7)
233 ksz = C.KeySZRange(28, 21, None, 7)
234 me.assertEqual(ksz.min, 21)
235 me.assertEqual(ksz.max, None)
236 me.assertEqual(ksz.mod, 7)
237 me.assertEqual(ksz.pad(36), 42)
238 me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
239 me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
240 me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
241 me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
242 me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
243 me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
244 me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
245
246 def test_conversions(me):
247 me.assertEqual(C.KeySZ.fromec(256), 128)
248 me.assertEqual(C.KeySZ.fromschnorr(256), 128)
249 me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
250 me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
251 me.assertEqual(C.KeySZ.toec(128), 256)
252 me.assertEqual(C.KeySZ.toschnorr(128), 256)
253 me.assertEqual(C.KeySZ.todl(128), 2958.6875)
254 me.assertEqual(C.KeySZ.toif(128), 2958.6875)
255
256 ###--------------------------------------------------------------------------
257 class TestCipher (T.GenericTestMixin):
258 """Test basic symmetric ciphers."""
259
260 def _test_cipher(me, ccls):
261
262 ## Check the class properties.
263 me.assertEqual(type(ccls.name), str)
264 me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
265 me.assertEqual(type(ccls.blksz), int)
266
267 ## Check round-tripping.
268 k = T.span(ccls.keysz.default)
269 iv = T.span(ccls.blksz)
270 m = T.span(253)
271 enc = ccls(k)
272 dec = ccls(k)
273 try: enc.setiv(iv)
274 except ValueError: can_setiv = False
275 else:
276 can_setiv = True
277 dec.setiv(iv)
278 c0 = enc.encrypt(m[0:57])
279 m0 = dec.decrypt(c0)
280 c1 = enc.encrypt(m[57:189])
281 m1 = dec.decrypt(c1)
282 try: enc.bdry()
283 except ValueError: can_bdry = False
284 else:
285 dec.bdry()
286 can_bdry = True
287 c2 = enc.encrypt(m[189:253])
288 m2 = dec.decrypt(c2)
289 me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
290 me.assertEqual(m0, m[0:57])
291 me.assertEqual(m1, m[57:189])
292 me.assertEqual(m2, m[189:253])
293
294 ## Check the `enczero' and `deczero' methods.
295 c3 = enc.enczero(32)
296 me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
297 m4 = dec.deczero(32)
298 me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
299
300 ## Check that ciphers which support a `boundary' operation actually
301 ## need it.
302 if can_bdry:
303 dec = ccls(k)
304 if can_setiv: dec.setiv(iv)
305 m01 = dec.decrypt(c0 + c1)
306 me.assertEqual(m01, m[0:189])
307
308 ## Check that the boundary actually does something.
309 if can_bdry:
310 dec = ccls(k)
311 if can_setiv: dec.setiv(iv)
312 m012 = dec.decrypt(c0 + c1 + c2)
313 me.assertNotEqual(m012, m)
314
315 ## Check that bad key lengths are rejected.
316 badlen = bad_key_size(ccls.keysz)
317 if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
318
319 TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
320 ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
321 "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
322
323 ###--------------------------------------------------------------------------
324 class TestAuthenticatedEncryption \
325 (HashBufferTestMixin, T.GenericTestMixin):
326 """Test authenticated encryption schemes."""
327
328 def _test_aead(me, aecls):
329
330 ## Check the class properties.
331 me.assertEqual(type(aecls.name), str)
332 me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
333 me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
334 me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
335 me.assertEqual(type(aecls.blksz), int)
336 me.assertEqual(type(aecls.bufsz), int)
337 me.assertEqual(type(aecls.ohd), int)
338 me.assertEqual(type(aecls.flags), int)
339
340 ## Check round-tripping, with full precommitment. First, select some
341 ## parameters. (It's conceivable that some AEAD schemes are more
342 ## restrictive than advertised by the various properties, but this works
343 ## out OK in practice.)
344 k = T.span(aecls.keysz.default)
345 n = T.span(aecls.noncesz.default)
346 if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
347 else: h = T.span(131)
348 m = T.span(253)
349 tsz = aecls.tagsz.default
350 key = aecls(k)
351
352 ## Next, encrypt a message, checking that things are proper as we go.
353 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
354 me.assertEqual(enc.hsz, len(h))
355 me.assertEqual(enc.msz, len(m))
356 me.assertEqual(enc.mlen, 0)
357 me.assertEqual(enc.tsz, tsz)
358 aad = enc.aad()
359 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
360 else: me.assertEqual(aad.hsz, None)
361 me.assertEqual(aad.hlen, 0)
362 if not aecls.flags&C.AEADF_NOAAD:
363 aad.hash(h[0:83])
364 me.assertEqual(aad.hlen, 83)
365 aad.hash(h[83:131])
366 me.assertEqual(aad.hlen, 131)
367 c0 = enc.encrypt(m[0:57])
368 me.assertEqual(enc.mlen, 57)
369 me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
370 c1 = enc.encrypt(m[57:189])
371 me.assertEqual(enc.mlen, 189)
372 me.assertTrue(132 - aecls.bufsz <= len(c1) <=
373 132 + aecls.bufsz + aecls.ohd)
374 c2 = enc.encrypt(m[189:253])
375 me.assertEqual(enc.mlen, 253)
376 me.assertTrue(64 - aecls.bufsz <= len(c2) <=
377 64 + aecls.bufsz + aecls.ohd)
378 c3, t = enc.done(aad = aad)
379 me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
380 c = c0 + c1 + c2 + c3
381 me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
382 me.assertEqual(len(t), tsz)
383
384 ## And now decrypt it again, with different record boundaries.
385 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
386 me.assertEqual(dec.hsz, len(h))
387 me.assertEqual(dec.csz, len(c))
388 me.assertEqual(dec.clen, 0)
389 me.assertEqual(dec.tsz, tsz)
390 aad = dec.aad()
391 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
392 else: me.assertEqual(aad.hsz, None)
393 me.assertEqual(aad.hlen, 0)
394 aad.hash(h)
395 m0 = dec.decrypt(c[0:156])
396 me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
397 m1 = dec.decrypt(c[156:])
398 me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
399 len(c) - 156 + aecls.bufsz)
400 m2 = dec.done(tag = t, aad = aad)
401 me.assertEqual(m0 + m1 + m2, m)
402
403 ## And again, with the wrong tag.
404 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
405 aad = dec.aad(); aad.hash(h)
406 _ = dec.decrypt(c)
407 me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
408
409 ## Check that the all-in-one methods work.
410 me.assertEqual((c, t),
411 key.encrypt(n = n, h = h, m = m, tsz = tsz))
412 me.assertEqual(m,
413 key.decrypt(n = n, h = h, c = c, t = t))
414
415 ## Check that bad key, nonce, and tag lengths are rejected.
416 badlen = bad_key_size(aecls.keysz)
417 if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
418 badlen = bad_key_size(aecls.noncesz)
419 if badlen is not None:
420 me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
421 hsz = len(h), msz = len(m), tsz = tsz)
422 me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
423 hsz = len(h), csz = len(c), tsz = tsz)
424 if not aecls.flags&C.AEADF_PCTSZ:
425 enc = key.enc(nonce = n, hsz = 0, msz = len(m))
426 _ = enc.encrypt(m)
427 me.assertRaises(ValueError, enc.done, tsz = badlen)
428 badlen = bad_key_size(aecls.tagsz)
429 if badlen is not None:
430 me.assertRaises(ValueError, key.enc, nonce = n,
431 hsz = len(h), msz = len(m), tsz = badlen)
432 me.assertRaises(ValueError, key.dec, nonce = n,
433 hsz = len(h), csz = len(c), tsz = badlen)
434
435 ## Check that we can't get a loose `aad' object from a scheme which has
436 ## nonce-dependent AAD processing.
437 if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
438
439 ## Check the menagerie of AAD hashing methods.
440 if not aecls.flags&C.AEADF_NOAAD:
441 def mkhash(hsz):
442 enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
443 aad = enc.aad()
444 return aad, lambda: enc.done(aad = aad)[1]
445 me.check_hashbuffer(mkhash)
446
447 ## Check that encryption/decryption works with the given precommitments.
448 def quick_enc_check(**kw):
449 enc = key.enc(**kw)
450 aad = enc.aad().hash(h)
451 c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
452 me.assertEqual((c, t), (c0 + c1, tt))
453 def quick_dec_check(**kw):
454 dec = key.dec(**kw)
455 aad = dec.aad().hash(h)
456 m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
457 me.assertEqual(m, m0 + m1)
458
459 ## Check that we can get away without precommitting to the header length
460 ## if and only if the AEAD scheme says it will let us.
461 if aecls.flags&C.AEADF_PCHSZ:
462 me.assertRaises(ValueError, key.enc, nonce = n,
463 msz = len(m), tsz = tsz)
464 me.assertRaises(ValueError, key.dec, nonce = n,
465 csz = len(c), tsz = tsz)
466 else:
467 quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
468 quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
469
470 ## Check that we can get away without precommitting to the message/
471 ## ciphertext length if and only if the AEAD scheme says it will let us.
472 if aecls.flags&C.AEADF_PCMSZ:
473 me.assertRaises(ValueError, key.enc, nonce = n,
474 hsz = len(h), tsz = tsz)
475 me.assertRaises(ValueError, key.dec, nonce = n,
476 hsz = len(h), tsz = tsz)
477 else:
478 quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
479 quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
480
481 ## Check that we can get away without precommitting to the tag length if
482 ## and only if the AEAD scheme says it will let us.
483 if aecls.flags&C.AEADF_PCTSZ:
484 me.assertRaises(ValueError, key.enc, nonce = n,
485 hsz = len(h), msz = len(m))
486 me.assertRaises(ValueError, key.dec, nonce = n,
487 hsz = len(h), csz = len(c))
488 else:
489 quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
490 quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
491
492 ## Check that if we precommit to the header length, we're properly held
493 ## to the commitment.
494 if not aecls.flags&C.AEADF_NOAAD:
495
496 ## First, check encryption with underrun. If we must supply AAD first,
497 ## then the underrun will be reported when we start trying to encrypt;
498 ## otherwise, checking is delayed until `done'.
499 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
500 aad = enc.aad().hash(h[0:83])
501 if aecls.flags&C.AEADF_AADFIRST:
502 me.assertRaises(ValueError, enc.encrypt, m)
503 else:
504 _ = enc.encrypt(m)
505 me.assertRaises(ValueError, enc.done, aad = aad)
506
507 ## Next, check decryption with underrun. If we must supply AAD first,
508 ## then the underrun will be reported when we start trying to encrypt;
509 ## otherwise, checking is delayed until `done'.
510 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
511 aad = dec.aad().hash(h[0:83])
512 if aecls.flags&C.AEADF_AADFIRST:
513 me.assertRaises(ValueError, dec.decrypt, c)
514 else:
515 _ = dec.decrypt(c)
516 me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
517
518 ## If AAD processing is nonce-dependent then an overrun will be
519 ## detected imediately.
520 if aecls.flags&C.AEADF_AADNDEP:
521 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
522 aad = enc.aad().hash(h[0:83])
523 me.assertRaises(ValueError, aad.hash, h[82:131])
524 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
525 aad = dec.aad().hash(h[0:83])
526 me.assertRaises(ValueError, aad.hash, h[82:131])
527
528 ## Some additional tests for nonce-dependent `aad' objects.
529 if aecls.flags&C.AEADF_AADNDEP:
530
531 ## Check that `aad' objects can't be used once their parents are gone.
532 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
533 aad = enc.aad()
534 del enc
535 me.assertRaises(ValueError, aad.hash, h)
536
537 ## Check that they can't be crossed over.
538 enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
539 enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
540 enc0.aad().hash(h)
541 aad1 = enc1.aad().hash(h)
542 _ = enc0.encrypt(m)
543 me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
544
545 ## Test copying AAD.
546 if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
547 aad0 = key.aad()
548 aad0.hash(h[0:83])
549 aad1 = aad0.copy()
550 aad2 = aad1.copy()
551 aad0.hash(h[83:131])
552 aad1.hash(h[83:131])
553 aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
554 me.assertEqual(key.enc(nonce = n, hsz = len(h),
555 msz = 0, tsz = tsz).done(aad = aad0),
556 key.enc(nonce = n, hsz = len(h),
557 msz = 0, tsz = tsz).done(aad = aad1))
558 me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
559 msz = 0, tsz = tsz).done(aad = aad0),
560 key.enc(nonce = n, hsz = len(h),
561 msz = 0, tsz = tsz).done(aad = aad2))
562
563 ## Check that if we precommit to the message length, we're properly held
564 ## to the commitment. (Fortunately, this is way simpler than the AAD
565 ## case above.) First, try an underrun.
566 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
567 _ = enc.encrypt(m[0:183])
568 me.assertRaises(ValueError, enc.done, tsz = tsz)
569 dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
570 _ = dec.decrypt(c[0:183])
571 me.assertRaises(ValueError, dec.done, tag = t)
572
573 ## And now an overrun.
574 enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
575 me.assertRaises(ValueError, enc.encrypt, m)
576 dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
577 me.assertRaises(ValueError, dec.decrypt, c)
578
579 ## Finally, check that if we precommit to a tag length, we're properly
580 ## held to the commitment. This depends on being able to find a tag size
581 ## which isn't the default.
582 tsz1 = different_key_size(aecls.tagsz, tsz)
583 if tsz1 is not None:
584 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
585 _ = enc.encrypt(m)
586 me.assertRaises(ValueError, enc.done, tsz = tsz)
587 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
588 aad = dec.aad().hash(h)
589 _ = dec.decrypt(c)
590 me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
591
592 TestAuthenticatedEncryption.generate_testcases \
593 ((name, C.gcaeads[name]) for name in
594 ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
595 "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
596
597 ###--------------------------------------------------------------------------
598 class BaseTestHash (HashBufferTestMixin):
599 """Base class for testing hash functions."""
600
601 def check_hash(me, hcls, need_bufsz = True):
602 """
603 Check hash class HCLS.
604
605 If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
606 `name', or `hashsz' attributes. This test is mostly reused for MACs,
607 which don't have these attributes.
608 """
609 ## Check the class properties.
610 if need_bufsz:
611 me.assertEqual(type(hcls.name), str)
612 me.assertEqual(type(hcls.bufsz), int)
613 me.assertEqual(type(hcls.hashsz), int)
614
615 ## Set some initial values.
616 m = T.span(131)
617 h = hcls().hash(m).done()
618
619 ## Check that hash length comes out right.
620 if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
621
622 ## Check that we get the same answer if we split the message up.
623 me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
624
625 ## Check the `check' method.
626 me.assertTrue(hcls().hash(m).check(h))
627 me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
628
629 ## Check the menagerie of random hashing methods.
630 def mkhash(_):
631 h = hcls()
632 return h, h.done
633 me.check_hashbuffer(mkhash)
634
635 class TestHash (BaseTestHash, T.GenericTestMixin):
636 """Test hash functions."""
637 def _test_hash(me, hcls): me.check_hash(hcls, need_bufsz = True)
638
639 TestHash.generate_testcases((name, C.gchashes[name]) for name in
640 ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
641 "crc32"])
642
643 ###--------------------------------------------------------------------------
644 class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
645 """Test message authentication codes."""
646
647 def _test_mac(me, mcls):
648
649 ## Check the MAC properties.
650 me.assertEqual(type(mcls.name), str)
651 me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
652 me.assertEqual(type(mcls.tagsz), int)
653
654 ## Test hashing.
655 k = T.span(mcls.keysz.default)
656 key = mcls(k)
657 me.check_hash(key, need_bufsz = False)
658
659 ## Check that bad key lengths are rejected.
660 badlen = bad_key_size(mcls.keysz)
661 if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
662
663 TestMessageAuthentication.generate_testcases \
664 ((name, C.gcmacs[name]) for name in
665 ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
666
667 class TestPoly1305 (HashBufferTestMixin):
668 """Check the Poly1305 one-time message authentication function."""
669
670 def test_poly1305(me):
671
672 ## Check the MAC properties.
673 me.assertEqual(C.poly1305.name, "poly1305")
674 me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
675 me.assertEqual(C.poly1305.keysz.default, 16)
676 me.assertEqual(C.poly1305.keysz.set, set([16]))
677 me.assertEqual(C.poly1305.tagsz, 16)
678 me.assertEqual(C.poly1305.masksz, 16)
679
680 ## Set some initial values.
681 k = T.span(16)
682 u = T.span(64)[-16:]
683 m = T.span(149)
684 key = C.poly1305(k)
685 t = key(u).hash(m).done()
686
687 ## Check the key properties.
688 me.assertEqual(len(t), 16)
689
690 ## Check that we get the same answer if we split the message up.
691 me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
692
693 ## Check the `check' method.
694 me.assertTrue(key(u).hash(m).check(t))
695 me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
696
697 ## Check the menagerie of random hashing methods.
698 def mkhash(_):
699 h = key(u)
700 return h, h.done
701 me.check_hashbuffer(mkhash)
702
703 ## Check that we can't complete hashing without a mask.
704 me.assertRaises(ValueError, key().hash(m).done)
705
706 ## Check `concat'.
707 h0 = key().hash(m[0:96])
708 h1 = key().hash(m[96:117])
709 me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
710 key1 = C.poly1305(k)
711 me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
712 me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
713 me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
714
715 ###--------------------------------------------------------------------------
716 class TestHLatin (U.TestCase):
717 """Test the `hsalsa20' and `hchacha20' functions."""
718
719 def test_hlatin(me):
720 kk = [T.span(sz) for sz in [10, 16, 32]]
721 n = T.span(16)
722 bad_k = T.span(18)
723 bad_n = T.span(13)
724 for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
725 C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
726 for k in kk:
727 h = fn(k, n)
728 me.assertEqual(len(h), 32)
729 me.assertRaises(ValueError, fn, bad_k, n)
730 me.assertRaises(ValueError, fn, k, bad_n)
731
732 ###--------------------------------------------------------------------------
733 class TestKeccak (HashBufferTestMixin):
734 """Test the Keccak-p[1600, n] sponge function."""
735
736 def test_keccak(me):
737
738 ## Make a state and feed some stuff into it.
739 m0 = T.bin("some initial string")
740 m1 = T.bin("awesome follow-up string")
741 st0 = C.Keccak1600()
742 me.assertEqual(st0.nround, 24)
743 st0.mix(m0).step()
744
745 ## Make another step with a different round count.
746 st1 = C.Keccak1600(23)
747 st1.mix(m0).step()
748 me.assertNotEqual(st0.extract(32), st1.extract(32))
749
750 ## Check state copying.
751 st1 = st0.copy()
752 mask = st1.extract(len(m1))
753 st0.mix(m1)
754 st1.mix(m1)
755 me.assertEqual(st0.extract(32), st1.extract(32))
756
757 ## Check error conditions.
758 _ = st0.extract(200)
759 me.assertRaises(ValueError, st0.extract, 201)
760 st0.mix(T.span(200))
761 me.assertRaises(ValueError, st0.mix, T.span(201))
762
763 def check_shake(me, xcls, c, done_matches_xof = True):
764 """
765 Test the SHAKE and cSHAKE XOFs.
766
767 This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
768 to indicate that the XOF output is range-separated from the fixed-length
769 outputs (unlike the basic SHAKE functions).
770 """
771
772 ## Check the hash attributes.
773 x = xcls()
774 me.assertEqual(x.rate, 200 - c)
775 me.assertEqual(x.buffered, 0)
776 me.assertEqual(x.state, "absorb")
777
778 ## Set some initial values.
779 func = T.bin("TESTXOF")
780 perso = T.bin("catacomb-python test")
781 m = T.span(167)
782 h0 = xcls().hash(m).done(193)
783 me.assertEqual(len(h0), 193)
784 h1 = xcls(func = func, perso = perso).hash(m).done(193)
785 me.assertEqual(len(h1), 193)
786 me.assertNotEqual(h0, h1)
787
788 ## Check input and output in pieces, and the state machine.
789 if done_matches_xof: h = h0
790 else: h = xcls().hash(m).xof().get(len(h0))
791 x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
792 me.assertEqual(h, x.get(98) + x.get(95))
793
794 ## Check masking.
795 x = xcls().hash(m).xof()
796 me.assertEqual(x.mask(m), m ^ h[0:len(m)])
797
798 ## Check the `check' method.
799 me.assertTrue(xcls().hash(m).check(h0))
800 me.assertFalse(xcls().hash(m).check(h1))
801
802 ## Check the menagerie of random hashing methods.
803 def mkhash(_):
804 x = xcls(func = func, perso = perso)
805 return x, x.done
806 me.check_hashbuffer(mkhash)
807
808 ## Check the state machine tracking.
809 x = xcls(); me.assertEqual(x.state, "absorb")
810 x.hash(m); me.assertEqual(x.state, "absorb")
811 xx = x.copy()
812 h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
813 me.assertEqual(xx.state, "dead")
814 me.assertRaises(ValueError, xx.done, 1)
815 me.assertRaises(ValueError, xx.get, 1)
816 me.assertEqual(x.state, "absorb")
817 me.assertRaises(ValueError, x.get, 1)
818 x.xof(); me.assertEqual(x.state, "squeeze")
819 me.assertRaises(ValueError, x.done, 1)
820 _ = x.get(1)
821 yy = x.copy(); me.assertEqual(yy.state, "squeeze")
822
823 def test_shake128(me): me.check_shake(C.Shake128, 32)
824 def test_shake256(me): me.check_shake(C.Shake256, 64)
825
826 def check_kmac(me, mcls, c):
827 k = T.span(32)
828 me.check_shake(lambda func = None, perso = None:
829 mcls(k, perso = perso),
830 c, done_matches_xof = False)
831
832 def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
833 def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
834
835 ###--------------------------------------------------------------------------
836 class TestPRP (T.GenericTestMixin):
837 """Test pseudorandom permutations (PRPs)."""
838
839 def _test_prp(me, pcls):
840
841 ## Check the PRP properties.
842 me.assertEqual(type(pcls.name), str)
843 me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
844 me.assertEqual(type(pcls.blksz), int)
845
846 ## Check round-tripping.
847 k = T.span(pcls.keysz.default)
848 key = pcls(k)
849 m = T.span(pcls.blksz)
850 c = key.encrypt(m)
851 me.assertEqual(len(c), pcls.blksz)
852 me.assertEqual(m, key.decrypt(c))
853
854 ## Check that bad key lengths are rejected.
855 badlen = bad_key_size(pcls.keysz)
856 if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
857
858 ## Check that bad blocks are rejected.
859 badblk = T.span(pcls.blksz + 1)
860 me.assertRaises(ValueError, key.encrypt, badblk)
861 me.assertRaises(ValueError, key.decrypt, badblk)
862
863 TestPRP.generate_testcases((name, C.gcprps[name]) for name in
864 ["desx", "blowfish", "rijndael"])
865
866 ###----- That's all, folks --------------------------------------------------
867
868 if __name__ == "__main__": U.main()