algorithms.c (FOO.hashbufN): Consistently raise `ValueError' if too big.
[catacomb-python] / t / t-algorithms.py
CommitLineData
553d59fe
MW
1### -*- mode: python, coding: utf-8 -*-
2###
3### Test symmetric algorithms
4###
5### (c) 2019 Straylight/Edgeware
6###
7
8###----- Licensing notice ---------------------------------------------------
9###
10### This file is part of the Python interface to Catacomb.
11###
12### Catacomb/Python is free software: you can redistribute it and/or
13### modify it under the terms of the GNU General Public License as
14### published by the Free Software Foundation; either version 2 of the
15### License, or (at your option) any later version.
16###
17### Catacomb/Python is distributed in the hope that it will be useful, but
18### WITHOUT ANY WARRANTY; without even the implied warranty of
19### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20### General Public License for more details.
21###
22### You should have received a copy of the GNU General Public License
23### along with Catacomb/Python. If not, write to the Free Software
24### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25### USA.
26
27###--------------------------------------------------------------------------
28### Imported modules.
29
30import catacomb as C
31import unittest as U
32import testutils as T
33
34###--------------------------------------------------------------------------
35### Utilities.
36
37def bad_key_size(ksz):
38 if isinstance(ksz, C.KeySZAny): return None
39 elif isinstance(ksz, C.KeySZRange):
40 if ksz.mod != 1: return ksz.min + 1
a9410db7 41 elif ksz.max is not None: return ksz.max + 1
553d59fe
MW
42 elif ksz.min != 0: return ksz.min - 1
43 else: return None
44 elif isinstance(ksz, C.KeySZSet):
45 for sz in sorted(ksz.set):
46 if sz + 1 not in ksz.set: return sz + 1
47 assert False, "That should have worked."
48 else:
49 return None
50
51def different_key_size(ksz, sz):
52 if isinstance(ksz, C.KeySZAny): return sz + 1
53 elif isinstance(ksz, C.KeySZRange):
54 if sz > ksz.min: return sz - ksz.mod
a9410db7 55 elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
553d59fe
MW
56 else: return None
57 elif isinstance(ksz, C.KeySZSet):
58 for sz1 in sorted(ksz.set):
59 if sz != sz1: return sz1
60 return None
61 else:
62 return None
63
64class HashBufferTestMixin (U.TestCase):
65 """Mixin class for testing all of the various `hash...' methods."""
66
67 def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68 """Check `hashuN'."""
69
70 ## Check encoding an integer.
71 h0, donefn0 = makefn(w + 2)
72 hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73 h1, donefn1 = makefn(w + 2)
74 h1.hash(T.span(w + 2))
75 me.assertEqual(donefn0(), donefn1())
76
77 ## Check overflow detection.
78 h0, _ = makefn(w)
79 me.assertRaises((OverflowError, ValueError),
80 hashfn, h0, 1 << 8*w)
81
82 def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83 """Check `hashbufN'."""
84
85 ## Go through a number of different sizes.
86 for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87 if n >= 1 << 8*w: continue
88 h0, donefn0 = makefn(2 + w + n)
89 hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90 h1, donefn1 = makefn(2 + w + n)
91 h1.hash(T.prep_lenseq(w, n, bigendp, True))
92 me.assertEqual(donefn0(), donefn1())
93
94 ## Check blocks which are too large for the length prefix.
95 if w <= 3:
96 n = 1 << 8*w
97 h0, _ = makefn(w + n)
e01ab3c9 98 me.assertRaises(ValueError, hashfn, h0, C.ByteString.zero(n))
553d59fe
MW
99
100 def check_hashbuffer(me, makefn):
101 """Test the various `hash...' methods."""
102
103 ## Check `hashuN'.
104 me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
105 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
106 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
107 me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
668b5f54
MW
108 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
109 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
110 me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
553d59fe
MW
111 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
112 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
113 me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
114 if hasattr(makefn(0)[0], "hashu64"):
115 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
116 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
117 me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
118
119 ## Check `hashbufN'.
120 me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
121 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
122 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
123 me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
668b5f54
MW
124 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
125 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
126 me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
553d59fe
MW
127 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
128 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
129 me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
130 if hasattr(makefn(0)[0], "hashbuf64"):
131 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
132 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
133 me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
134
135###--------------------------------------------------------------------------
136class TestKeysize (U.TestCase):
137
138 def test_any(me):
139
140 ## A typical one-byte spec.
141 ksz = C.seal.keysz
142 me.assertEqual(type(ksz), C.KeySZAny)
143 me.assertEqual(ksz.default, 20)
144 me.assertEqual(ksz.min, 0)
a9410db7 145 me.assertEqual(ksz.max, None)
553d59fe
MW
146 for n in [0, 12, 20, 5000]:
147 me.assertTrue(ksz.check(n))
148 me.assertEqual(ksz.best(n), n)
606677f6 149 me.assertEqual(ksz.pad(n), n)
553d59fe
MW
150
151 ## A typical two-byte spec. (No published algorithms actually /need/ a
152 ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
153 ksz = C.sha256_hmac.keysz
154 me.assertEqual(type(ksz), C.KeySZAny)
155 me.assertEqual(ksz.default, 32)
156 me.assertEqual(ksz.min, 0)
a9410db7 157 me.assertEqual(ksz.max, None)
553d59fe
MW
158 for n in [0, 12, 20, 5000]:
159 me.assertTrue(ksz.check(n))
160 me.assertEqual(ksz.best(n), n)
606677f6 161 me.assertEqual(ksz.pad(n), n)
553d59fe
MW
162
163 ## Check construction.
164 ksz = C.KeySZAny(15)
165 me.assertEqual(ksz.default, 15)
166 me.assertEqual(ksz.min, 0)
a9410db7 167 me.assertEqual(ksz.max, None)
553d59fe
MW
168 me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
169 me.assertEqual(C.KeySZAny(0).default, 0)
170
171 def test_set(me):
172 ## Note that no published algorithm uses a 16-bit `set' spec.
173
174 ## A typical spec.
175 ksz = C.salsa20.keysz
176 me.assertEqual(type(ksz), C.KeySZSet)
177 me.assertEqual(ksz.default, 32)
178 me.assertEqual(ksz.min, 10)
179 me.assertEqual(ksz.max, 32)
fd46a88d 180 me.assertEqual(ksz.set, set([10, 16, 32]))
553d59fe
MW
181 for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
182 (15, 10, 16), (16, 16, 16), (17, 16, 32),
183 (31, 16, 32), (32, 32, 32), (33, 32, None)]:
184 if x == best == pad: me.assertTrue(ksz.check(x))
185 else: me.assertFalse(ksz.check(x))
186 if best is None: me.assertRaises(ValueError, ksz.best, x)
187 else: me.assertEqual(ksz.best(x), best)
606677f6
MW
188 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
189 else: me.assertEqual(ksz.pad(x), pad)
553d59fe
MW
190
191 ## Check construction.
192 ksz = C.KeySZSet(7)
193 me.assertEqual(ksz.default, 7)
fd46a88d 194 me.assertEqual(ksz.set, set([7]))
553d59fe
MW
195 me.assertEqual(ksz.min, 7)
196 me.assertEqual(ksz.max, 7)
fd46a88d 197 ksz = C.KeySZSet(7, iter([3, 6, 9]))
553d59fe 198 me.assertEqual(ksz.default, 7)
fd46a88d 199 me.assertEqual(ksz.set, set([3, 6, 7, 9]))
553d59fe
MW
200 me.assertEqual(ksz.min, 3)
201 me.assertEqual(ksz.max, 9)
202
203 def test_range(me):
204 ## Note that no published algorithm uses a 16-bit `range' spec, or an
205 ## unbounded `range'.
206
207 ## A typical spec.
208 ksz = C.rijndael.keysz
209 me.assertEqual(type(ksz), C.KeySZRange)
210 me.assertEqual(ksz.default, 32)
211 me.assertEqual(ksz.min, 4)
212 me.assertEqual(ksz.max, 32)
213 me.assertEqual(ksz.mod, 4)
606677f6
MW
214 for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
215 (15, 12, 16), (16, 16, 16), (17, 16, 20),
216 (31, 28, 32), (32, 32, 32), (33, 32, None)]:
217 if x == best == pad: me.assertTrue(ksz.check(x))
553d59fe
MW
218 else: me.assertFalse(ksz.check(x))
219 if best is None: me.assertRaises(ValueError, ksz.best, x)
220 else: me.assertEqual(ksz.best(x), best)
606677f6
MW
221 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
222 else: me.assertEqual(ksz.pad(x), pad)
553d59fe
MW
223
224 ## Check construction.
225 ksz = C.KeySZRange(28, 21, 35, 7)
226 me.assertEqual(ksz.default, 28)
227 me.assertEqual(ksz.min, 21)
228 me.assertEqual(ksz.max, 35)
229 me.assertEqual(ksz.mod, 7)
a9410db7
MW
230 ksz = C.KeySZRange(28, 21, None, 7)
231 me.assertEqual(ksz.min, 21)
232 me.assertEqual(ksz.max, None)
233 me.assertEqual(ksz.mod, 7)
234 me.assertEqual(ksz.pad(36), 42)
553d59fe
MW
235 me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
236 me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
237 me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
238 me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
239 me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
240 me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
241 me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
242
243 def test_conversions(me):
244 me.assertEqual(C.KeySZ.fromec(256), 128)
245 me.assertEqual(C.KeySZ.fromschnorr(256), 128)
246 me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
247 me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
248 me.assertEqual(C.KeySZ.toec(128), 256)
249 me.assertEqual(C.KeySZ.toschnorr(128), 256)
250 me.assertEqual(C.KeySZ.todl(128), 2958.6875)
251 me.assertEqual(C.KeySZ.toif(128), 2958.6875)
252
253###--------------------------------------------------------------------------
254class TestCipher (T.GenericTestMixin):
255 """Test basic symmetric ciphers."""
256
257 def _test_cipher(me, ccls):
258
259 ## Check the class properties.
260 me.assertEqual(type(ccls.name), str)
261 me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
262 me.assertEqual(type(ccls.blksz), int)
263
264 ## Check round-tripping.
265 k = T.span(ccls.keysz.default)
266 iv = T.span(ccls.blksz)
267 m = T.span(253)
268 enc = ccls(k)
269 dec = ccls(k)
270 try: enc.setiv(iv)
271 except ValueError: can_setiv = False
272 else:
273 can_setiv = True
274 dec.setiv(iv)
275 c0 = enc.encrypt(m[0:57])
276 m0 = dec.decrypt(c0)
277 c1 = enc.encrypt(m[57:189])
278 m1 = dec.decrypt(c1)
279 try: enc.bdry()
280 except ValueError: can_bdry = False
281 else:
282 dec.bdry()
283 can_bdry = True
284 c2 = enc.encrypt(m[189:253])
285 m2 = dec.decrypt(c2)
286 me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
287 me.assertEqual(m0, m[0:57])
288 me.assertEqual(m1, m[57:189])
289 me.assertEqual(m2, m[189:253])
290
291 ## Check the `enczero' and `deczero' methods.
292 c3 = enc.enczero(32)
293 me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
294 m4 = dec.deczero(32)
295 me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
296
297 ## Check that ciphers which support a `boundary' operation actually
298 ## need it.
299 if can_bdry:
300 dec = ccls(k)
301 if can_setiv: dec.setiv(iv)
302 m01 = dec.decrypt(c0 + c1)
303 me.assertEqual(m01, m[0:189])
304
305 ## Check that the boundary actually does something.
306 if can_bdry:
307 dec = ccls(k)
308 if can_setiv: dec.setiv(iv)
309 m012 = dec.decrypt(c0 + c1 + c2)
310 me.assertNotEqual(m012, m)
311
312 ## Check that bad key lengths are rejected.
313 badlen = bad_key_size(ccls.keysz)
314 if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
315
316TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
317 ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
318 "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
319
320###--------------------------------------------------------------------------
10f3f611
MW
321class TestAuthenticatedEncryption \
322 (HashBufferTestMixin, T.GenericTestMixin):
323 """Test authenticated encryption schemes."""
324
325 def _test_aead(me, aecls):
326
327 ## Check the class properties.
328 me.assertEqual(type(aecls.name), str)
329 me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
330 me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
331 me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
332 me.assertEqual(type(aecls.blksz), int)
333 me.assertEqual(type(aecls.bufsz), int)
334 me.assertEqual(type(aecls.ohd), int)
335 me.assertEqual(type(aecls.flags), int)
336
337 ## Check round-tripping, with full precommitment. First, select some
338 ## parameters. (It's conceivable that some AEAD schemes are more
339 ## restrictive than advertised by the various properties, but this works
340 ## out OK in practice.)
341 k = T.span(aecls.keysz.default)
342 n = T.span(aecls.noncesz.default)
343 if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
344 else: h = T.span(131)
345 m = T.span(253)
346 tsz = aecls.tagsz.default
347 key = aecls(k)
348
349 ## Next, encrypt a message, checking that things are proper as we go.
350 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
351 me.assertEqual(enc.hsz, len(h))
352 me.assertEqual(enc.msz, len(m))
353 me.assertEqual(enc.mlen, 0)
354 me.assertEqual(enc.tsz, tsz)
355 aad = enc.aad()
356 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
357 else: me.assertEqual(aad.hsz, None)
358 me.assertEqual(aad.hlen, 0)
359 if not aecls.flags&C.AEADF_NOAAD:
360 aad.hash(h[0:83])
361 me.assertEqual(aad.hlen, 83)
362 aad.hash(h[83:131])
363 me.assertEqual(aad.hlen, 131)
364 c0 = enc.encrypt(m[0:57])
365 me.assertEqual(enc.mlen, 57)
366 me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
367 c1 = enc.encrypt(m[57:189])
368 me.assertEqual(enc.mlen, 189)
369 me.assertTrue(132 - aecls.bufsz <= len(c1) <=
370 132 + aecls.bufsz + aecls.ohd)
371 c2 = enc.encrypt(m[189:253])
372 me.assertEqual(enc.mlen, 253)
373 me.assertTrue(64 - aecls.bufsz <= len(c2) <=
374 64 + aecls.bufsz + aecls.ohd)
375 c3, t = enc.done(aad = aad)
376 me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
377 c = c0 + c1 + c2 + c3
378 me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
379 me.assertEqual(len(t), tsz)
380
381 ## And now decrypt it again, with different record boundaries.
382 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
383 me.assertEqual(dec.hsz, len(h))
384 me.assertEqual(dec.csz, len(c))
385 me.assertEqual(dec.clen, 0)
386 me.assertEqual(dec.tsz, tsz)
387 aad = dec.aad()
388 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
389 else: me.assertEqual(aad.hsz, None)
390 me.assertEqual(aad.hlen, 0)
391 aad.hash(h)
392 m0 = dec.decrypt(c[0:156])
393 me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
394 m1 = dec.decrypt(c[156:])
395 me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
396 len(c) - 156 + aecls.bufsz)
397 m2 = dec.done(tag = t, aad = aad)
398 me.assertEqual(m0 + m1 + m2, m)
399
400 ## And again, with the wrong tag.
401 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
402 aad = dec.aad(); aad.hash(h)
403 _ = dec.decrypt(c)
404 me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
405
406 ## Check that the all-in-one methods work.
407 me.assertEqual((c, t),
408 key.encrypt(n = n, h = h, m = m, tsz = tsz))
409 me.assertEqual(m,
410 key.decrypt(n = n, h = h, c = c, t = t))
411
412 ## Check that bad key, nonce, and tag lengths are rejected.
413 badlen = bad_key_size(aecls.keysz)
414 if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
415 badlen = bad_key_size(aecls.noncesz)
416 if badlen is not None:
417 me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
418 hsz = len(h), msz = len(m), tsz = tsz)
419 me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
420 hsz = len(h), csz = len(c), tsz = tsz)
421 if not aecls.flags&C.AEADF_PCTSZ:
422 enc = key.enc(nonce = n, hsz = 0, msz = len(m))
423 _ = enc.encrypt(m)
424 me.assertRaises(ValueError, enc.done, tsz = badlen)
425 badlen = bad_key_size(aecls.tagsz)
426 if badlen is not None:
427 me.assertRaises(ValueError, key.enc, nonce = n,
428 hsz = len(h), msz = len(m), tsz = badlen)
429 me.assertRaises(ValueError, key.dec, nonce = n,
430 hsz = len(h), csz = len(c), tsz = badlen)
431
432 ## Check that we can't get a loose `aad' object from a scheme which has
433 ## nonce-dependent AAD processing.
434 if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
435
436 ## Check the menagerie of AAD hashing methods.
437 if not aecls.flags&C.AEADF_NOAAD:
438 def mkhash(hsz):
439 enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
440 aad = enc.aad()
441 return aad, lambda: enc.done(aad = aad)[1]
442 me.check_hashbuffer(mkhash)
443
444 ## Check that encryption/decryption works with the given precommitments.
445 def quick_enc_check(**kw):
446 enc = key.enc(**kw)
447 aad = enc.aad().hash(h)
448 c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
449 me.assertEqual((c, t), (c0 + c1, tt))
450 def quick_dec_check(**kw):
451 dec = key.dec(**kw)
452 aad = dec.aad().hash(h)
453 m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
454 me.assertEqual(m, m0 + m1)
455
456 ## Check that we can get away without precommitting to the header length
457 ## if and only if the AEAD scheme says it will let us.
458 if aecls.flags&C.AEADF_PCHSZ:
459 me.assertRaises(ValueError, key.enc, nonce = n,
460 msz = len(m), tsz = tsz)
461 me.assertRaises(ValueError, key.dec, nonce = n,
462 csz = len(c), tsz = tsz)
463 else:
464 quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
465 quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
466
467 ## Check that we can get away without precommitting to the message/
468 ## ciphertext length if and only if the AEAD scheme says it will let us.
469 if aecls.flags&C.AEADF_PCMSZ:
470 me.assertRaises(ValueError, key.enc, nonce = n,
471 hsz = len(h), tsz = tsz)
472 me.assertRaises(ValueError, key.dec, nonce = n,
473 hsz = len(h), tsz = tsz)
474 else:
475 quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
476 quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
477
478 ## Check that we can get away without precommitting to the tag length if
479 ## and only if the AEAD scheme says it will let us.
480 if aecls.flags&C.AEADF_PCTSZ:
481 me.assertRaises(ValueError, key.enc, nonce = n,
482 hsz = len(h), msz = len(m))
483 me.assertRaises(ValueError, key.dec, nonce = n,
484 hsz = len(h), csz = len(c))
485 else:
486 quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
487 quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
488
489 ## Check that if we precommit to the header length, we're properly held
490 ## to the commitment.
491 if not aecls.flags&C.AEADF_NOAAD:
492
493 ## First, check encryption with underrun. If we must supply AAD first,
494 ## then the underrun will be reported when we start trying to encrypt;
495 ## otherwise, checking is delayed until `done'.
496 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
497 aad = enc.aad().hash(h[0:83])
498 if aecls.flags&C.AEADF_AADFIRST:
499 me.assertRaises(ValueError, enc.encrypt, m)
500 else:
501 _ = enc.encrypt(m)
502 me.assertRaises(ValueError, enc.done, aad = aad)
503
504 ## Next, check decryption with underrun. If we must supply AAD first,
505 ## then the underrun will be reported when we start trying to encrypt;
506 ## otherwise, checking is delayed until `done'.
507 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
508 aad = dec.aad().hash(h[0:83])
509 if aecls.flags&C.AEADF_AADFIRST:
510 me.assertRaises(ValueError, dec.decrypt, c)
511 else:
512 _ = dec.decrypt(c)
513 me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
514
515 ## If AAD processing is nonce-dependent then an overrun will be
516 ## detected imediately.
517 if aecls.flags&C.AEADF_AADNDEP:
518 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
519 aad = enc.aad().hash(h[0:83])
520 me.assertRaises(ValueError, aad.hash, h[82:131])
521 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
522 aad = dec.aad().hash(h[0:83])
523 me.assertRaises(ValueError, aad.hash, h[82:131])
524
525 ## Some additional tests for nonce-dependent `aad' objects.
526 if aecls.flags&C.AEADF_AADNDEP:
527
528 ## Check that `aad' objects can't be used once their parents are gone.
529 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
530 aad = enc.aad()
531 del enc
532 me.assertRaises(ValueError, aad.hash, h)
533
534 ## Check that they can't be crossed over.
535 enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
536 enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
537 enc0.aad().hash(h)
538 aad1 = enc1.aad().hash(h)
539 _ = enc0.encrypt(m)
540 me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
541
542 ## Test copying AAD.
543 if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
544 aad0 = key.aad()
545 aad0.hash(h[0:83])
546 aad1 = aad0.copy()
547 aad2 = aad1.copy()
548 aad0.hash(h[83:131])
549 aad1.hash(h[83:131])
550 aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
551 me.assertEqual(key.enc(nonce = n, hsz = len(h),
552 msz = 0, tsz = tsz).done(aad = aad0),
553 key.enc(nonce = n, hsz = len(h),
554 msz = 0, tsz = tsz).done(aad = aad1))
555 me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
556 msz = 0, tsz = tsz).done(aad = aad0),
557 key.enc(nonce = n, hsz = len(h),
558 msz = 0, tsz = tsz).done(aad = aad2))
559
560 ## Check that if we precommit to the message length, we're properly held
561 ## to the commitment. (Fortunately, this is way simpler than the AAD
562 ## case above.) First, try an underrun.
563 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
564 _ = enc.encrypt(m[0:183])
565 me.assertRaises(ValueError, enc.done, tsz = tsz)
566 dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
567 _ = dec.decrypt(c[0:183])
568 me.assertRaises(ValueError, dec.done, tag = t)
569
570 ## And now an overrun.
571 enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
572 me.assertRaises(ValueError, enc.encrypt, m)
573 dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
574 me.assertRaises(ValueError, dec.decrypt, c)
575
576 ## Finally, check that if we precommit to a tag length, we're properly
577 ## held to the commitment. This depends on being able to find a tag size
578 ## which isn't the default.
579 tsz1 = different_key_size(aecls.tagsz, tsz)
580 if tsz1 is not None:
581 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
582 _ = enc.encrypt(m)
583 me.assertRaises(ValueError, enc.done, tsz = tsz)
584 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
585 aad = dec.aad().hash(h)
586 _ = dec.decrypt(c)
587 me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
588
589TestAuthenticatedEncryption.generate_testcases \
590 ((name, C.gcaeads[name]) for name in
591 ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
592 "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
593
594###--------------------------------------------------------------------------
553d59fe
MW
595class BaseTestHash (HashBufferTestMixin):
596 """Base class for testing hash functions."""
597
598 def check_hash(me, hcls, need_bufsz = True):
599 """
600 Check hash class HCLS.
601
71574cba
MW
602 If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz'
603 attribute. This test is mostly reused for MACs, which don't have this
604 attribute.
553d59fe
MW
605 """
606 ## Check the class properties.
71574cba
MW
607 me.assertEqual(type(hcls.name), str)
608 if need_bufsz: me.assertEqual(type(hcls.bufsz), int)
609 me.assertEqual(type(hcls.hashsz), int)
553d59fe
MW
610
611 ## Set some initial values.
612 m = T.span(131)
613 h = hcls().hash(m).done()
614
615 ## Check that hash length comes out right.
71574cba 616 me.assertEqual(len(h), hcls.hashsz)
553d59fe
MW
617
618 ## Check that we get the same answer if we split the message up.
619 me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
620
621 ## Check the `check' method.
622 me.assertTrue(hcls().hash(m).check(h))
71574cba 623 me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa")))
553d59fe
MW
624
625 ## Check the menagerie of random hashing methods.
626 def mkhash(_):
627 h = hcls()
628 return h, h.done
629 me.check_hashbuffer(mkhash)
630
631class TestHash (BaseTestHash, T.GenericTestMixin):
632 """Test hash functions."""
633 def _test_hash(me, hcls): me.check_hash(hcls, need_bufsz = True)
634
635TestHash.generate_testcases((name, C.gchashes[name]) for name in
636 ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
637 "crc32"])
638
639###--------------------------------------------------------------------------
640class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
641 """Test message authentication codes."""
642
643 def _test_mac(me, mcls):
644
645 ## Check the MAC properties.
646 me.assertEqual(type(mcls.name), str)
647 me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
648 me.assertEqual(type(mcls.tagsz), int)
649
650 ## Test hashing.
651 k = T.span(mcls.keysz.default)
652 key = mcls(k)
71574cba 653 me.assertEqual(key.hashsz, key.tagsz)
553d59fe
MW
654 me.check_hash(key, need_bufsz = False)
655
656 ## Check that bad key lengths are rejected.
657 badlen = bad_key_size(mcls.keysz)
658 if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
659
660TestMessageAuthentication.generate_testcases \
661 ((name, C.gcmacs[name]) for name in
662 ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
663
664class TestPoly1305 (HashBufferTestMixin):
665 """Check the Poly1305 one-time message authentication function."""
666
667 def test_poly1305(me):
668
669 ## Check the MAC properties.
670 me.assertEqual(C.poly1305.name, "poly1305")
671 me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
672 me.assertEqual(C.poly1305.keysz.default, 16)
fd46a88d 673 me.assertEqual(C.poly1305.keysz.set, set([16]))
553d59fe
MW
674 me.assertEqual(C.poly1305.tagsz, 16)
675 me.assertEqual(C.poly1305.masksz, 16)
676
677 ## Set some initial values.
678 k = T.span(16)
679 u = T.span(64)[-16:]
680 m = T.span(149)
681 key = C.poly1305(k)
682 t = key(u).hash(m).done()
683
684 ## Check the key properties.
71574cba
MW
685 me.assertEqual(key.name, "poly1305")
686 me.assertEqual(key.tagsz, 16)
687 me.assertEqual(key.tagsz, 16)
553d59fe
MW
688 me.assertEqual(len(t), 16)
689
690 ## Check that we get the same answer if we split the message up.
691 me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
692
693 ## Check the `check' method.
694 me.assertTrue(key(u).hash(m).check(t))
695 me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
696
697 ## Check the menagerie of random hashing methods.
698 def mkhash(_):
699 h = key(u)
700 return h, h.done
701 me.check_hashbuffer(mkhash)
702
703 ## Check that we can't complete hashing without a mask.
704 me.assertRaises(ValueError, key().hash(m).done)
705
706 ## Check `concat'.
707 h0 = key().hash(m[0:96])
708 h1 = key().hash(m[96:117])
709 me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
710 key1 = C.poly1305(k)
711 me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
712 me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
713 me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
714
715###--------------------------------------------------------------------------
716class TestHLatin (U.TestCase):
717 """Test the `hsalsa20' and `hchacha20' functions."""
718
719 def test_hlatin(me):
3b5f9ac0 720 kk = [T.span(sz) for sz in [10, 16, 32]]
553d59fe
MW
721 n = T.span(16)
722 bad_k = T.span(18)
723 bad_n = T.span(13)
724 for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
725 C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
726 for k in kk:
727 h = fn(k, n)
728 me.assertEqual(len(h), 32)
729 me.assertRaises(ValueError, fn, bad_k, n)
730 me.assertRaises(ValueError, fn, k, bad_n)
731
732###--------------------------------------------------------------------------
733class TestKeccak (HashBufferTestMixin):
734 """Test the Keccak-p[1600, n] sponge function."""
735
736 def test_keccak(me):
737
738 ## Make a state and feed some stuff into it.
739 m0 = T.bin("some initial string")
740 m1 = T.bin("awesome follow-up string")
741 st0 = C.Keccak1600()
742 me.assertEqual(st0.nround, 24)
743 st0.mix(m0).step()
744
745 ## Make another step with a different round count.
746 st1 = C.Keccak1600(23)
747 st1.mix(m0).step()
748 me.assertNotEqual(st0.extract(32), st1.extract(32))
749
dad564fa
MW
750 ## Check state copying.
751 st1 = st0.copy()
752 mask = st1.extract(len(m1))
753 st0.mix(m1)
754 st1.mix(m1)
755 me.assertEqual(st0.extract(32), st1.extract(32))
756
553d59fe
MW
757 ## Check error conditions.
758 _ = st0.extract(200)
759 me.assertRaises(ValueError, st0.extract, 201)
760 st0.mix(T.span(200))
761 me.assertRaises(ValueError, st0.mix, T.span(201))
762
763 def check_shake(me, xcls, c, done_matches_xof = True):
764 """
765 Test the SHAKE and cSHAKE XOFs.
766
767 This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
768 to indicate that the XOF output is range-separated from the fixed-length
769 outputs (unlike the basic SHAKE functions).
770 """
771
772 ## Check the hash attributes.
773 x = xcls()
774 me.assertEqual(x.rate, 200 - c)
775 me.assertEqual(x.buffered, 0)
776 me.assertEqual(x.state, "absorb")
777
778 ## Set some initial values.
779 func = T.bin("TESTXOF")
780 perso = T.bin("catacomb-python test")
781 m = T.span(167)
782 h0 = xcls().hash(m).done(193)
783 me.assertEqual(len(h0), 193)
784 h1 = xcls(func = func, perso = perso).hash(m).done(193)
785 me.assertEqual(len(h1), 193)
786 me.assertNotEqual(h0, h1)
787
788 ## Check input and output in pieces, and the state machine.
789 if done_matches_xof: h = h0
790 else: h = xcls().hash(m).xof().get(len(h0))
791 x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
792 me.assertEqual(h, x.get(98) + x.get(95))
793
794 ## Check masking.
795 x = xcls().hash(m).xof()
49e3a113 796 me.assertEqual(x.mask(m), m ^ h[0:len(m)])
553d59fe
MW
797
798 ## Check the `check' method.
799 me.assertTrue(xcls().hash(m).check(h0))
800 me.assertFalse(xcls().hash(m).check(h1))
801
802 ## Check the menagerie of random hashing methods.
803 def mkhash(_):
804 x = xcls(func = func, perso = perso)
aaf8c7aa 805 return x, x.done
553d59fe
MW
806 me.check_hashbuffer(mkhash)
807
808 ## Check the state machine tracking.
809 x = xcls(); me.assertEqual(x.state, "absorb")
810 x.hash(m); me.assertEqual(x.state, "absorb")
811 xx = x.copy()
aaf8c7aa 812 h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
553d59fe
MW
813 me.assertEqual(xx.state, "dead")
814 me.assertRaises(ValueError, xx.done, 1)
815 me.assertRaises(ValueError, xx.get, 1)
816 me.assertEqual(x.state, "absorb")
817 me.assertRaises(ValueError, x.get, 1)
818 x.xof(); me.assertEqual(x.state, "squeeze")
819 me.assertRaises(ValueError, x.done, 1)
820 _ = x.get(1)
821 yy = x.copy(); me.assertEqual(yy.state, "squeeze")
822
823 def test_shake128(me): me.check_shake(C.Shake128, 32)
824 def test_shake256(me): me.check_shake(C.Shake256, 64)
825
826 def check_kmac(me, mcls, c):
827 k = T.span(32)
04b75d71 828 me.check_shake(lambda func = None, perso = None:
553d59fe
MW
829 mcls(k, perso = perso),
830 c, done_matches_xof = False)
831
832 def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
833 def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
834
835###--------------------------------------------------------------------------
836class TestPRP (T.GenericTestMixin):
837 """Test pseudorandom permutations (PRPs)."""
838
839 def _test_prp(me, pcls):
840
841 ## Check the PRP properties.
842 me.assertEqual(type(pcls.name), str)
843 me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
844 me.assertEqual(type(pcls.blksz), int)
845
846 ## Check round-tripping.
847 k = T.span(pcls.keysz.default)
848 key = pcls(k)
849 m = T.span(pcls.blksz)
850 c = key.encrypt(m)
851 me.assertEqual(len(c), pcls.blksz)
852 me.assertEqual(m, key.decrypt(c))
853
854 ## Check that bad key lengths are rejected.
855 badlen = bad_key_size(pcls.keysz)
856 if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
857
858 ## Check that bad blocks are rejected.
859 badblk = T.span(pcls.blksz + 1)
860 me.assertRaises(ValueError, key.encrypt, badblk)
861 me.assertRaises(ValueError, key.decrypt, badblk)
862
863TestPRP.generate_testcases((name, C.gcprps[name]) for name in
864 ["desx", "blowfish", "rijndael"])
865
866###----- That's all, folks --------------------------------------------------
867
868if __name__ == "__main__": U.main()