Makefile, ocbgen: Handle 512-bit blocks.
[ocb-tv] / ocbgen
CommitLineData
54482987
MW
1#! /usr/bin/python
2### -*-python-*-
3###
4### Generalization of OCB mode for other block sizes
5###
6### (c) 2017 Mark Wooding
7###
8
9###----- Licensing notice ---------------------------------------------------
10###
11### This program is free software; you can redistribute it and/or modify
12### it under the terms of the GNU General Public License as published by
13### the Free Software Foundation; either version 2 of the License, or
14### (at your option) any later version.
15###
16### This program is distributed in the hope that it will be useful,
17### but WITHOUT ANY WARRANTY; without even the implied warranty of
18### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19### GNU General Public License for more details.
20###
21### You should have received a copy of the GNU General Public License
22### along with this program; if not, write to the Free Software Foundation,
23### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
24
25from sys import argv, stderr
26from struct import pack
27from itertools import izip
86082bbc 28from contextlib import contextmanager
54482987
MW
29import catacomb as C
30
31R = C.FibRand(0)
32
33###--------------------------------------------------------------------------
34### Utilities.
35
36def combs(things, k):
37 ii = range(k)
38 n = len(things)
39 while True:
40 yield [things[i] for i in ii]
41 for j in xrange(k):
42 if j == k - 1: lim = n
43 else: lim = ii[j + 1]
44 i = ii[j] + 1
45 if i < lim:
46 ii[j] = i
47 break
48 ii[j] = j
49 else:
50 return
51
52POLYMAP = {}
53
54def poly(nbits):
55 try: return POLYMAP[nbits]
56 except KeyError: pass
57 base = C.GF(0).setbit(nbits).setbit(0)
58 for k in xrange(1, nbits, 2):
59 for cc in combs(range(1, nbits), k):
60 p = base + sum(C.GF(0).setbit(c) for c in cc)
61 if p.irreduciblep(): POLYMAP[nbits] = p; return p
62 raise ValueError, nbits
63
64def prim(nbits):
65 ## No fancy way to do this: I'd need a much cleverer factoring algorithm
66 ## than I have in my pockets.
67 if nbits == 64: cc = [64, 4, 3, 1, 0]
68 elif nbits == 96: cc = [96, 10, 9, 6, 0]
69 elif nbits == 128: cc = [128, 7, 2, 1, 0]
70 elif nbits == 192: cc = [192, 15, 11, 5, 0]
71 elif nbits == 256: cc = [256, 10, 5, 2, 0]
72 else: raise ValueError, 'no field for %d bits' % nbits
73 p = C.GF(0)
74 for c in cc: p = p.setbit(c)
75 return p
76
77def Z(n):
78 return C.ByteString.zero(n)
79
80def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
81
82def with_lastp(it):
83 it = iter(it)
84 try: j = next(it)
85 except StopIteration: raise ValueError, 'empty iter'
86 lastp = False
87 while not lastp:
88 i = j
89 try: j = next(it)
90 except StopIteration: lastp = True
91 yield i, lastp
92
93def safehex(x):
94 if len(x): return hex(x)
95 else: return '""'
96
97def keylens(ksz):
98 sel = []
99 if isinstance(ksz, C.KeySZSet): kk = ksz.set
100 elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
101 elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
102 kk = list(kk); kk = kk[:]
103 n = len(kk)
104 while n and len(sel) < 4:
105 i = R.range(n)
106 n -= 1
107 kk[i], kk[n] = kk[n], kk[i]
108 sel.append(kk[n])
109 return sel
110
111def pad0star(m, w):
112 n = len(m)
113 if not n: r = w
114 else: r = (-len(m))%w
115 if r: m += Z(r)
116 return C.ByteString(m)
117
118def pad10star(m, w):
119 r = w - len(m)%w
120 if r: m += '\x80' + Z(r - 1)
121 return C.ByteString(m)
122
123def ntz(i):
124 j = 0
125 while (i&1) == 0: i >>= 1; j += 1
126 return j
127
128def blocks(x, w):
129 v, i, n = [], 0, len(x)
130 while n - i > w:
131 v.append(C.ByteString(x[i:i + w]))
132 i += w
133 return v, C.ByteString(x[i:])
134
135EMPTY = C.bytes('')
136
137def blocks0(x, w):
138 v, tl = blocks(x, w)
139 if len(tl) == w: v.append(tl); tl = EMPTY
140 return v, tl
141
142###--------------------------------------------------------------------------
143### Luby--Rackoff large-block ciphers.
144
145class LubyRackoffCipher (type):
146 def __new__(cls, bc, blksz):
147 assert blksz%2 == 0
148 assert blksz <= 2*bc.blksz
149 name = '%s-lr[%d]' % (bc.name, 8*blksz)
150 me = type(name, (LubyRackoffBase,), {})
151 me.name = name
152 me.blksz = blksz
153 me.keysz = bc.keysz
154 me.bc = bc
155 return me
156
86082bbc
MW
157@contextmanager
158def muffle():
159 global VERBOSE, LRVERBOSE
160 _v, _lrv = VERBOSE, LRVERBOSE
161 try:
162 VERBOSE = LRVERBOSE = False
163 yield None
164 finally:
165 VERBOSE, LRVERBOSE = _v, _lrv
166
54482987
MW
167class LubyRackoffBase (object):
168 NR = 4 # for strong-PRP security
169 def __init__(me, k):
170 if LRVERBOSE: print 'K = %s' % hex(k)
171 bc, blksz = me.__class__.bc, me.__class__.blksz
86082bbc 172 with muffle(): E = bc(k)
54482987
MW
173 me.f = []
174 ksz = len(k)
175 i = C.MP(0)
176 for j in xrange(me.NR):
177 b = C.WriteBuffer()
178 while b.size < ksz:
86082bbc 179 with muffle(): x = E.encrypt(i.storeb(bc.blksz))
54482987
MW
180 b.put(x)
181 if LRVERBOSE: print 'E(K; [%d]) = %s' % (i, hex(x))
182 i += 1
183 kj = C.ByteString(C.ByteString(b)[0:ksz])
184 if LRVERBOSE: print 'K_%d = %s' % (j, hex(kj))
86082bbc 185 with muffle(): me.f.append(bc(kj))
54482987
MW
186 def encrypt(me, m):
187 bc, blksz = me.__class__.bc, me.__class__.blksz
188 assert len(m) == blksz
189 l, r = C.ByteString(m[:blksz/2]), C.ByteString(m[blksz/2:])
190 if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
191 for j in xrange(me.NR):
192 l0 = pad0star(l, bc.blksz)
86082bbc 193 with muffle(): t = me.f[j].encrypt(l0)
54482987
MW
194 l, r = r ^ t[:blksz/2], l
195 if LRVERBOSE:
196 print 'E(K_%d; L_%d || 0^*) = %s' % (j, j, hex(t))
197 print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
198 return C.ByteString(r + l)
199 def decrypt(me, c):
200 bc, blksz = me.__class__.bc, me.__class__.blksz
201 assert len(c) == blksz
202 l, r = C.ByteString(c[:blksz/2]), C.ByteString(c[blksz/2:])
203 for j in xrange(me.NR - 1, -1, -1):
204 l0 = pad0star(l, bc.blksz)
86082bbc 205 with muffle(): t = me.f[j].encrypt(l0)
54482987
MW
206 if LRVERBOSE:
207 print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
208 print 'E(K_%d; L_%d || 0^*) = %s' % (j + 1, j + 1, hex(t))
209 l, r = r ^ t[:blksz/2], l
210 if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
211 return C.ByteString(r + l)
212
213LRAES = {}
214for i in [8, 12, 16, 24, 32]:
215 LRAES['lraes%d' % (8*i)] = LubyRackoffCipher(C.rijndael, i)
86082bbc 216LRAES['dlraes512'] = LubyRackoffCipher(LubyRackoffCipher(C.rijndael, 32), 64)
54482987
MW
217
218###--------------------------------------------------------------------------
219### PMAC.
220
221def ocb_masks(E):
222 blksz = E.__class__.blksz
223 p = poly(8*blksz)
224 x = C.GF(2); xinv = p.modinv(x)
225 z = Z(blksz)
226 L = E.encrypt(z)
227 Lxinv = mul_blk_gf(L, xinv, p)
228 Lgamma = 66*[L]
229 for i in xrange(1, len(Lgamma)):
230 Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
231 return Lgamma, Lxinv
232
233def dump_ocb(E):
234 Lgamma, Lxinv = ocb_masks(E)
235 print 'L x^-1 = %s' % hex(Lxinv)
236 for i, lg in enumerate(Lgamma):
237 print 'L x^%d = %s' % (i, hex(lg))
238
239def pmac1(E, m):
240 blksz = E.__class__.blksz
241 Lgamma, Lxinv = ocb_masks(E)
242 a = o = Z(blksz)
243 i = 1
244 v, tl = blocks(m, blksz)
245 for x in v:
246 b = ntz(i)
247 o ^= Lgamma[b]
248 a ^= E.encrypt(x ^ o)
249 if VERBOSE:
250 print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
251 print 'A[%d]: %s' % (i - 1, hex(a))
252 i += 1
253 if len(tl) == blksz: a ^= tl ^ Lxinv
254 else: a ^= pad10star(tl, blksz)
255 return E.encrypt(a)
256
257def pmac2(E, m):
258 blksz = E.__class__.blksz
259 p = prim(8*blksz)
260 L = E.encrypt(Z(blksz))
261 o = mul_blk_gf(L, 10, p)
262 a = Z(blksz)
263 v, tl = blocks(m, blksz)
264 for x in v:
265 a ^= E.encrypt(x ^ o)
266 o = mul_blk_gf(o, 2, p)
267 if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, 3, p)
268 else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p)
269 return E.encrypt(a)
270
271def ocb3_masks(E):
272 Lgamma, _ = ocb_masks(E)
273 Lstar = Lgamma[0]
274 Ldollar = Lgamma[1]
275 return Lstar, Ldollar, Lgamma[2:]
276
277def dump_ocb3(E):
278 Lstar, Ldollar, Lgamma = ocb3_masks(E)
279 print 'L_* : %s' % hex(Lstar)
280 print 'L_$ : %s' % hex(Ldollar)
281 for i, lg in enumerate(Lgamma[:4]):
282 print 'L_%-8d: %s' % (i, hex(lg))
283
284def pmac3(E, m):
285 blksz = E.__class__.blksz
286 Lstar, Ldollar, Lgamma = ocb3_masks(E)
287 a = o = Z(blksz)
288 i = 1
289 v, tl = blocks0(m, blksz)
290 for x in v:
291 b = ntz(i)
292 o ^= Lgamma[b]
293 a ^= E.encrypt(x ^ o)
294 if VERBOSE:
295 print 'Offset\'_%-2d: %s' % (i, hex(o))
296 print 'AuthSum_%-2d: %s' % (i, hex(a))
297 i += 1
298 if tl:
299 o ^= Lstar
300 a ^= E.encrypt(pad10star(tl, blksz) ^ o)
301 if VERBOSE:
302 print 'Offset\'_* : %s' % hex(o)
303 print 'AuthSum_* : %s' % hex(a)
304 return a
305
306def pmac1_pub(E, m):
307 if VERBOSE: dump_ocb(E)
308 return pmac1(E, m),
309
310def pmac2_pub(E, m):
311 return pmac2(E, m),
312
313def pmac3_pub(E, m):
314 return pmac3(E, m),
315
316def pmacgen(bc):
317 return [(0,), (1,),
318 (3*bc.blksz,),
319 (3*bc.blksz - 5,)]
320
321###--------------------------------------------------------------------------
322### OCB.
323
324## For OCB2, it's important for security that n = log_x (x + 1) is large in
325## the field representations of GF(2^w) used -- in fact, we need more, that
326## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}. The
327## original paper lists the values for 64 and 128, but we support other block
328## sizes, so here's the result of the (rather large, in some cases)
329## computation.
330##
331## Block size log_x (x + 1)
332##
333## 64 9686038906114705801
334## 96 63214690573408919568138788065
335## 128 338793687469689340204974836150077311399
336## 192 161110085006042185925119981866940491651092686475226538785
337## 256 22928580326165511958494515843249267194111962539778797914076675796261938307298
338
339def ocb1(E, n, h, m, tsz = None):
340 ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
341 ## Associated-Data'.
342 blksz = E.__class__.blksz
343 if VERBOSE: dump_ocb(E)
344 Lgamma, Lxinv = ocb_masks(E)
345 if tsz is None: tsz = blksz
346 a = Z(blksz)
347 o = E.encrypt(n ^ Lgamma[0])
348 if VERBOSE: print 'R = %s' % hex(o)
349 i = 1
350 y = C.WriteBuffer()
351 v, tl = blocks(m, blksz)
352 for x in v:
353 b = ntz(i)
354 o ^= Lgamma[b]
355 a ^= x
356 if VERBOSE:
357 print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
358 print 'A[%d]: %s' % (i - 1, hex(a))
359 y.put(E.encrypt(x ^ o) ^ o)
360 i += 1
361 b = ntz(i)
362 o ^= Lgamma[b]
363 n = len(tl)
364 if VERBOSE:
365 print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
366 print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
367 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
368 cfinal = tl ^ yfinal[:n]
369 a ^= o ^ (tl + yfinal[n:])
370 y.put(cfinal)
371 t = E.encrypt(a)
372 if h: t ^= pmac1(E, h)
373 return C.ByteString(y), C.ByteString(t[:tsz])
374
375def ocb2(E, n, h, m, tsz = None):
376 blksz = E.__class__.blksz
377 if tsz is None: tsz = blksz
378 p = prim(8*blksz)
379 L = E.encrypt(n)
380 o = mul_blk_gf(L, 2, p)
381 a = Z(blksz)
382 v, tl = blocks(m, blksz)
383 y = C.WriteBuffer()
384 for x in v:
385 a ^= x
386 y.put(E.encrypt(x ^ o) ^ o)
387 o = mul_blk_gf(o, 2, p)
388 n = len(tl)
389 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
390 cfinal = tl ^ yfinal[:n]
391 a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, 3, p)
392 y.put(cfinal)
393 t = E.encrypt(a)
394 if h: t ^= pmac2(E, h)
395 return C.ByteString(y), C.ByteString(t[:tsz])
396
397OCB3_STRETCH = { 8: (5, 25),
398 12: (6, 33),
399 16: (6, 8),
400 24: (7, 40),
86082bbc
MW
401 32: (7, 120),
402 64: (8, 240) }
54482987
MW
403
404def ocb3(E, n, h, m, tsz = None):
405 blksz = E.__class__.blksz
406 if tsz is None: tsz = blksz
407 Lstar, Ldollar, Lgamma = ocb3_masks(E)
408 if VERBOSE: dump_ocb3(E)
409
410 ## Figure out how much we need to glue onto the nonce. This ends up being
411 ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
412 ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
413 ## v - 1. But this is an annoying way to think about it because of the
414 ## byte misalignment. Instead, think of it as a byte-aligned prefix
415 ## encoding the tag and an `is the nonce full-length' flag, followed by
416 ## optional padding, and then the nonce:
417 ##
418 ## F || N if l(N) = w - f
419 ## F || 0^p || 1 || N otherwise
420 ##
421 ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
422 ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
423 tszbits = C.MP(8*blksz - 1).nbits
424 fwd = tszbits/8 + 1
425 f = tsz << 3 + 8*fwd - tszbits
426
427 ## Form the augmented nonce.
428 nb = C.WriteBuffer()
429 nsz, nwd = len(n), blksz - fwd
430 if nsz == nwd: f |= 1
431 nb.put(C.MP(f).storeb(fwd))
432 if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
433 nb.put(n)
434 nn = C.ByteString(nb)
435 if VERBOSE: print 'N\' : %s' % hex(nn)
436
437 ## Calculate the initial offset.
438 split, shift = OCB3_STRETCH[blksz]
439 splitbits = 1 << split
440 t2ps = C.MP(0).setbit(splitbits)
441 lomask = (C.MP(0).setbit(split) - 1)
442 himask = ~lomask
443 top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
444 ktop = C.MP.loadb(E.encrypt(top))
445 stretch = (ktop << splitbits) | \
446 (((ktop ^ (ktop << shift)) >> (8*blksz - splitbits))%t2ps)
447 o = (stretch >> splitbits - bottom).storeb(blksz)
448 a = C.ByteString.zero(blksz)
449 if VERBOSE:
450 print 'bottom : %d' % bottom
451 print 'Ktop : %s' % hex(ktop.storeb(blksz))
452 print 'Stretch : %s' % hex(stretch.storeb(blksz + (1 << split - 3)))
453 print 'Offset_0 : %s' % hex(o)
454
455 ## Split the message into blocks.
456 i = 1
457 y = C.WriteBuffer()
458 v, tl = blocks0(m, blksz)
459 for x in v:
460 b = ntz(i)
461 o ^= Lgamma[b]
462 a ^= x
463 if VERBOSE:
464 print 'Offset_%-3d: %s' % (i, hex(o))
465 print 'Checksum_%d: %s' % (i, hex(a))
466 y.put(E.encrypt(x ^ o) ^ o)
467 i += 1
468 if tl:
469 o ^= Lstar
470 n = len(tl)
471 pad = E.encrypt(o)
472 a ^= pad10star(tl, blksz)
473 if VERBOSE:
474 print 'Offset_* : %s' % hex(o)
475 print 'Checksum_*: %s' % hex(a)
476 y.put(tl ^ pad[0:n])
477 o ^= Ldollar
478 t = E.encrypt(a ^ o) ^ pmac3(E, h)
479 return C.ByteString(y), C.ByteString(t[:tsz])
480
481def ocbgen(bc):
482 w = bc.blksz
483 return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
484 (w, 0, 3*w),
485 (w, 3*w, 3*w),
486 (w, 0, 3*w + 5),
487 (w, 3*w - 5, 3*w + 5)]
488
489def ocb3gen(bc):
490 w = bc.blksz
491 return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
492 (w - 5, 0, 3*w),
493 (w - 3, 3*w, 3*w),
494 (w - 2, 0, 3*w + 5),
495 (w - 2, 3*w - 5, 3*w + 5)]
496
497###--------------------------------------------------------------------------
498### Main program.
499
500VERBOSE = LRVERBOSE = False
501
502class struct (object):
503 def __init__(me, **kw):
504 me.__dict__.update(kw)
505
506def mct(ocb, bc, ksz, nsz, tsz):
507 k = C.MP(8*tsz).storeb(ksz)
508 E = bc(k)
509 e = C.ByteString('')
510 n = C.MP(1)
511 cbuf = C.WriteBuffer()
512 for i in xrange(128):
513 s = C.ByteString.zero(i)
514 y, t = ocb(E, n.storeb(nsz), s, s, tsz); n += 1; cbuf.put(y).put(t)
515 y, t = ocb(E, n.storeb(nsz), e, s, tsz); n += 1; cbuf.put(y).put(t)
516 y, t = ocb(E, n.storeb(nsz), s, e, tsz); n += 1; cbuf.put(y).put(t)
517 _, t = ocb(E, n.storeb(nsz), C.ByteString(cbuf), e, tsz)
518 print hex(t)
519
520argc = len(argv)
521argi = 1
522
523def usage():
524 print >>stderr, """\
525usage: %s [-v] OCB BLKC OP ARGS...
526 mct KSZ NSZ TSZ
527 kat K N0 TSZ HSZ,MSZ ...
528 lraes W K M""" % argv[0]
529 exit(2)
530
531def arg(must = True, default = None):
532 global argi
533 if argi < argc: argi += 1; return argv[argi - 1]
534 elif not must: return default
535 else: usage()
536
537MODEMAP = { 'ocb1': ocb1,
538 'ocb2': ocb2,
539 'ocb3': ocb3 }
540
541def pat(sz):
542 b = C.WriteBuffer()
543 for i in xrange(sz): b.putu8(i%256)
544 return C.ByteString(b)
545
546opt = arg()
547if opt == '-v': VERBOSE = True; opt = arg()
548ocb = MODEMAP[opt]
549
550bcname = arg()
551bc = None
552for d in LRAES, C.gcprps:
553 try: bc = d[bcname]
554 except KeyError: pass
555 else: break
556if bc is None: raise KeyError, bcname
557
558mode = arg()
559if mode == 'mct':
560 ksz = int(arg()); nsz = int(arg()); tsz = int(arg())
561 mct(ocb, bc, ksz, nsz, tsz)
562 exit(0)
563
564elif mode == 'kat':
565 k = C.bytes(arg())
566 E = bc(k)
567 nspec = arg()
568 if nspec.endswith('+'): ninc = 1; nspec = nspec[:-1]
569 else: ninc = 0
570 n0 = C.bytes(nspec)
571 nz = C.MP.loadb(n0)
572 nsz = len(n0)
573 tsz = int(arg())
574
575 print 'K: %s' % hex(k)
576
577 while True:
578 hmsz = arg(must = False)
579 if hmsz is None: break
580 hsz, msz = map(int, hmsz.split(','))
581 n = nz.storeb(nsz)
582 h = pat(hsz)
583 m = pat(msz)
584 y, t = ocb(E, n, h, m, tsz)
585 print
586 print 'N: %s' % hex(n)
587 print 'A: %s' % hex(h)
588 print 'P: %s' % hex(m)
589 print 'C: %s%s' % (hex(y), hex(t))
590 nz += ninc
591
592elif mode == 'lraes':
593 w = int(arg())
594 k = C.bytes(arg())
595 m = C.bytes(arg())
596 LRVERBOSE = True
597 lr = LubyRackoffCipher(bc, w)
598 E = lr(k)
599 print
600 c = E.encrypt(m)
601 print 'E\'(K, m) = %s' % hex(c)
602
603else:
604 usage()
605
606###----- That's all, folks --------------------------------------------------