symm/ocb3.h, symm/ocb3-def.h: Implement the OCB3 auth'ned encryption mode.
[catacomb] / utils / advmodes
1 #! /usr/bin/python
2
3 from sys import argv, exit
4 from struct import unpack, pack
5 from itertools import izip
6 import catacomb as C
7
8 R = C.FibRand(0)
9
10 ###--------------------------------------------------------------------------
11 ### Utilities.
12
13 def combs(things, k):
14 ii = range(k)
15 n = len(things)
16 while True:
17 yield [things[i] for i in ii]
18 for j in xrange(k):
19 if j == k - 1: lim = n
20 else: lim = ii[j + 1]
21 i = ii[j] + 1
22 if i < lim:
23 ii[j] = i
24 break
25 ii[j] = j
26 else:
27 return
28
29 POLYMAP = {}
30
31 def poly(nbits):
32 try: return POLYMAP[nbits]
33 except KeyError: pass
34 base = C.GF(0).setbit(nbits).setbit(0)
35 for k in xrange(1, nbits, 2):
36 for cc in combs(range(1, nbits), k):
37 p = base + sum(C.GF(0).setbit(c) for c in cc)
38 if p.irreduciblep(): POLYMAP[nbits] = p; return p
39 raise ValueError, nbits
40
41 def prim(nbits):
42 ## No fancy way to do this: I'd need a much cleverer factoring algorithm
43 ## than I have in my pockets.
44 if nbits == 64: cc = [64, 4, 3, 1, 0]
45 elif nbits == 96: cc = [96, 10, 9, 6, 0]
46 elif nbits == 128: cc = [128, 7, 2, 1, 0]
47 elif nbits == 192: cc = [192, 15, 11, 5, 0]
48 elif nbits == 256: cc = [256, 10, 5, 2, 0]
49 else: raise ValueError, 'no field for %d bits' % nbits
50 p = C.GF(0)
51 for c in cc: p = p.setbit(c)
52 return p
53
54 def Z(n):
55 return C.ByteString.zero(n)
56
57 def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
58
59 def with_lastp(it):
60 it = iter(it)
61 try: j = next(it)
62 except StopIteration: raise ValueError, 'empty iter'
63 lastp = False
64 while not lastp:
65 i = j
66 try: j = next(it)
67 except StopIteration: lastp = True
68 yield i, lastp
69
70 def safehex(x):
71 if len(x): return hex(x)
72 else: return '""'
73
74 def keylens(ksz):
75 sel = []
76 if isinstance(ksz, C.KeySZSet): kk = ksz.set
77 elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
78 elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
79 kk = list(kk); kk = kk[:]
80 n = len(kk)
81 while n and len(sel) < 4:
82 i = R.range(n)
83 n -= 1
84 kk[i], kk[n] = kk[n], kk[i]
85 sel.append(kk[n])
86 return sel
87
88 def pad0star(m, w):
89 n = len(m)
90 if not n: r = w
91 else: r = (-len(m))%w
92 if r: m += Z(r)
93 return C.ByteString(m)
94
95 def pad10star(m, w):
96 r = w - len(m)%w
97 if r: m += '\x80' + Z(r - 1)
98 return C.ByteString(m)
99
100 def ntz(i):
101 j = 0
102 while (i&1) == 0: i >>= 1; j += 1
103 return j
104
105 def blocks(x, w):
106 v, i, n = [], 0, len(x)
107 while n - i > w:
108 v.append(C.ByteString(x[i:i + w]))
109 i += w
110 return v, C.ByteString(x[i:])
111
112 EMPTY = C.bytes('')
113
114 def blocks0(x, w):
115 v, tl = blocks(x, w)
116 if len(tl) == w: v.append(tl); tl = EMPTY
117 return v, tl
118
119 def dummygen(bc): return []
120
121 CUSTOM = {}
122
123 ###--------------------------------------------------------------------------
124 ### RC6.
125
126 class RC6Cipher (type):
127 def __new__(cls, w, r):
128 name = 'rc6-%d/%d' % (w, r)
129 me = type(name, (RC6Base,), {})
130 me.name = name
131 me.r = r
132 me.w = w
133 me.blksz = w/2
134 me.keysz = C.KeySZRange(me.blksz, 1, 255, 1)
135 return me
136
137 def rotw(w):
138 return w.bit_length() - 1
139
140 def rol(w, x, n):
141 m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1
142 return ((x&m0) << n) | (x >> (w - n))&m1
143
144 def ror(w, x, n):
145 m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1
146 return ((x&m0) << (w - n)) | (x >> n)&m1
147
148 class RC6Base (object):
149
150 ## Magic constants.
151 P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06)
152 Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188)
153
154 def __init__(me, k):
155
156 ## Build the magic numbers.
157 P = me.P400 >> (400 - me.w)
158 if P%2 == 0: P += 1
159 Q = me.Q400 >> (400 - me.w)
160 if Q%2 == 0: Q += 1
161 M = C.MP(0).setbit(me.w) - 1
162
163 ## Convert the key into words.
164 wb = me.w/8
165 c = (len(k) + wb - 1)/wb
166 kb, ktl = blocks(k, me.w/8)
167 L = map(C.MP.loadl, kb + [ktl])
168 assert c == len(L)
169
170 ## Build the subkey table.
171 me.d = rotw(me.w)
172 n = 2*me.r + 4
173 S = [(P + i*Q)&M for i in xrange(n)]
174
175 ##for j in xrange(c):
176 ## print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
177 ##for i in xrange(n):
178 ## print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
179
180 i = j = 0
181 A = B = C.MP(0)
182
183 for s in xrange(3*max(c, n)):
184 A = S[i] = rol(me.w, S[i] + A + B, 3)
185 B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d))
186 ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
187 ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
188 i = (i + 1)%n
189 j = (j + 1)%c
190
191 ## Done.
192 me.s = S
193
194 def encrypt(me, x):
195 M = C.MP(0).setbit(me.w) - 1
196 a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0])
197 b = (b + me.s[0])&M
198 d = (d + me.s[1])&M
199 ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0'))
200 ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0'))
201 for i in xrange(2, 2*me.r + 2, 2):
202 t = rol(me.w, 2*b*b + b, me.d)
203 u = rol(me.w, 2*d*d + d, me.d)
204 a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M
205 c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M
206 ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
207 ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
208 a, b, c, d = b, c, d, a
209 a = (a + me.s[2*me.r + 2])&M
210 c = (c + me.s[2*me.r + 3])&M
211 ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
212 ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
213 return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
214 c.storel(me.blksz/4) + d.storel(me.blksz/4))
215
216 def decrypt(me, x):
217 M = C.MP(0).setbit(me.w) - 1
218 a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4))
219 c = (c - me.s[2*me.r + 3])&M
220 a = (a - me.s[2*me.r + 2])&M
221 for i in xrange(2*me.r + 1, 1, -2):
222 a, b, c, d = d, a, b, c
223 u = rol(me.w, 2*d*d + d, me.d)
224 t = rol(me.w, 2*b*b + b, me.d)
225 c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u
226 a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t
227 a = (a + s[2*me.r + 2])&M
228 c = (c + s[2*me.r + 3])&M
229 return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
230 c.storel(me.blksz/4) + d.storel(me.blksz/4))
231
232 for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16),
233 (32, 20), (48, 16), (64, 16), (96, 16), (128, 16),
234 (192, 16), (256, 16), (400, 16)]:
235 CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r)
236
237 ###--------------------------------------------------------------------------
238 ### OMAC (or CMAC).
239
240 def omac_masks(E):
241 blksz = E.__class__.blksz
242 p = poly(8*blksz)
243 z = Z(blksz)
244 L = E.encrypt(z)
245 m0 = mul_blk_gf(L, 2, p)
246 m1 = mul_blk_gf(m0, 2, p)
247 return m0, m1
248
249 def dump_omac(E):
250 blksz = E.__class__.blksz
251 m0, m1 = omac_masks(E)
252 print 'L = %s' % hex(E.encrypt(Z(blksz)))
253 print 'm0 = %s' % hex(m0)
254 print 'm1 = %s' % hex(m1)
255 for t in xrange(3):
256 print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz))))
257 print 'z%d = %s' % (t, hex(omac(E, t, '')))
258
259 def omac(E, t, m):
260 blksz = E.__class__.blksz
261 m0, m1 = omac_masks(E)
262 a = Z(blksz)
263 if t is not None: m = C.MP(t).storeb(blksz) + m
264 v, tl = blocks(m, blksz)
265 for x in v: a = E.encrypt(a ^ x)
266 r = blksz - len(tl)
267 if r == 0:
268 a = E.encrypt(a ^ tl ^ m0)
269 else:
270 pad = pad10star(tl, blksz)
271 a = E.encrypt(a ^ pad ^ m1)
272 return a
273
274 def cmac(E, m):
275 if VERBOSE: dump_omac(E)
276 return omac(E, None, m),
277
278 def cmacgen(bc):
279 return [(0,), (1,),
280 (3*bc.blksz,),
281 (3*bc.blksz - 5,)]
282
283 ###--------------------------------------------------------------------------
284 ### Counter mode.
285
286 def ctr(E, m, c0):
287 blksz = E.__class__.blksz
288 y = C.WriteBuffer()
289 c = C.MP.loadb(c0)
290 while y.size < len(m):
291 y.put(E.encrypt(c.storeb(blksz)))
292 c += 1
293 return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
294
295 ###--------------------------------------------------------------------------
296 ### GCM.
297
298 def gcm_mangle(x):
299 y = C.WriteBuffer()
300 for b in x:
301 b = ord(b)
302 bb = 0
303 for i in xrange(8):
304 bb <<= 1
305 if b&1: bb |= 1
306 b >>= 1
307 y.putu8(bb)
308 return C.ByteString(y)
309
310 def gcm_mul(x, y):
311 w = len(x)
312 p = poly(8*w)
313 u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
314 z = (u*v)%p
315 return gcm_mangle(z.storel(w))
316
317 def gcm_pow(x, n):
318 w = len(x)
319 p = poly(8*w)
320 u = C.GF.loadl(gcm_mangle(x))
321 z = pow(u, n, p)
322 return gcm_mangle(z.storel(w))
323
324 def gcm_ctr(E, m, c0):
325 y = C.WriteBuffer()
326 pre = c0[:-4]
327 c, = unpack('>L', c0[-4:])
328 while y.size < len(m):
329 c += 1
330 y.put(E.encrypt(pre + pack('>L', c)))
331 return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
332
333 def g(what, x, m, a0 = None):
334 n = len(x)
335 if a0 is None: a = Z(n)
336 else: a = a0
337 i = 0
338 for b in blocks0(m, n)[0]:
339 a = gcm_mul(a ^ b, x)
340 if VERBOSE: print '%s[%d] = %s -> %s' % (what, i, hex(b), hex(a))
341 i += 1
342 return a
343
344 def gcm_pad(w, x):
345 return C.ByteString(x + Z(-len(x)%w))
346
347 def gcm_lens(w, a, b):
348 if w < 12: n = w
349 else: n = w/2
350 return C.ByteString(C.MP(a).storeb(n) + C.MP(b).storeb(n))
351
352 def ghash(whata, whatb, x, a, b):
353 w = len(x)
354 ha = g(whata, x, gcm_pad(w, a))
355 hb = g(whatb, x, gcm_pad(w, b))
356 if a:
357 hc = gcm_mul(ha, gcm_pow(x, (len(b) + w - 1)/w)) ^ hb
358 if VERBOSE: print '%s || %s -> %s' % (whata, whatb, hex(hc))
359 else:
360 hc = hb
361 return g(whatb, x, gcm_lens(w, 8*len(a), 8*len(b)), hc)
362
363 def gcmenc(E, n, h, m, tsz = None):
364 w = E.__class__.blksz
365 x = E.encrypt(Z(w))
366 if VERBOSE: print 'x = %s' % hex(x)
367 if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
368 else: c0 = ghash('?', 'n', x, EMPTY, n)
369 if VERBOSE: print 'c0 = %s' % hex(c0)
370 y = gcm_ctr(E, m, c0)
371 t = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
372 return y, t
373
374 def gcmdec(E, n, h, y, t):
375 w = E.__class__.blksz
376 x = E.encrypt(Z(w))
377 if VERBOSE: print 'x = %s' % hex(x)
378 if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
379 else: c0 = ghash('?', 'n', x, EMPTY, n)
380 if VERBOSE: print 'c0 = %s' % hex(c0)
381 m = gcm_ctr(E, y, c0)
382 tt = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
383 if t == tt: return m,
384 else: return None,
385
386 def gcmgen(bc):
387 return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
388 (bc.blksz, 3*bc.blksz, 3*bc.blksz),
389 (bc.blksz - 4, bc.blksz + 3, 3*bc.blksz + 9),
390 (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
391
392 ###--------------------------------------------------------------------------
393 ### CCM.
394
395 def stbe(n, w): return C.MP(n).storeb(w)
396
397 def ccm_fmthdr(blksz, n, hsz, msz, tsz):
398 b = C.WriteBuffer()
399 if blksz == 8:
400 q = blksz - len(n) - 1
401 f = 0
402 if hsz: f |= 0x40
403 f |= (tsz - 1) << 3
404 f |= q - 1
405 b.putu8(f).put(n).put(stbe(msz, q))
406 elif blksz == 16:
407 q = blksz - len(n) - 1
408 f = 0
409 if hsz: f |= 0x40
410 f |= (tsz - 2)/2 << 3
411 f |= q - 1
412 b.putu8(f).put(n).put(stbe(msz, q))
413 else:
414 q = blksz - len(n) - 2
415 f0 = f1 = 0
416 if hsz: f1 |= 0x80
417 f0 |= tsz
418 f1 |= q
419 b.putu8(f0).putu8(f1).put(n).put(stbe(msz, q))
420 b = C.ByteString(b)
421 if VERBOSE: print 'hdr = %s' % hex(b)
422 return b
423
424 def ccm_fmtctr(blksz, n, i = 0):
425 b = C.WriteBuffer()
426 if blksz == 8 or blksz == 16:
427 q = blksz - len(n) - 1
428 b.putu8(q - 1).put(n).put(stbe(i, q))
429 else:
430 q = blksz - len(n) - 2
431 b.putu8(0).putu8(q).put(n).put(stbe(i, q))
432 b = C.ByteString(b)
433 if VERBOSE: print 'ctr = %s' % hex(b)
434 return b
435
436 def ccmaad(b, h, blksz):
437 hsz = len(h)
438 if not hsz: pass
439 elif hsz < 0xfffe: b.putu16(hsz)
440 elif hsz <= 0xffffffff: b.putu16(0xfffe).putu32(hsz)
441 else: b.putu16(0xffff).putu64(hsz)
442 b.put(h); b.zero((-b.size)%blksz)
443
444 def ccmenc(E, n, h, m, tsz = None):
445 blksz = E.__class__.blksz
446 if tsz is None: tsz = blksz
447 b = C.WriteBuffer()
448 b.put(ccm_fmthdr(blksz, n, len(h), len(m), tsz))
449 ccmaad(b, h, blksz)
450 b.put(m); b.zero((-b.size)%blksz)
451 b = C.ByteString(b)
452 a = Z(blksz)
453 v, _ = blocks0(b, blksz)
454 i = 0
455 for x in v:
456 a = E.encrypt(a ^ x)
457 if VERBOSE:
458 print 'b[%d] = %s' % (i, hex(x))
459 print 'a[%d] = %s' % (i + 1, hex(a))
460 i += 1
461 y = ctr(E, a + m, ccm_fmtctr(blksz, n))
462 return C.ByteString(y[blksz:]), C.ByteString(y[0:tsz])
463
464 def ccmdec(E, n, h, y, t):
465 blksz = E.__class__.blksz
466 tsz = len(t)
467 b = C.WriteBuffer()
468 b.put(ccm_fmthdr(blksz, n, len(h), len(y), tsz))
469 ccmaad(b, h, blksz)
470 mm = ctr(E, t + Z(blksz - tsz) + y, ccm_fmtctr(blksz, n))
471 u, m = C.ByteString(mm[0:tsz]), C.ByteString(mm[blksz:])
472 b.put(m); b.zero((-b.size)%blksz)
473 b = C.ByteString(b)
474 a = Z(blksz)
475 v, _ = blocks0(b, blksz)
476 i = 0
477 for x in v:
478 a = E.encrypt(a ^ x)
479 if VERBOSE:
480 print 'b[%d] = %s' % (i, hex(x))
481 print 'a[%d] = %s' % (i + 1, hex(a))
482 i += 1
483 if u == a[:tsz]: return m,
484 else: return None,
485
486 def ccmgen(bc):
487 bsz = bc.blksz
488 return [(bsz - 5, 0, 0, 4), (bsz - 5, 1, 0, 4), (bsz - 5, 0, 1, 4),
489 (bsz/2 + 1, 3*bc.blksz, 3*bc.blksz),
490 (bsz/2 + 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
491
492 ###--------------------------------------------------------------------------
493 ### EAX.
494
495 def eaxenc(E, n, h, m, tsz = None):
496 if VERBOSE:
497 print 'k = %s' % hex(k)
498 print 'n = %s' % hex(n)
499 print 'h = %s' % hex(h)
500 print 'm = %s' % hex(m)
501 dump_omac(E)
502 if tsz is None: tsz = E.__class__.blksz
503 c0 = omac(E, 0, n)
504 y = ctr(E, m, c0)
505 ht = omac(E, 1, h)
506 yt = omac(E, 2, y)
507 if VERBOSE:
508 print 'c0 = %s' % hex(c0)
509 print 'ht = %s' % hex(ht)
510 print 'yt = %s' % hex(yt)
511 return y, C.ByteString((c0 ^ ht ^ yt)[:tsz])
512
513 def eaxdec(E, n, h, y, t):
514 if VERBOSE:
515 print 'k = %s' % hex(k)
516 print 'n = %s' % hex(n)
517 print 'h = %s' % hex(h)
518 print 'y = %s' % hex(y)
519 print 't = %s' % hex(t)
520 dump_omac(E)
521 c0 = omac(E, 0, n)
522 m = ctr(E, y, c0)
523 ht = omac(E, 1, h)
524 yt = omac(E, 2, y)
525 if VERBOSE:
526 print 'c0 = %s' % hex(c0)
527 print 'ht = %s' % hex(ht)
528 print 'yt = %s' % hex(yt)
529 if t == (c0 ^ ht ^ yt)[:len(t)]: return m,
530 else: return None,
531
532 def eaxgen(bc):
533 return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
534 (bc.blksz, 3*bc.blksz, 3*bc.blksz),
535 (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
536
537 ###--------------------------------------------------------------------------
538 ### PMAC.
539
540 def ocb_masks(E):
541 blksz = E.__class__.blksz
542 p = poly(8*blksz)
543 x = C.GF(2); xinv = p.modinv(x)
544 z = Z(blksz)
545 L = E.encrypt(z)
546 Lxinv = mul_blk_gf(L, xinv, p)
547 Lgamma = 66*[L]
548 for i in xrange(1, len(Lgamma)):
549 Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
550 return Lgamma, Lxinv
551
552 def dump_ocb(E):
553 Lgamma, Lxinv = ocb_masks(E)
554 print 'L x^-1 = %s' % hex(Lxinv)
555 for i, lg in enumerate(Lgamma[:16]):
556 print 'L x^%d = %s' % (i, hex(lg))
557
558 def pmac1(E, m):
559 blksz = E.__class__.blksz
560 Lgamma, Lxinv = ocb_masks(E)
561 a = o = Z(blksz)
562 i = 0
563 v, tl = blocks(m, blksz)
564 for x in v:
565 i += 1
566 b = ntz(i)
567 o ^= Lgamma[b]
568 a ^= E.encrypt(x ^ o)
569 if VERBOSE:
570 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
571 print 'A[%d]: %s' % (i, hex(a))
572 if len(tl) == blksz: a ^= tl ^ Lxinv
573 else: a ^= pad10star(tl, blksz)
574 return E.encrypt(a)
575
576 def pmac2(E, m):
577 blksz = E.__class__.blksz
578 p = prim(8*blksz)
579 L = E.encrypt(Z(blksz))
580 o = mul_blk_gf(L, 10, p)
581 a = Z(blksz)
582 v, tl = blocks(m, blksz)
583 for x in v:
584 a ^= E.encrypt(x ^ o)
585 o = mul_blk_gf(o, 2, p)
586 if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, 3, p)
587 else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p)
588 return E.encrypt(a)
589
590 def ocb3_masks(E):
591 Lgamma, _ = ocb_masks(E)
592 Lstar = Lgamma[0]
593 Ldollar = Lgamma[1]
594 return Lstar, Ldollar, Lgamma[2:]
595
596 def dump_ocb3(E):
597 Lstar, Ldollar, Lgamma = ocb3_masks(E)
598 print 'L_* = %s' % hex(Lstar)
599 print 'L_$ = %s' % hex(Ldollar)
600 for i, lg in enumerate(Lgamma[:16]):
601 print 'L x^%d = %s' % (i, hex(lg))
602
603 def pmac3(E, m):
604 ## Note that `PMAC3' is /not/ a secure MAC. It depends on other parts of
605 ## OCB3 to prevent a rather easy linear-algebra attack.
606 blksz = E.__class__.blksz
607 Lstar, Ldollar, Lgamma = ocb3_masks(E)
608 a = o = Z(blksz)
609 i = 0
610 v, tl = blocks0(m, blksz)
611 for x in v:
612 i += 1
613 b = ntz(i)
614 o ^= Lgamma[b]
615 a ^= E.encrypt(x ^ o)
616 if VERBOSE:
617 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
618 print 'A[%d]: %s' % (i, hex(a))
619 if tl:
620 o ^= Lstar
621 a ^= E.encrypt(pad10star(tl, blksz) ^ o)
622 if VERBOSE:
623 print 'Z[%d]: * -> %s' % (i, hex(o))
624 print 'A[%d]: %s' % (i, hex(a))
625 return a
626
627 def pmac1_pub(E, m):
628 if VERBOSE: dump_ocb(E)
629 return pmac1(E, m),
630
631 def pmacgen(bc):
632 return [(0,), (1,),
633 (3*bc.blksz,),
634 (3*bc.blksz - 5,)]
635
636 ###--------------------------------------------------------------------------
637 ### OCB.
638
639 def ocb1enc(E, n, h, m, tsz = None):
640 ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
641 ## Associated-Data'.
642 blksz = E.__class__.blksz
643 if VERBOSE: dump_ocb(E)
644 Lgamma, Lxinv = ocb_masks(E)
645 if tsz is None: tsz = blksz
646 a = Z(blksz)
647 o = E.encrypt(n ^ Lgamma[0])
648 if VERBOSE: print 'R = %s' % hex(o)
649 i = 0
650 y = C.WriteBuffer()
651 v, tl = blocks(m, blksz)
652 for x in v:
653 i += 1
654 b = ntz(i)
655 o ^= Lgamma[b]
656 a ^= x
657 if VERBOSE:
658 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
659 print 'A[%d]: %s' % (i, hex(a))
660 y.put(E.encrypt(x ^ o) ^ o)
661 i += 1
662 b = ntz(i)
663 o ^= Lgamma[b]
664 n = len(tl)
665 if VERBOSE:
666 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
667 print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
668 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
669 cfinal = tl ^ yfinal[:n]
670 a ^= o ^ (tl + yfinal[n:])
671 y.put(cfinal)
672 t = E.encrypt(a)
673 if h: t ^= pmac1(E, h)
674 return C.ByteString(y), C.ByteString(t[:tsz])
675
676 def ocb1dec(E, n, h, y, t):
677 ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
678 ## Associated-Data'.
679 blksz = E.__class__.blksz
680 if VERBOSE: dump_ocb(E)
681 Lgamma, Lxinv = ocb_masks(E)
682 a = Z(blksz)
683 o = E.encrypt(n ^ Lgamma[0])
684 if VERBOSE: print 'R = %s' % hex(o)
685 i = 0
686 m = C.WriteBuffer()
687 v, tl = blocks(y, blksz)
688 for x in v:
689 i += 1
690 b = ntz(i)
691 o ^= Lgamma[b]
692 if VERBOSE:
693 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
694 print 'A[%d]: %s' % (i, hex(a))
695 u = E.decrypt(x ^ o) ^ o
696 m.put(u)
697 a ^= u
698 i += 1
699 b = ntz(i)
700 o ^= Lgamma[b]
701 n = len(tl)
702 if VERBOSE:
703 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
704 print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
705 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
706 mfinal = tl ^ yfinal[:n]
707 a ^= o ^ (mfinal + yfinal[n:])
708 m.put(mfinal)
709 u = E.encrypt(a)
710 if h: u ^= pmac1(E, h)
711 if t == u[:len(t)]: return C.ByteString(m),
712 else: return None,
713
714 def ocb2enc(E, n, h, m, tsz = None):
715 ## For OCB2, it's important for security that n = log_x (x + 1) is large in
716 ## the field representations of GF(2^w) used -- in fact, we need more, that
717 ## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}. The
718 ## original paper lists the values for 64 and 128, but we support other
719 ## block sizes, so here's the result of the (rather large, in some cases)
720 ## computation.
721 ##
722 ## Block size log_x (x + 1)
723 ##
724 ## 64 9686038906114705801
725 ## 96 63214690573408919568138788065
726 ## 128 338793687469689340204974836150077311399
727 ## 192 161110085006042185925119981866940491651092686475226538785
728 ## 256 22928580326165511958494515843249267194111962539778797914076675796261938307298
729
730 blksz = E.__class__.blksz
731 if tsz is None: tsz = blksz
732 p = prim(8*blksz)
733 L = E.encrypt(n)
734 o = mul_blk_gf(L, 2, p)
735 a = Z(blksz)
736 v, tl = blocks(m, blksz)
737 y = C.WriteBuffer()
738 for x in v:
739 a ^= x
740 y.put(E.encrypt(x ^ o) ^ o)
741 o = mul_blk_gf(o, 2, p)
742 n = len(tl)
743 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
744 cfinal = tl ^ yfinal[:n]
745 a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, 3, p)
746 y.put(cfinal)
747 t = E.encrypt(a)
748 if h: t ^= pmac2(E, h)
749 return C.ByteString(y), C.ByteString(t[:tsz])
750
751 def ocb2dec(E, n, h, y, t):
752 blksz = E.__class__.blksz
753 p = prim(8*blksz)
754 L = E.encrypt(n)
755 o = mul_blk_gf(L, 2, p)
756 a = Z(blksz)
757 v, tl = blocks(y, blksz)
758 m = C.WriteBuffer()
759 for x in v:
760 u = E.encrypt(x ^ o) ^ o
761 y.put(u)
762 a ^= u
763 o = mul_blk_gf(o, 2, p)
764 n = len(tl)
765 yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
766 mfinal = tl ^ yfinal[:n]
767 a ^= (mfinal + yfinal[n:]) ^ mul_blk_gf(o, 3, p)
768 m.put(mfinal)
769 u = E.encrypt(a)
770 if h: u ^= pmac2(E, h)
771 if t == u[:len(t)]: return C.ByteString(m),
772 else: return None,
773
774 OCB3_STRETCH = { 4: ( 4, 17),
775 8: ( 5, 25),
776 12: ( 6, 33),
777 16: ( 6, 8),
778 24: ( 7, 40),
779 32: ( 8, 1),
780 48: ( 8, 80),
781 64: ( 8, 176),
782 96: ( 9, 160),
783 128: ( 9, 352),
784 200: (10, 192) }
785
786 def ocb3nonce(E, n, tsz):
787
788 ## Figure out how much we need to glue onto the nonce. This ends up being
789 ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
790 ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
791 ## v - 1. But this is an annoying way to think about it because of the
792 ## byte misalignment. Instead, think of it as a byte-aligned prefix
793 ## encoding the tag and an `is the nonce full-length' flag, followed by
794 ## optional padding, and then the nonce:
795 ##
796 ## F || N if l(N) = w - f
797 ## F || 0^p || 1 || N otherwise
798 ##
799 ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
800 ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
801 blksz = E.__class__.blksz
802 tszbits = min(C.MP(8*blksz - 1).nbits, 8)
803 fwd = tszbits/8 + 1
804 f = 8*(tsz%blksz) << + 8*fwd - tszbits
805
806 ## Form the augmented nonce.
807 nb = C.WriteBuffer()
808 nsz, nwd = len(n), blksz - fwd
809 if nsz == nwd: f |= 1
810 nb.put(C.MP(f).storeb(fwd))
811 if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
812 nb.put(n)
813 nn = C.ByteString(nb)
814 if VERBOSE: print 'aug-nonce = %s' % hex(nn)
815
816 ## Calculate the initial offset.
817 split, shift = OCB3_STRETCH[blksz]
818 t2pw = C.MP(0).setbit(8*blksz) - 1
819 lomask = (C.MP(0).setbit(split) - 1)
820 himask = ~lomask
821 top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
822 ktop = C.MP.loadb(E.encrypt(top))
823 stretch = (ktop << 8*blksz) | (ktop ^ (ktop << shift)&t2pw)
824 o = (stretch >> 8*blksz - bottom).storeb(blksz)
825 if VERBOSE:
826 print 'stretch = %s' % hex(stretch.storeb(2*blksz))
827 print 'Z[0] = %s' % hex(o)
828
829 return o
830
831 def ocb3enc(E, n, h, m, tsz = None):
832 blksz = E.__class__.blksz
833 if tsz is None: tsz = blksz
834 Lstar, Ldollar, Lgamma = ocb3_masks(E)
835 if VERBOSE: dump_ocb3(E)
836
837 ## Set things up.
838 o = ocb3nonce(E, n, tsz)
839 a = C.ByteString.zero(blksz)
840
841 ## Split the message into blocks.
842 i = 0
843 y = C.WriteBuffer()
844 v, tl = blocks0(m, blksz)
845 for x in v:
846 i += 1
847 b = ntz(i)
848 o ^= Lgamma[b]
849 a ^= x
850 if VERBOSE:
851 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
852 print 'A[%d]: %s' % (i, hex(a))
853 y.put(E.encrypt(x ^ o) ^ o)
854 if tl:
855 o ^= Lstar
856 n = len(tl)
857 pad = E.encrypt(o)
858 a ^= pad10star(tl, blksz)
859 if VERBOSE:
860 print 'Z[%d]: * -> %s' % (i, hex(o))
861 print 'A[%d]: %s' % (i, hex(a))
862 y.put(tl ^ pad[0:n])
863 o ^= Ldollar
864 t = E.encrypt(a ^ o) ^ pmac3(E, h)
865 return C.ByteString(y), C.ByteString(t[:tsz])
866
867 def ocb3dec(E, n, h, y, t):
868 blksz = E.__class__.blksz
869 tsz = len(t)
870 Lstar, Ldollar, Lgamma = ocb3_masks(E)
871 if VERBOSE: dump_ocb3(E)
872
873 ## Set things up.
874 o = ocb3nonce(E, n, tsz)
875 a = C.ByteString.zero(blksz)
876
877 ## Split the message into blocks.
878 i = 0
879 m = C.WriteBuffer()
880 v, tl = blocks0(y, blksz)
881 for x in v:
882 i += 1
883 b = ntz(i)
884 o ^= Lgamma[b]
885 if VERBOSE:
886 print 'Z[%d]: %d -> %s' % (i, b, hex(o))
887 print 'A[%d]: %s' % (i, hex(a))
888 u = E.encrypt(x ^ o) ^ o
889 m.put(u)
890 a ^= u
891 if tl:
892 o ^= Lstar
893 n = len(tl)
894 pad = E.encrypt(o)
895 if VERBOSE:
896 print 'Z[%d]: * -> %s' % (i, hex(o))
897 print 'A[%d]: %s' % (i, hex(a))
898 u = tl ^ pad[0:n]
899 m.put(u)
900 a ^= pad10star(u, blksz)
901 o ^= Ldollar
902 u = E.encrypt(a ^ o) ^ pmac3(E, h)
903 if t == u[:tsz]: return C.ByteString(m),
904 else: return None,
905
906 def ocbgen(bc):
907 w = bc.blksz
908 return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
909 (w, 0, 3*w),
910 (w, 3*w, 3*w),
911 (w, 0, 3*w + 5),
912 (w, 3*w - 5, 3*w + 5)]
913
914 def ocb3gen(bc):
915 w = bc.blksz
916 return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
917 (w - 5, 0, 3*w),
918 (w - 3, 3*w, 3*w),
919 (w - 2, 0, 3*w + 5),
920 (w - 2, 3*w - 5, 3*w + 5)]
921
922 def ocb3_mct(bc, ksz, tsz):
923 k = C.ByteString(C.WriteBuffer().zero(ksz - 4).putu32(8*tsz))
924 E = bc(k)
925 n = C.MP(1)
926 nw = bc.blksz - 4
927 cbuf = C.WriteBuffer()
928 for i in xrange(128):
929 s = C.ByteString.zero(i)
930 y, t = ocb3enc(E, n.storeb(nw), s, s, tsz); n += 1; cbuf.put(y).put(t)
931 y, t = ocb3enc(E, n.storeb(nw), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t)
932 y, t = ocb3enc(E, n.storeb(nw), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t)
933 _, t = ocb3enc(E, n.storeb(nw), C.ByteString(cbuf), EMPTY, tsz)
934 print hex(t)
935
936 def ocb3_mct2(bc):
937 k = C.bytes('000102030405060708090a0b0c0d0e0f')
938 E = bc(k)
939 tsz = min(E.blksz, 32)
940 n = C.MP(1)
941 cbuf = C.WriteBuffer()
942 for i in xrange(128):
943 sbuf = C.WriteBuffer()
944 for j in xrange(i): sbuf.putu8(j)
945 s = C.ByteString(sbuf)
946 y, t = ocb3enc(E, n.storeb(2), s, s, tsz); n += 1; cbuf.put(y).put(t)
947 y, t = ocb3enc(E, n.storeb(2), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t)
948 y, t = ocb3enc(E, n.storeb(2), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t)
949 _, t = ocb3enc(E, n.storeb(2), C.ByteString(cbuf), EMPTY, tsz)
950 print hex(t)
951
952 ###--------------------------------------------------------------------------
953 ### Main program.
954
955 class struct (object):
956 def __init__(me, **kw):
957 me.__dict__.update(kw)
958
959 binarg = struct(mk = R.block, parse = C.bytes, show = safehex)
960 intarg = struct(mk = lambda x: x, parse = int, show = None)
961
962 MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc),
963 'eax-dec': (dummygen, 4*[binarg], eaxdec),
964 'ccm-enc': (ccmgen, 3*[binarg] + [intarg], ccmenc),
965 'ccm-dec': (dummygen, 4*[binarg], ccmdec),
966 'cmac': (cmacgen, [binarg], cmac),
967 'gcm-enc': (gcmgen, 3*[binarg] + [intarg], gcmenc),
968 'gcm-dec': (dummygen, 4*[binarg], gcmdec),
969 'ocb1-enc': (ocbgen, 3*[binarg] + [intarg], ocb1enc),
970 'ocb1-dec': (dummygen, 4*[binarg], ocb1dec),
971 'ocb2-enc': (ocbgen, 3*[binarg] + [intarg], ocb2enc),
972 'ocb2-dec': (dummygen, 4*[binarg], ocb2dec),
973 'ocb3-enc': (ocb3gen, 3*[binarg] + [intarg], ocb3enc),
974 'ocb3-dec': (dummygen, 4*[binarg], ocb3dec),
975 'pmac1': (pmacgen, [binarg], pmac1_pub) }
976
977 mode = argv[1]
978 bc = None
979 for d in CUSTOM, C.gcprps:
980 try: bc = d[argv[2]]
981 except KeyError: pass
982 else: break
983 if bc is None: raise KeyError, argv[2]
984 if len(argv) == 5 and mode == 'ocb3-mct':
985 VERBOSE = False
986 ksz, tsz = int(argv[3]), int(argv[4])
987 ocb3_mct(bc, ksz, tsz)
988 exit(0)
989 if len(argv) == 3 and mode == 'ocb3-mct2':
990 VERBOSE = False
991 ocb3_mct2(bc)
992 exit(0)
993 if len(argv) == 3:
994 VERBOSE = False
995 gen, argty, func = MODEMAP[mode]
996 if mode.endswith('-enc'): mode = mode[:-4]
997 print '%s-%s {' % (bc.name, mode)
998 for ksz in keylens(bc.keysz):
999 for argvals in gen(bc):
1000 k = R.block(ksz)
1001 args = [t.mk(a) for t, a in izip(argty, argvals)]
1002 rets = func(bc(k), *args)
1003 print ' %s' % safehex(k)
1004 for t, a in izip(argty, args):
1005 if t.show: print ' %s' % t.show(a)
1006 for r, lastp in with_lastp(rets):
1007 print ' %s%s' % (safehex(r), lastp and ';' or '')
1008 print '}'
1009 else:
1010 VERBOSE = True
1011 k = C.bytes(argv[3])
1012 gen, argty, func = MODEMAP[mode]
1013 args = [t.parse(a) for t, a in izip(argty, argv[4:])]
1014 rets = func(bc(k), *args)
1015 for r in rets:
1016 if r is None: print "X"
1017 else: print hex(r)
1018
1019 ###----- That's all, folks --------------------------------------------------