8346fad1b8da6fbcb63273ad4f9a6e8c22d53689
[catacomb-python] / t / t-algorithms.py
1 ### -*- mode: python, coding: utf-8 -*-
2 ###
3 ### Test symmetric algorithms
4 ###
5 ### (c) 2019 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software: you can redistribute it and/or
13 ### modify it under the terms of the GNU General Public License as
14 ### published by the Free Software Foundation; either version 2 of the
15 ### License, or (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful, but
18 ### WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 ### General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python. If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imported modules.
29
30 import catacomb as C
31 import unittest as U
32 import testutils as T
33
34 ###--------------------------------------------------------------------------
35 ### Utilities.
36
37 def bad_key_size(ksz):
38 if isinstance(ksz, C.KeySZAny): return None
39 elif isinstance(ksz, C.KeySZRange):
40 if ksz.mod != 1: return ksz.min + 1
41 elif ksz.max is not None: return ksz.max + 1
42 elif ksz.min != 0: return ksz.min - 1
43 else: return None
44 elif isinstance(ksz, C.KeySZSet):
45 for sz in sorted(ksz.set):
46 if sz + 1 not in ksz.set: return sz + 1
47 assert False, "That should have worked."
48 else:
49 return None
50
51 def different_key_size(ksz, sz):
52 if isinstance(ksz, C.KeySZAny): return sz + 1
53 elif isinstance(ksz, C.KeySZRange):
54 if sz > ksz.min: return sz - ksz.mod
55 elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
56 else: return None
57 elif isinstance(ksz, C.KeySZSet):
58 for sz1 in sorted(ksz.set):
59 if sz != sz1: return sz1
60 return None
61 else:
62 return None
63
64 class HashBufferTestMixin (U.TestCase):
65 """Mixin class for testing all of the various `hash...' methods."""
66
67 def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68 """Check `hashuN'."""
69
70 ## Check encoding an integer.
71 h0, donefn0 = makefn(w + 2)
72 hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73 h1, donefn1 = makefn(w + 2)
74 h1.hash(T.span(w + 2))
75 me.assertEqual(donefn0(), donefn1())
76
77 ## Check overflow detection.
78 h0, _ = makefn(w)
79 me.assertRaises((OverflowError, ValueError),
80 hashfn, h0, 1 << 8*w)
81
82 def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83 """Check `hashbufN'."""
84
85 ## Go through a number of different sizes.
86 for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87 if n >= 1 << 8*w: continue
88 h0, donefn0 = makefn(2 + w + n)
89 hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90 h1, donefn1 = makefn(2 + w + n)
91 h1.hash(T.prep_lenseq(w, n, bigendp, True))
92 me.assertEqual(donefn0(), donefn1())
93
94 ## Check blocks which are too large for the length prefix.
95 if w <= 3:
96 n = 1 << 8*w
97 h0, _ = makefn(w + n)
98 me.assertRaises((ValueError, TypeError),
99 hashfn, h0, C.ByteString.zero(n))
100
101 def check_hashbuffer(me, makefn):
102 """Test the various `hash...' methods."""
103
104 ## Check `hashuN'.
105 me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
106 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
107 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
108 me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
109 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
110 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
111 me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
112 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
113 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
114 me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
115 if hasattr(makefn(0)[0], "hashu64"):
116 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
117 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
118 me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
119
120 ## Check `hashbufN'.
121 me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
122 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
123 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
124 me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
125 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
126 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
127 me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
128 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
129 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
130 me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
131 if hasattr(makefn(0)[0], "hashbuf64"):
132 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
133 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
134 me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
135
136 ###--------------------------------------------------------------------------
137 class TestKeysize (U.TestCase):
138
139 def test_any(me):
140
141 ## A typical one-byte spec.
142 ksz = C.seal.keysz
143 me.assertEqual(type(ksz), C.KeySZAny)
144 me.assertEqual(ksz.default, 20)
145 me.assertEqual(ksz.min, 0)
146 me.assertEqual(ksz.max, None)
147 for n in [0, 12, 20, 5000]:
148 me.assertTrue(ksz.check(n))
149 me.assertEqual(ksz.best(n), n)
150 me.assertEqual(ksz.pad(n), n)
151
152 ## A typical two-byte spec. (No published algorithms actually /need/ a
153 ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
154 ksz = C.sha256_hmac.keysz
155 me.assertEqual(type(ksz), C.KeySZAny)
156 me.assertEqual(ksz.default, 32)
157 me.assertEqual(ksz.min, 0)
158 me.assertEqual(ksz.max, None)
159 for n in [0, 12, 20, 5000]:
160 me.assertTrue(ksz.check(n))
161 me.assertEqual(ksz.best(n), n)
162 me.assertEqual(ksz.pad(n), n)
163
164 ## Check construction.
165 ksz = C.KeySZAny(15)
166 me.assertEqual(ksz.default, 15)
167 me.assertEqual(ksz.min, 0)
168 me.assertEqual(ksz.max, None)
169 me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
170 me.assertEqual(C.KeySZAny(0).default, 0)
171
172 def test_set(me):
173 ## Note that no published algorithm uses a 16-bit `set' spec.
174
175 ## A typical spec.
176 ksz = C.salsa20.keysz
177 me.assertEqual(type(ksz), C.KeySZSet)
178 me.assertEqual(ksz.default, 32)
179 me.assertEqual(ksz.min, 10)
180 me.assertEqual(ksz.max, 32)
181 me.assertEqual(ksz.set, set([10, 16, 32]))
182 for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
183 (15, 10, 16), (16, 16, 16), (17, 16, 32),
184 (31, 16, 32), (32, 32, 32), (33, 32, None)]:
185 if x == best == pad: me.assertTrue(ksz.check(x))
186 else: me.assertFalse(ksz.check(x))
187 if best is None: me.assertRaises(ValueError, ksz.best, x)
188 else: me.assertEqual(ksz.best(x), best)
189 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
190 else: me.assertEqual(ksz.pad(x), pad)
191
192 ## Check construction.
193 ksz = C.KeySZSet(7)
194 me.assertEqual(ksz.default, 7)
195 me.assertEqual(ksz.set, set([7]))
196 me.assertEqual(ksz.min, 7)
197 me.assertEqual(ksz.max, 7)
198 ksz = C.KeySZSet(7, iter([3, 6, 9]))
199 me.assertEqual(ksz.default, 7)
200 me.assertEqual(ksz.set, set([3, 6, 7, 9]))
201 me.assertEqual(ksz.min, 3)
202 me.assertEqual(ksz.max, 9)
203
204 def test_range(me):
205 ## Note that no published algorithm uses a 16-bit `range' spec, or an
206 ## unbounded `range'.
207
208 ## A typical spec.
209 ksz = C.rijndael.keysz
210 me.assertEqual(type(ksz), C.KeySZRange)
211 me.assertEqual(ksz.default, 32)
212 me.assertEqual(ksz.min, 4)
213 me.assertEqual(ksz.max, 32)
214 me.assertEqual(ksz.mod, 4)
215 for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
216 (15, 12, 16), (16, 16, 16), (17, 16, 20),
217 (31, 28, 32), (32, 32, 32), (33, 32, None)]:
218 if x == best == pad: me.assertTrue(ksz.check(x))
219 else: me.assertFalse(ksz.check(x))
220 if best is None: me.assertRaises(ValueError, ksz.best, x)
221 else: me.assertEqual(ksz.best(x), best)
222 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
223 else: me.assertEqual(ksz.pad(x), pad)
224
225 ## Check construction.
226 ksz = C.KeySZRange(28, 21, 35, 7)
227 me.assertEqual(ksz.default, 28)
228 me.assertEqual(ksz.min, 21)
229 me.assertEqual(ksz.max, 35)
230 me.assertEqual(ksz.mod, 7)
231 ksz = C.KeySZRange(28, 21, None, 7)
232 me.assertEqual(ksz.min, 21)
233 me.assertEqual(ksz.max, None)
234 me.assertEqual(ksz.mod, 7)
235 me.assertEqual(ksz.pad(36), 42)
236 me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
237 me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
238 me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
239 me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
240 me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
241 me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
242 me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
243
244 def test_conversions(me):
245 me.assertEqual(C.KeySZ.fromec(256), 128)
246 me.assertEqual(C.KeySZ.fromschnorr(256), 128)
247 me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
248 me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
249 me.assertEqual(C.KeySZ.toec(128), 256)
250 me.assertEqual(C.KeySZ.toschnorr(128), 256)
251 me.assertEqual(C.KeySZ.todl(128), 2958.6875)
252 me.assertEqual(C.KeySZ.toif(128), 2958.6875)
253
254 ###--------------------------------------------------------------------------
255 class TestCipher (T.GenericTestMixin):
256 """Test basic symmetric ciphers."""
257
258 def _test_cipher(me, ccls):
259
260 ## Check the class properties.
261 me.assertEqual(type(ccls.name), str)
262 me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
263 me.assertEqual(type(ccls.blksz), int)
264
265 ## Check round-tripping.
266 k = T.span(ccls.keysz.default)
267 iv = T.span(ccls.blksz)
268 m = T.span(253)
269 enc = ccls(k)
270 dec = ccls(k)
271 try: enc.setiv(iv)
272 except ValueError: can_setiv = False
273 else:
274 can_setiv = True
275 dec.setiv(iv)
276 c0 = enc.encrypt(m[0:57])
277 m0 = dec.decrypt(c0)
278 c1 = enc.encrypt(m[57:189])
279 m1 = dec.decrypt(c1)
280 try: enc.bdry()
281 except ValueError: can_bdry = False
282 else:
283 dec.bdry()
284 can_bdry = True
285 c2 = enc.encrypt(m[189:253])
286 m2 = dec.decrypt(c2)
287 me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
288 me.assertEqual(m0, m[0:57])
289 me.assertEqual(m1, m[57:189])
290 me.assertEqual(m2, m[189:253])
291
292 ## Check the `enczero' and `deczero' methods.
293 c3 = enc.enczero(32)
294 me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
295 m4 = dec.deczero(32)
296 me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
297
298 ## Check that ciphers which support a `boundary' operation actually
299 ## need it.
300 if can_bdry:
301 dec = ccls(k)
302 if can_setiv: dec.setiv(iv)
303 m01 = dec.decrypt(c0 + c1)
304 me.assertEqual(m01, m[0:189])
305
306 ## Check that the boundary actually does something.
307 if can_bdry:
308 dec = ccls(k)
309 if can_setiv: dec.setiv(iv)
310 m012 = dec.decrypt(c0 + c1 + c2)
311 me.assertNotEqual(m012, m)
312
313 ## Check that bad key lengths are rejected.
314 badlen = bad_key_size(ccls.keysz)
315 if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
316
317 TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
318 ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
319 "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
320
321 ###--------------------------------------------------------------------------
322 class TestAuthenticatedEncryption \
323 (HashBufferTestMixin, T.GenericTestMixin):
324 """Test authenticated encryption schemes."""
325
326 def _test_aead(me, aecls):
327
328 ## Check the class properties.
329 me.assertEqual(type(aecls.name), str)
330 me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
331 me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
332 me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
333 me.assertEqual(type(aecls.blksz), int)
334 me.assertEqual(type(aecls.bufsz), int)
335 me.assertEqual(type(aecls.ohd), int)
336 me.assertEqual(type(aecls.flags), int)
337
338 ## Check round-tripping, with full precommitment. First, select some
339 ## parameters. (It's conceivable that some AEAD schemes are more
340 ## restrictive than advertised by the various properties, but this works
341 ## out OK in practice.)
342 k = T.span(aecls.keysz.default)
343 n = T.span(aecls.noncesz.default)
344 if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
345 else: h = T.span(131)
346 m = T.span(253)
347 tsz = aecls.tagsz.default
348 key = aecls(k)
349
350 ## Next, encrypt a message, checking that things are proper as we go.
351 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
352 me.assertEqual(enc.hsz, len(h))
353 me.assertEqual(enc.msz, len(m))
354 me.assertEqual(enc.mlen, 0)
355 me.assertEqual(enc.tsz, tsz)
356 aad = enc.aad()
357 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
358 else: me.assertEqual(aad.hsz, None)
359 me.assertEqual(aad.hlen, 0)
360 if not aecls.flags&C.AEADF_NOAAD:
361 aad.hash(h[0:83])
362 me.assertEqual(aad.hlen, 83)
363 aad.hash(h[83:131])
364 me.assertEqual(aad.hlen, 131)
365 c0 = enc.encrypt(m[0:57])
366 me.assertEqual(enc.mlen, 57)
367 me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
368 c1 = enc.encrypt(m[57:189])
369 me.assertEqual(enc.mlen, 189)
370 me.assertTrue(132 - aecls.bufsz <= len(c1) <=
371 132 + aecls.bufsz + aecls.ohd)
372 c2 = enc.encrypt(m[189:253])
373 me.assertEqual(enc.mlen, 253)
374 me.assertTrue(64 - aecls.bufsz <= len(c2) <=
375 64 + aecls.bufsz + aecls.ohd)
376 c3, t = enc.done(aad = aad)
377 me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
378 c = c0 + c1 + c2 + c3
379 me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
380 me.assertEqual(len(t), tsz)
381
382 ## And now decrypt it again, with different record boundaries.
383 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
384 me.assertEqual(dec.hsz, len(h))
385 me.assertEqual(dec.csz, len(c))
386 me.assertEqual(dec.clen, 0)
387 me.assertEqual(dec.tsz, tsz)
388 aad = dec.aad()
389 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
390 else: me.assertEqual(aad.hsz, None)
391 me.assertEqual(aad.hlen, 0)
392 aad.hash(h)
393 m0 = dec.decrypt(c[0:156])
394 me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
395 m1 = dec.decrypt(c[156:])
396 me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
397 len(c) - 156 + aecls.bufsz)
398 m2 = dec.done(tag = t, aad = aad)
399 me.assertEqual(m0 + m1 + m2, m)
400
401 ## And again, with the wrong tag.
402 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
403 aad = dec.aad(); aad.hash(h)
404 _ = dec.decrypt(c)
405 me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
406
407 ## Check that the all-in-one methods work.
408 me.assertEqual((c, t),
409 key.encrypt(n = n, h = h, m = m, tsz = tsz))
410 me.assertEqual(m,
411 key.decrypt(n = n, h = h, c = c, t = t))
412
413 ## Check that bad key, nonce, and tag lengths are rejected.
414 badlen = bad_key_size(aecls.keysz)
415 if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
416 badlen = bad_key_size(aecls.noncesz)
417 if badlen is not None:
418 me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
419 hsz = len(h), msz = len(m), tsz = tsz)
420 me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
421 hsz = len(h), csz = len(c), tsz = tsz)
422 if not aecls.flags&C.AEADF_PCTSZ:
423 enc = key.enc(nonce = n, hsz = 0, msz = len(m))
424 _ = enc.encrypt(m)
425 me.assertRaises(ValueError, enc.done, tsz = badlen)
426 badlen = bad_key_size(aecls.tagsz)
427 if badlen is not None:
428 me.assertRaises(ValueError, key.enc, nonce = n,
429 hsz = len(h), msz = len(m), tsz = badlen)
430 me.assertRaises(ValueError, key.dec, nonce = n,
431 hsz = len(h), csz = len(c), tsz = badlen)
432
433 ## Check that we can't get a loose `aad' object from a scheme which has
434 ## nonce-dependent AAD processing.
435 if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
436
437 ## Check the menagerie of AAD hashing methods.
438 if not aecls.flags&C.AEADF_NOAAD:
439 def mkhash(hsz):
440 enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
441 aad = enc.aad()
442 return aad, lambda: enc.done(aad = aad)[1]
443 me.check_hashbuffer(mkhash)
444
445 ## Check that encryption/decryption works with the given precommitments.
446 def quick_enc_check(**kw):
447 enc = key.enc(**kw)
448 aad = enc.aad().hash(h)
449 c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
450 me.assertEqual((c, t), (c0 + c1, tt))
451 def quick_dec_check(**kw):
452 dec = key.dec(**kw)
453 aad = dec.aad().hash(h)
454 m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
455 me.assertEqual(m, m0 + m1)
456
457 ## Check that we can get away without precommitting to the header length
458 ## if and only if the AEAD scheme says it will let us.
459 if aecls.flags&C.AEADF_PCHSZ:
460 me.assertRaises(ValueError, key.enc, nonce = n,
461 msz = len(m), tsz = tsz)
462 me.assertRaises(ValueError, key.dec, nonce = n,
463 csz = len(c), tsz = tsz)
464 else:
465 quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
466 quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
467
468 ## Check that we can get away without precommitting to the message/
469 ## ciphertext length if and only if the AEAD scheme says it will let us.
470 if aecls.flags&C.AEADF_PCMSZ:
471 me.assertRaises(ValueError, key.enc, nonce = n,
472 hsz = len(h), tsz = tsz)
473 me.assertRaises(ValueError, key.dec, nonce = n,
474 hsz = len(h), tsz = tsz)
475 else:
476 quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
477 quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
478
479 ## Check that we can get away without precommitting to the tag length if
480 ## and only if the AEAD scheme says it will let us.
481 if aecls.flags&C.AEADF_PCTSZ:
482 me.assertRaises(ValueError, key.enc, nonce = n,
483 hsz = len(h), msz = len(m))
484 me.assertRaises(ValueError, key.dec, nonce = n,
485 hsz = len(h), csz = len(c))
486 else:
487 quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
488 quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
489
490 ## Check that if we precommit to the header length, we're properly held
491 ## to the commitment.
492 if not aecls.flags&C.AEADF_NOAAD:
493
494 ## First, check encryption with underrun. If we must supply AAD first,
495 ## then the underrun will be reported when we start trying to encrypt;
496 ## otherwise, checking is delayed until `done'.
497 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
498 aad = enc.aad().hash(h[0:83])
499 if aecls.flags&C.AEADF_AADFIRST:
500 me.assertRaises(ValueError, enc.encrypt, m)
501 else:
502 _ = enc.encrypt(m)
503 me.assertRaises(ValueError, enc.done, aad = aad)
504
505 ## Next, check decryption with underrun. If we must supply AAD first,
506 ## then the underrun will be reported when we start trying to encrypt;
507 ## otherwise, checking is delayed until `done'.
508 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
509 aad = dec.aad().hash(h[0:83])
510 if aecls.flags&C.AEADF_AADFIRST:
511 me.assertRaises(ValueError, dec.decrypt, c)
512 else:
513 _ = dec.decrypt(c)
514 me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
515
516 ## If AAD processing is nonce-dependent then an overrun will be
517 ## detected imediately.
518 if aecls.flags&C.AEADF_AADNDEP:
519 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
520 aad = enc.aad().hash(h[0:83])
521 me.assertRaises(ValueError, aad.hash, h[82:131])
522 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
523 aad = dec.aad().hash(h[0:83])
524 me.assertRaises(ValueError, aad.hash, h[82:131])
525
526 ## Some additional tests for nonce-dependent `aad' objects.
527 if aecls.flags&C.AEADF_AADNDEP:
528
529 ## Check that `aad' objects can't be used once their parents are gone.
530 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
531 aad = enc.aad()
532 del enc
533 me.assertRaises(ValueError, aad.hash, h)
534
535 ## Check that they can't be crossed over.
536 enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
537 enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
538 enc0.aad().hash(h)
539 aad1 = enc1.aad().hash(h)
540 _ = enc0.encrypt(m)
541 me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
542
543 ## Test copying AAD.
544 if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
545 aad0 = key.aad()
546 aad0.hash(h[0:83])
547 aad1 = aad0.copy()
548 aad2 = aad1.copy()
549 aad0.hash(h[83:131])
550 aad1.hash(h[83:131])
551 aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
552 me.assertEqual(key.enc(nonce = n, hsz = len(h),
553 msz = 0, tsz = tsz).done(aad = aad0),
554 key.enc(nonce = n, hsz = len(h),
555 msz = 0, tsz = tsz).done(aad = aad1))
556 me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
557 msz = 0, tsz = tsz).done(aad = aad0),
558 key.enc(nonce = n, hsz = len(h),
559 msz = 0, tsz = tsz).done(aad = aad2))
560
561 ## Check that if we precommit to the message length, we're properly held
562 ## to the commitment. (Fortunately, this is way simpler than the AAD
563 ## case above.) First, try an underrun.
564 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
565 _ = enc.encrypt(m[0:183])
566 me.assertRaises(ValueError, enc.done, tsz = tsz)
567 dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
568 _ = dec.decrypt(c[0:183])
569 me.assertRaises(ValueError, dec.done, tag = t)
570
571 ## And now an overrun.
572 enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
573 me.assertRaises(ValueError, enc.encrypt, m)
574 dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
575 me.assertRaises(ValueError, dec.decrypt, c)
576
577 ## Finally, check that if we precommit to a tag length, we're properly
578 ## held to the commitment. This depends on being able to find a tag size
579 ## which isn't the default.
580 tsz1 = different_key_size(aecls.tagsz, tsz)
581 if tsz1 is not None:
582 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
583 _ = enc.encrypt(m)
584 me.assertRaises(ValueError, enc.done, tsz = tsz)
585 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
586 aad = dec.aad().hash(h)
587 _ = dec.decrypt(c)
588 me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
589
590 TestAuthenticatedEncryption.generate_testcases \
591 ((name, C.gcaeads[name]) for name in
592 ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
593 "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
594
595 ###--------------------------------------------------------------------------
596 class BaseTestHash (HashBufferTestMixin):
597 """Base class for testing hash functions."""
598
599 def check_hash(me, hcls, need_bufsz = True):
600 """
601 Check hash class HCLS.
602
603 If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz'
604 attribute. This test is mostly reused for MACs, which don't have this
605 attribute.
606 """
607 ## Check the class properties.
608 me.assertEqual(type(hcls.name), str)
609 if need_bufsz: me.assertEqual(type(hcls.bufsz), int)
610 me.assertEqual(type(hcls.hashsz), int)
611
612 ## Set some initial values.
613 m = T.span(131)
614 h = hcls().hash(m).done()
615
616 ## Check that hash length comes out right.
617 me.assertEqual(len(h), hcls.hashsz)
618
619 ## Check that we get the same answer if we split the message up.
620 me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
621
622 ## Check the `check' method.
623 me.assertTrue(hcls().hash(m).check(h))
624 me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa")))
625
626 ## Check the menagerie of random hashing methods.
627 def mkhash(_):
628 h = hcls()
629 return h, h.done
630 me.check_hashbuffer(mkhash)
631
632 class TestHash (BaseTestHash, T.GenericTestMixin):
633 """Test hash functions."""
634 def _test_hash(me, hcls): me.check_hash(hcls, need_bufsz = True)
635
636 TestHash.generate_testcases((name, C.gchashes[name]) for name in
637 ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
638 "crc32"])
639
640 ###--------------------------------------------------------------------------
641 class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
642 """Test message authentication codes."""
643
644 def _test_mac(me, mcls):
645
646 ## Check the MAC properties.
647 me.assertEqual(type(mcls.name), str)
648 me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
649 me.assertEqual(type(mcls.tagsz), int)
650
651 ## Test hashing.
652 k = T.span(mcls.keysz.default)
653 key = mcls(k)
654 me.assertEqual(key.hashsz, key.tagsz)
655 me.check_hash(key, need_bufsz = False)
656
657 ## Check that bad key lengths are rejected.
658 badlen = bad_key_size(mcls.keysz)
659 if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
660
661 TestMessageAuthentication.generate_testcases \
662 ((name, C.gcmacs[name]) for name in
663 ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
664
665 class TestPoly1305 (HashBufferTestMixin):
666 """Check the Poly1305 one-time message authentication function."""
667
668 def test_poly1305(me):
669
670 ## Check the MAC properties.
671 me.assertEqual(C.poly1305.name, "poly1305")
672 me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
673 me.assertEqual(C.poly1305.keysz.default, 16)
674 me.assertEqual(C.poly1305.keysz.set, set([16]))
675 me.assertEqual(C.poly1305.tagsz, 16)
676 me.assertEqual(C.poly1305.masksz, 16)
677
678 ## Set some initial values.
679 k = T.span(16)
680 u = T.span(64)[-16:]
681 m = T.span(149)
682 key = C.poly1305(k)
683 t = key(u).hash(m).done()
684
685 ## Check the key properties.
686 me.assertEqual(key.name, "poly1305")
687 me.assertEqual(key.tagsz, 16)
688 me.assertEqual(key.tagsz, 16)
689 me.assertEqual(len(t), 16)
690
691 ## Check that we get the same answer if we split the message up.
692 me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
693
694 ## Check the `check' method.
695 me.assertTrue(key(u).hash(m).check(t))
696 me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
697
698 ## Check the menagerie of random hashing methods.
699 def mkhash(_):
700 h = key(u)
701 return h, h.done
702 me.check_hashbuffer(mkhash)
703
704 ## Check that we can't complete hashing without a mask.
705 me.assertRaises(ValueError, key().hash(m).done)
706
707 ## Check `concat'.
708 h0 = key().hash(m[0:96])
709 h1 = key().hash(m[96:117])
710 me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
711 key1 = C.poly1305(k)
712 me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
713 me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
714 me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
715
716 ###--------------------------------------------------------------------------
717 class TestHLatin (U.TestCase):
718 """Test the `hsalsa20' and `hchacha20' functions."""
719
720 def test_hlatin(me):
721 kk = [T.span(sz) for sz in [10, 16, 32]]
722 n = T.span(16)
723 bad_k = T.span(18)
724 bad_n = T.span(13)
725 for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
726 C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
727 for k in kk:
728 h = fn(k, n)
729 me.assertEqual(len(h), 32)
730 me.assertRaises(ValueError, fn, bad_k, n)
731 me.assertRaises(ValueError, fn, k, bad_n)
732
733 ###--------------------------------------------------------------------------
734 class TestKeccak (HashBufferTestMixin):
735 """Test the Keccak-p[1600, n] sponge function."""
736
737 def test_keccak(me):
738
739 ## Make a state and feed some stuff into it.
740 m0 = T.bin("some initial string")
741 m1 = T.bin("awesome follow-up string")
742 st0 = C.Keccak1600()
743 me.assertEqual(st0.nround, 24)
744 st0.mix(m0).step()
745
746 ## Make another step with a different round count.
747 st1 = C.Keccak1600(23)
748 st1.mix(m0).step()
749 me.assertNotEqual(st0.extract(32), st1.extract(32))
750
751 ## Check state copying.
752 st1 = st0.copy()
753 mask = st1.extract(len(m1))
754 st0.mix(m1)
755 st1.mix(m1)
756 me.assertEqual(st0.extract(32), st1.extract(32))
757
758 ## Check error conditions.
759 _ = st0.extract(200)
760 me.assertRaises(ValueError, st0.extract, 201)
761 st0.mix(T.span(200))
762 me.assertRaises(ValueError, st0.mix, T.span(201))
763
764 def check_shake(me, xcls, c, done_matches_xof = True):
765 """
766 Test the SHAKE and cSHAKE XOFs.
767
768 This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
769 to indicate that the XOF output is range-separated from the fixed-length
770 outputs (unlike the basic SHAKE functions).
771 """
772
773 ## Check the hash attributes.
774 x = xcls()
775 me.assertEqual(x.rate, 200 - c)
776 me.assertEqual(x.buffered, 0)
777 me.assertEqual(x.state, "absorb")
778
779 ## Set some initial values.
780 func = T.bin("TESTXOF")
781 perso = T.bin("catacomb-python test")
782 m = T.span(167)
783 h0 = xcls().hash(m).done(193)
784 me.assertEqual(len(h0), 193)
785 h1 = xcls(func = func, perso = perso).hash(m).done(193)
786 me.assertEqual(len(h1), 193)
787 me.assertNotEqual(h0, h1)
788
789 ## Check input and output in pieces, and the state machine.
790 if done_matches_xof: h = h0
791 else: h = xcls().hash(m).xof().get(len(h0))
792 x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
793 me.assertEqual(h, x.get(98) + x.get(95))
794
795 ## Check masking.
796 x = xcls().hash(m).xof()
797 me.assertEqual(x.mask(m), m ^ h[0:len(m)])
798
799 ## Check the `check' method.
800 me.assertTrue(xcls().hash(m).check(h0))
801 me.assertFalse(xcls().hash(m).check(h1))
802
803 ## Check the menagerie of random hashing methods.
804 def mkhash(_):
805 x = xcls(func = func, perso = perso)
806 return x, x.done
807 me.check_hashbuffer(mkhash)
808
809 ## Check the state machine tracking.
810 x = xcls(); me.assertEqual(x.state, "absorb")
811 x.hash(m); me.assertEqual(x.state, "absorb")
812 xx = x.copy()
813 h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
814 me.assertEqual(xx.state, "dead")
815 me.assertRaises(ValueError, xx.done, 1)
816 me.assertRaises(ValueError, xx.get, 1)
817 me.assertEqual(x.state, "absorb")
818 me.assertRaises(ValueError, x.get, 1)
819 x.xof(); me.assertEqual(x.state, "squeeze")
820 me.assertRaises(ValueError, x.done, 1)
821 _ = x.get(1)
822 yy = x.copy(); me.assertEqual(yy.state, "squeeze")
823
824 def test_shake128(me): me.check_shake(C.Shake128, 32)
825 def test_shake256(me): me.check_shake(C.Shake256, 64)
826
827 def check_kmac(me, mcls, c):
828 k = T.span(32)
829 me.check_shake(lambda func = None, perso = None:
830 mcls(k, perso = perso),
831 c, done_matches_xof = False)
832
833 def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
834 def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
835
836 ###--------------------------------------------------------------------------
837 class TestPRP (T.GenericTestMixin):
838 """Test pseudorandom permutations (PRPs)."""
839
840 def _test_prp(me, pcls):
841
842 ## Check the PRP properties.
843 me.assertEqual(type(pcls.name), str)
844 me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
845 me.assertEqual(type(pcls.blksz), int)
846
847 ## Check round-tripping.
848 k = T.span(pcls.keysz.default)
849 key = pcls(k)
850 m = T.span(pcls.blksz)
851 c = key.encrypt(m)
852 me.assertEqual(len(c), pcls.blksz)
853 me.assertEqual(m, key.decrypt(c))
854
855 ## Check that bad key lengths are rejected.
856 badlen = bad_key_size(pcls.keysz)
857 if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
858
859 ## Check that bad blocks are rejected.
860 badblk = T.span(pcls.blksz + 1)
861 me.assertRaises(ValueError, key.encrypt, badblk)
862 me.assertRaises(ValueError, key.decrypt, badblk)
863
864 TestPRP.generate_testcases((name, C.gcprps[name]) for name in
865 ["desx", "blowfish", "rijndael"])
866
867 ###----- That's all, folks --------------------------------------------------
868
869 if __name__ == "__main__": U.main()