Makefile: Deploy countermeasures for dc(1) line splitting.
[ocb-tv] / ocbgen
CommitLineData
54482987
MW
1#! /usr/bin/python
2### -*-python-*-
3###
4### Generalization of OCB mode for other block sizes
5###
6### (c) 2017 Mark Wooding
7###
8
9###----- Licensing notice ---------------------------------------------------
10###
11### This program is free software; you can redistribute it and/or modify
12### it under the terms of the GNU General Public License as published by
13### the Free Software Foundation; either version 2 of the License, or
14### (at your option) any later version.
15###
16### This program is distributed in the hope that it will be useful,
17### but WITHOUT ANY WARRANTY; without even the implied warranty of
18### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19### GNU General Public License for more details.
20###
21### You should have received a copy of the GNU General Public License
22### along with this program; if not, write to the Free Software Foundation,
23### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
24
25from sys import argv, stderr
26from struct import pack
27from itertools import izip
28import catacomb as C
29
30R = C.FibRand(0)
31
32###--------------------------------------------------------------------------
33### Utilities.
34
35def combs(things, k):
36 ii = range(k)
37 n = len(things)
38 while True:
39 yield [things[i] for i in ii]
40 for j in xrange(k):
41 if j == k - 1: lim = n
42 else: lim = ii[j + 1]
43 i = ii[j] + 1
44 if i < lim:
45 ii[j] = i
46 break
47 ii[j] = j
48 else:
49 return
50
51POLYMAP = {}
52
53def poly(nbits):
54 try: return POLYMAP[nbits]
55 except KeyError: pass
56 base = C.GF(0).setbit(nbits).setbit(0)
57 for k in xrange(1, nbits, 2):
58 for cc in combs(range(1, nbits), k):
59 p = base + sum(C.GF(0).setbit(c) for c in cc)
60 if p.irreduciblep(): POLYMAP[nbits] = p; return p
61 raise ValueError, nbits
62
63def prim(nbits):
64 ## No fancy way to do this: I'd need a much cleverer factoring algorithm
65 ## than I have in my pockets.
66 if nbits == 64: cc = [64, 4, 3, 1, 0]
67 elif nbits == 96: cc = [96, 10, 9, 6, 0]
68 elif nbits == 128: cc = [128, 7, 2, 1, 0]
69 elif nbits == 192: cc = [192, 15, 11, 5, 0]
70 elif nbits == 256: cc = [256, 10, 5, 2, 0]
71 else: raise ValueError, 'no field for %d bits' % nbits
72 p = C.GF(0)
73 for c in cc: p = p.setbit(c)
74 return p
75
76def Z(n):
77 return C.ByteString.zero(n)
78
79def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
80
81def with_lastp(it):
82 it = iter(it)
83 try: j = next(it)
84 except StopIteration: raise ValueError, 'empty iter'
85 lastp = False
86 while not lastp:
87 i = j
88 try: j = next(it)
89 except StopIteration: lastp = True
90 yield i, lastp
91
92def safehex(x):
93 if len(x): return hex(x)
94 else: return '""'
95
96def keylens(ksz):
97 sel = []
98 if isinstance(ksz, C.KeySZSet): kk = ksz.set
99 elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
100 elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
101 kk = list(kk); kk = kk[:]
102 n = len(kk)
103 while n and len(sel) < 4:
104 i = R.range(n)
105 n -= 1
106 kk[i], kk[n] = kk[n], kk[i]
107 sel.append(kk[n])
108 return sel
109
110def pad0star(m, w):
111 n = len(m)
112 if not n: r = w
113 else: r = (-len(m))%w
114 if r: m += Z(r)
115 return C.ByteString(m)
116
117def pad10star(m, w):
118 r = w - len(m)%w
119 if r: m += '\x80' + Z(r - 1)
120 return C.ByteString(m)
121
122def ntz(i):
123 j = 0
124 while (i&1) == 0: i >>= 1; j += 1
125 return j
126
127def blocks(x, w):
128 v, i, n = [], 0, len(x)
129 while n - i > w:
130 v.append(C.ByteString(x[i:i + w]))
131 i += w
132 return v, C.ByteString(x[i:])
133
134EMPTY = C.bytes('')
135
136def blocks0(x, w):
137 v, tl = blocks(x, w)
138 if len(tl) == w: v.append(tl); tl = EMPTY
139 return v, tl
140
141###--------------------------------------------------------------------------
142### Luby--Rackoff large-block ciphers.
143
144class LubyRackoffCipher (type):
145 def __new__(cls, bc, blksz):
146 assert blksz%2 == 0
147 assert blksz <= 2*bc.blksz
148 name = '%s-lr[%d]' % (bc.name, 8*blksz)
149 me = type(name, (LubyRackoffBase,), {})
150 me.name = name
151 me.blksz = blksz
152 me.keysz = bc.keysz
153 me.bc = bc
154 return me
155
156class LubyRackoffBase (object):
157 NR = 4 # for strong-PRP security
158 def __init__(me, k):
159 if LRVERBOSE: print 'K = %s' % hex(k)
160 bc, blksz = me.__class__.bc, me.__class__.blksz
161 E = bc(k)
162 me.f = []
163 ksz = len(k)
164 i = C.MP(0)
165 for j in xrange(me.NR):
166 b = C.WriteBuffer()
167 while b.size < ksz:
168 x = E.encrypt(i.storeb(bc.blksz))
169 b.put(x)
170 if LRVERBOSE: print 'E(K; [%d]) = %s' % (i, hex(x))
171 i += 1
172 kj = C.ByteString(C.ByteString(b)[0:ksz])
173 if LRVERBOSE: print 'K_%d = %s' % (j, hex(kj))
174 me.f.append(bc(kj))
175 def encrypt(me, m):
176 bc, blksz = me.__class__.bc, me.__class__.blksz
177 assert len(m) == blksz
178 l, r = C.ByteString(m[:blksz/2]), C.ByteString(m[blksz/2:])
179 if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
180 for j in xrange(me.NR):
181 l0 = pad0star(l, bc.blksz)
182 t = me.f[j].encrypt(l0)
183 l, r = r ^ t[:blksz/2], l
184 if LRVERBOSE:
185 print 'E(K_%d; L_%d || 0^*) = %s' % (j, j, hex(t))
186 print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
187 return C.ByteString(r + l)
188 def decrypt(me, c):
189 bc, blksz = me.__class__.bc, me.__class__.blksz
190 assert len(c) == blksz
191 l, r = C.ByteString(c[:blksz/2]), C.ByteString(c[blksz/2:])
192 for j in xrange(me.NR - 1, -1, -1):
193 l0 = pad0star(l, bc.blksz)
194 t = me.f[j].encrypt(l0)
195 if LRVERBOSE:
196 print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
197 print 'E(K_%d; L_%d || 0^*) = %s' % (j + 1, j + 1, hex(t))
198 l, r = r ^ t[:blksz/2], l
199 if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
200 return C.ByteString(r + l)
201
202LRAES = {}
203for i in [8, 12, 16, 24, 32]:
204 LRAES['lraes%d' % (8*i)] = LubyRackoffCipher(C.rijndael, i)
205
206###--------------------------------------------------------------------------
207### PMAC.
208
209def ocb_masks(E):
210 blksz = E.__class__.blksz
211 p = poly(8*blksz)
212 x = C.GF(2); xinv = p.modinv(x)
213 z = Z(blksz)
214 L = E.encrypt(z)
215 Lxinv = mul_blk_gf(L, xinv, p)
216 Lgamma = 66*[L]
217 for i in xrange(1, len(Lgamma)):
218 Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
219 return Lgamma, Lxinv
220
221def dump_ocb(E):
222 Lgamma, Lxinv = ocb_masks(E)
223 print 'L x^-1 = %s' % hex(Lxinv)
224 for i, lg in enumerate(Lgamma):
225 print 'L x^%d = %s' % (i, hex(lg))
226
227def pmac1(E, m):
228 blksz = E.__class__.blksz
229 Lgamma, Lxinv = ocb_masks(E)
230 a = o = Z(blksz)
231 i = 1
232 v, tl = blocks(m, blksz)
233 for x in v:
234 b = ntz(i)
235 o ^= Lgamma[b]
236 a ^= E.encrypt(x ^ o)
237 if VERBOSE:
238 print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
239 print 'A[%d]: %s' % (i - 1, hex(a))
240 i += 1
241 if len(tl) == blksz: a ^= tl ^ Lxinv
242 else: a ^= pad10star(tl, blksz)
243 return E.encrypt(a)
244
245def pmac2(E, m):
246 blksz = E.__class__.blksz
247 p = prim(8*blksz)
248 L = E.encrypt(Z(blksz))
249 o = mul_blk_gf(L, 10, p)
250 a = Z(blksz)
251 v, tl = blocks(m, blksz)
252 for x in v:
253 a ^= E.encrypt(x ^ o)
254 o = mul_blk_gf(o, 2, p)
255 if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, 3, p)
256 else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p)
257 return E.encrypt(a)
258
259def ocb3_masks(E):
260 Lgamma, _ = ocb_masks(E)
261 Lstar = Lgamma[0]
262 Ldollar = Lgamma[1]
263 return Lstar, Ldollar, Lgamma[2:]
264
265def dump_ocb3(E):
266 Lstar, Ldollar, Lgamma = ocb3_masks(E)
267 print 'L_* : %s' % hex(Lstar)
268 print 'L_$ : %s' % hex(Ldollar)
269 for i, lg in enumerate(Lgamma[:4]):
270 print 'L_%-8d: %s' % (i, hex(lg))
271
272def pmac3(E, m):
273 blksz = E.__class__.blksz
274 Lstar, Ldollar, Lgamma = ocb3_masks(E)
275 a = o = Z(blksz)
276 i = 1
277 v, tl = blocks0(m, blksz)
278 for x in v:
279 b = ntz(i)
280 o ^= Lgamma[b]
281 a ^= E.encrypt(x ^ o)
282 if VERBOSE:
283 print 'Offset\'_%-2d: %s' % (i, hex(o))
284 print 'AuthSum_%-2d: %s' % (i, hex(a))
285 i += 1
286 if tl:
287 o ^= Lstar
288 a ^= E.encrypt(pad10star(tl, blksz) ^ o)
289 if VERBOSE:
290 print 'Offset\'_* : %s' % hex(o)
291 print 'AuthSum_* : %s' % hex(a)
292 return a
293
294def pmac1_pub(E, m):
295 if VERBOSE: dump_ocb(E)
296 return pmac1(E, m),
297
298def pmac2_pub(E, m):
299 return pmac2(E, m),
300
301def pmac3_pub(E, m):
302 return pmac3(E, m),
303
304def pmacgen(bc):
305 return [(0,), (1,),
306 (3*bc.blksz,),
307 (3*bc.blksz - 5,)]
308
309###--------------------------------------------------------------------------
310### OCB.
311
312## For OCB2, it's important for security that n = log_x (x + 1) is large in
313## the field representations of GF(2^w) used -- in fact, we need more, that
314## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}. The
315## original paper lists the values for 64 and 128, but we support other block
316## sizes, so here's the result of the (rather large, in some cases)
317## computation.
318##
319## Block size log_x (x + 1)
320##
321## 64 9686038906114705801
322## 96 63214690573408919568138788065
323## 128 338793687469689340204974836150077311399
324## 192 161110085006042185925119981866940491651092686475226538785
325## 256 22928580326165511958494515843249267194111962539778797914076675796261938307298
326
327def ocb1(E, n, h, m, tsz = None):
328 ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
329 ## Associated-Data'.
330 blksz = E.__class__.blksz
331 if VERBOSE: dump_ocb(E)
332 Lgamma, Lxinv = ocb_masks(E)
333 if tsz is None: tsz = blksz
334 a = Z(blksz)
335 o = E.encrypt(n ^ Lgamma[0])
336 if VERBOSE: print 'R = %s' % hex(o)
337 i = 1
338 y = C.WriteBuffer()
339 v, tl = blocks(m, blksz)
340 for x in v:
341 b = ntz(i)
342 o ^= Lgamma[b]
343 a ^= x
344 if VERBOSE:
345 print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
346 print 'A[%d]: %s' % (i - 1, hex(a))
347 y.put(E.encrypt(x ^ o) ^ o)
348 i += 1
349 b = ntz(i)
350 o ^= Lgamma[b]
351 n = len(tl)
352 if VERBOSE:
353 print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
354 print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
355 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
356 cfinal = tl ^ yfinal[:n]
357 a ^= o ^ (tl + yfinal[n:])
358 y.put(cfinal)
359 t = E.encrypt(a)
360 if h: t ^= pmac1(E, h)
361 return C.ByteString(y), C.ByteString(t[:tsz])
362
363def ocb2(E, n, h, m, tsz = None):
364 blksz = E.__class__.blksz
365 if tsz is None: tsz = blksz
366 p = prim(8*blksz)
367 L = E.encrypt(n)
368 o = mul_blk_gf(L, 2, p)
369 a = Z(blksz)
370 v, tl = blocks(m, blksz)
371 y = C.WriteBuffer()
372 for x in v:
373 a ^= x
374 y.put(E.encrypt(x ^ o) ^ o)
375 o = mul_blk_gf(o, 2, p)
376 n = len(tl)
377 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
378 cfinal = tl ^ yfinal[:n]
379 a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, 3, p)
380 y.put(cfinal)
381 t = E.encrypt(a)
382 if h: t ^= pmac2(E, h)
383 return C.ByteString(y), C.ByteString(t[:tsz])
384
385OCB3_STRETCH = { 8: (5, 25),
386 12: (6, 33),
387 16: (6, 8),
388 24: (7, 40),
389 32: (7, 120) }
390
391def ocb3(E, n, h, m, tsz = None):
392 blksz = E.__class__.blksz
393 if tsz is None: tsz = blksz
394 Lstar, Ldollar, Lgamma = ocb3_masks(E)
395 if VERBOSE: dump_ocb3(E)
396
397 ## Figure out how much we need to glue onto the nonce. This ends up being
398 ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
399 ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
400 ## v - 1. But this is an annoying way to think about it because of the
401 ## byte misalignment. Instead, think of it as a byte-aligned prefix
402 ## encoding the tag and an `is the nonce full-length' flag, followed by
403 ## optional padding, and then the nonce:
404 ##
405 ## F || N if l(N) = w - f
406 ## F || 0^p || 1 || N otherwise
407 ##
408 ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
409 ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
410 tszbits = C.MP(8*blksz - 1).nbits
411 fwd = tszbits/8 + 1
412 f = tsz << 3 + 8*fwd - tszbits
413
414 ## Form the augmented nonce.
415 nb = C.WriteBuffer()
416 nsz, nwd = len(n), blksz - fwd
417 if nsz == nwd: f |= 1
418 nb.put(C.MP(f).storeb(fwd))
419 if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
420 nb.put(n)
421 nn = C.ByteString(nb)
422 if VERBOSE: print 'N\' : %s' % hex(nn)
423
424 ## Calculate the initial offset.
425 split, shift = OCB3_STRETCH[blksz]
426 splitbits = 1 << split
427 t2ps = C.MP(0).setbit(splitbits)
428 lomask = (C.MP(0).setbit(split) - 1)
429 himask = ~lomask
430 top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
431 ktop = C.MP.loadb(E.encrypt(top))
432 stretch = (ktop << splitbits) | \
433 (((ktop ^ (ktop << shift)) >> (8*blksz - splitbits))%t2ps)
434 o = (stretch >> splitbits - bottom).storeb(blksz)
435 a = C.ByteString.zero(blksz)
436 if VERBOSE:
437 print 'bottom : %d' % bottom
438 print 'Ktop : %s' % hex(ktop.storeb(blksz))
439 print 'Stretch : %s' % hex(stretch.storeb(blksz + (1 << split - 3)))
440 print 'Offset_0 : %s' % hex(o)
441
442 ## Split the message into blocks.
443 i = 1
444 y = C.WriteBuffer()
445 v, tl = blocks0(m, blksz)
446 for x in v:
447 b = ntz(i)
448 o ^= Lgamma[b]
449 a ^= x
450 if VERBOSE:
451 print 'Offset_%-3d: %s' % (i, hex(o))
452 print 'Checksum_%d: %s' % (i, hex(a))
453 y.put(E.encrypt(x ^ o) ^ o)
454 i += 1
455 if tl:
456 o ^= Lstar
457 n = len(tl)
458 pad = E.encrypt(o)
459 a ^= pad10star(tl, blksz)
460 if VERBOSE:
461 print 'Offset_* : %s' % hex(o)
462 print 'Checksum_*: %s' % hex(a)
463 y.put(tl ^ pad[0:n])
464 o ^= Ldollar
465 t = E.encrypt(a ^ o) ^ pmac3(E, h)
466 return C.ByteString(y), C.ByteString(t[:tsz])
467
468def ocbgen(bc):
469 w = bc.blksz
470 return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
471 (w, 0, 3*w),
472 (w, 3*w, 3*w),
473 (w, 0, 3*w + 5),
474 (w, 3*w - 5, 3*w + 5)]
475
476def ocb3gen(bc):
477 w = bc.blksz
478 return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
479 (w - 5, 0, 3*w),
480 (w - 3, 3*w, 3*w),
481 (w - 2, 0, 3*w + 5),
482 (w - 2, 3*w - 5, 3*w + 5)]
483
484###--------------------------------------------------------------------------
485### Main program.
486
487VERBOSE = LRVERBOSE = False
488
489class struct (object):
490 def __init__(me, **kw):
491 me.__dict__.update(kw)
492
493def mct(ocb, bc, ksz, nsz, tsz):
494 k = C.MP(8*tsz).storeb(ksz)
495 E = bc(k)
496 e = C.ByteString('')
497 n = C.MP(1)
498 cbuf = C.WriteBuffer()
499 for i in xrange(128):
500 s = C.ByteString.zero(i)
501 y, t = ocb(E, n.storeb(nsz), s, s, tsz); n += 1; cbuf.put(y).put(t)
502 y, t = ocb(E, n.storeb(nsz), e, s, tsz); n += 1; cbuf.put(y).put(t)
503 y, t = ocb(E, n.storeb(nsz), s, e, tsz); n += 1; cbuf.put(y).put(t)
504 _, t = ocb(E, n.storeb(nsz), C.ByteString(cbuf), e, tsz)
505 print hex(t)
506
507argc = len(argv)
508argi = 1
509
510def usage():
511 print >>stderr, """\
512usage: %s [-v] OCB BLKC OP ARGS...
513 mct KSZ NSZ TSZ
514 kat K N0 TSZ HSZ,MSZ ...
515 lraes W K M""" % argv[0]
516 exit(2)
517
518def arg(must = True, default = None):
519 global argi
520 if argi < argc: argi += 1; return argv[argi - 1]
521 elif not must: return default
522 else: usage()
523
524MODEMAP = { 'ocb1': ocb1,
525 'ocb2': ocb2,
526 'ocb3': ocb3 }
527
528def pat(sz):
529 b = C.WriteBuffer()
530 for i in xrange(sz): b.putu8(i%256)
531 return C.ByteString(b)
532
533opt = arg()
534if opt == '-v': VERBOSE = True; opt = arg()
535ocb = MODEMAP[opt]
536
537bcname = arg()
538bc = None
539for d in LRAES, C.gcprps:
540 try: bc = d[bcname]
541 except KeyError: pass
542 else: break
543if bc is None: raise KeyError, bcname
544
545mode = arg()
546if mode == 'mct':
547 ksz = int(arg()); nsz = int(arg()); tsz = int(arg())
548 mct(ocb, bc, ksz, nsz, tsz)
549 exit(0)
550
551elif mode == 'kat':
552 k = C.bytes(arg())
553 E = bc(k)
554 nspec = arg()
555 if nspec.endswith('+'): ninc = 1; nspec = nspec[:-1]
556 else: ninc = 0
557 n0 = C.bytes(nspec)
558 nz = C.MP.loadb(n0)
559 nsz = len(n0)
560 tsz = int(arg())
561
562 print 'K: %s' % hex(k)
563
564 while True:
565 hmsz = arg(must = False)
566 if hmsz is None: break
567 hsz, msz = map(int, hmsz.split(','))
568 n = nz.storeb(nsz)
569 h = pat(hsz)
570 m = pat(msz)
571 y, t = ocb(E, n, h, m, tsz)
572 print
573 print 'N: %s' % hex(n)
574 print 'A: %s' % hex(h)
575 print 'P: %s' % hex(m)
576 print 'C: %s%s' % (hex(y), hex(t))
577 nz += ninc
578
579elif mode == 'lraes':
580 w = int(arg())
581 k = C.bytes(arg())
582 m = C.bytes(arg())
583 LRVERBOSE = True
584 lr = LubyRackoffCipher(bc, w)
585 E = lr(k)
586 print
587 c = E.encrypt(m)
588 print 'E\'(K, m) = %s' % hex(c)
589
590else:
591 usage()
592
593###----- That's all, folks --------------------------------------------------