96e5f5bf2b194d4105024e6ad8a8e61a06d64b09
[catacomb] / utils / advmodes
1 #! /usr/bin/python
2
3 from sys import argv
4 from struct import pack
5 from itertools import izip
6 import catacomb as C
7
8 R = C.FibRand(0)
9
10 ###--------------------------------------------------------------------------
11 ### Utilities.
12
13 def combs(things, k):
14 ii = range(k)
15 n = len(things)
16 while True:
17 yield [things[i] for i in ii]
18 for j in xrange(k):
19 if j == k - 1: lim = n
20 else: lim = ii[j + 1]
21 i = ii[j] + 1
22 if i < lim:
23 ii[j] = i
24 break
25 ii[j] = j
26 else:
27 return
28
29 POLYMAP = {}
30
31 def poly(nbits):
32 try: return POLYMAP[nbits]
33 except KeyError: pass
34 base = C.GF(0).setbit(nbits).setbit(0)
35 for k in xrange(1, nbits, 2):
36 for cc in combs(range(1, nbits), k):
37 p = base + sum(C.GF(0).setbit(c) for c in cc)
38 if p.irreduciblep(): POLYMAP[nbits] = p; return p
39 raise ValueError, nbits
40
41 def Z(n):
42 return C.ByteString.zero(n)
43
44 def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
45
46 def with_lastp(it):
47 it = iter(it)
48 try: j = next(it)
49 except StopIteration: raise ValueError, 'empty iter'
50 lastp = False
51 while not lastp:
52 i = j
53 try: j = next(it)
54 except StopIteration: lastp = True
55 yield i, lastp
56
57 def safehex(x):
58 if len(x): return hex(x)
59 else: return '""'
60
61 def keylens(ksz):
62 sel = []
63 if isinstance(ksz, C.KeySZSet): kk = ksz.set
64 elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
65 elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
66 kk = list(kk); kk = kk[:]
67 n = len(kk)
68 while n and len(sel) < 4:
69 i = R.range(n)
70 n -= 1
71 kk[i], kk[n] = kk[n], kk[i]
72 sel.append(kk[n])
73 return sel
74
75 def pad0star(m, w):
76 n = len(m)
77 if not n: r = w
78 else: r = (-len(m))%w
79 if r: m += Z(r)
80 return C.ByteString(m)
81
82 def pad10star(m, w):
83 r = w - len(m)%w
84 if r: m += '\x80' + Z(r - 1)
85 return C.ByteString(m)
86
87 def ntz(i):
88 j = 0
89 while (i&1) == 0: i >>= 1; j += 1
90 return j
91
92 def blocks(x, w):
93 v, i, n = [], 0, len(x)
94 while n - i > w:
95 v.append(C.ByteString(x[i:i + w]))
96 i += w
97 return v, C.ByteString(x[i:])
98
99 EMPTY = C.bytes('')
100
101 def blocks0(x, w):
102 v, tl = blocks(x, w)
103 if len(tl) == w: v.append(tl); tl = EMPTY
104 return v, tl
105
106 CUSTOM = {}
107
108 ###--------------------------------------------------------------------------
109 ### RC6.
110
111 class RC6Cipher (type):
112 def __new__(cls, w, r):
113 name = 'rc6-%d/%d' % (w, r)
114 me = type(name, (RC6Base,), {})
115 me.name = name
116 me.r = r
117 me.w = w
118 me.blksz = w/2
119 me.keysz = C.KeySZRange(me.blksz, 1, 255, 1)
120 return me
121
122 def rotw(w):
123 return w.bit_length() - 1
124
125 def rol(w, x, n):
126 m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1
127 return ((x&m0) << n) | (x >> (w - n))&m1
128
129 def ror(w, x, n):
130 m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1
131 return ((x&m0) << (w - n)) | (x >> n)&m1
132
133 class RC6Base (object):
134
135 ## Magic constants.
136 P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06)
137 Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188)
138
139 def __init__(me, k):
140
141 ## Build the magic numbers.
142 P = me.P400 >> (400 - me.w)
143 if P%2 == 0: P += 1
144 Q = me.Q400 >> (400 - me.w)
145 if Q%2 == 0: Q += 1
146 M = C.MP(0).setbit(me.w) - 1
147
148 ## Convert the key into words.
149 wb = me.w/8
150 c = (len(k) + wb - 1)/wb
151 kb, ktl = blocks(k, me.w/8)
152 L = map(C.MP.loadl, kb + [ktl])
153 assert c == len(L)
154
155 ## Build the subkey table.
156 me.d = rotw(me.w)
157 n = 2*me.r + 4
158 S = [(P + i*Q)&M for i in xrange(n)]
159
160 ##for j in xrange(c):
161 ## print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
162 ##for i in xrange(n):
163 ## print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
164
165 i = j = 0
166 A = B = C.MP(0)
167
168 for s in xrange(3*max(c, n)):
169 A = S[i] = rol(me.w, S[i] + A + B, 3)
170 B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d))
171 ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
172 ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
173 i = (i + 1)%n
174 j = (j + 1)%c
175
176 ## Done.
177 me.s = S
178
179 def encrypt(me, x):
180 M = C.MP(0).setbit(me.w) - 1
181 a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0])
182 b = (b + me.s[0])&M
183 d = (d + me.s[1])&M
184 ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0'))
185 ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0'))
186 for i in xrange(2, 2*me.r + 2, 2):
187 t = rol(me.w, 2*b*b + b, me.d)
188 u = rol(me.w, 2*d*d + d, me.d)
189 a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M
190 c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M
191 ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
192 ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
193 a, b, c, d = b, c, d, a
194 a = (a + me.s[2*me.r + 2])&M
195 c = (c + me.s[2*me.r + 3])&M
196 ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
197 ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
198 return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
199 c.storel(me.blksz/4) + d.storel(me.blksz/4))
200
201 def decrypt(me, x):
202 M = C.MP(0).setbit(me.w) - 1
203 a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4))
204 c = (c - me.s[2*me.r + 3])&M
205 a = (a - me.s[2*me.r + 2])&M
206 for i in xrange(2*me.r + 1, 1, -2):
207 a, b, c, d = d, a, b, c
208 u = rol(me.w, 2*d*d + d, me.d)
209 t = rol(me.w, 2*b*b + b, me.d)
210 c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u
211 a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t
212 a = (a + s[2*me.r + 2])&M
213 c = (c + s[2*me.r + 3])&M
214 return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
215 c.storel(me.blksz/4) + d.storel(me.blksz/4))
216
217 for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16),
218 (32, 20), (48, 16), (64, 16), (96, 16), (128, 16),
219 (192, 16), (256, 16), (400, 16)]:
220 CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r)
221
222 ###--------------------------------------------------------------------------
223 ### OMAC (or CMAC).
224
225 def omac_masks(E):
226 blksz = E.__class__.blksz
227 p = poly(8*blksz)
228 z = Z(blksz)
229 L = E.encrypt(z)
230 m0 = mul_blk_gf(L, 2, p)
231 m1 = mul_blk_gf(m0, 2, p)
232 return m0, m1
233
234 def dump_omac(E):
235 blksz = E.__class__.blksz
236 m0, m1 = omac_masks(E)
237 print 'L = %s' % hex(E.encrypt(Z(blksz)))
238 print 'm0 = %s' % hex(m0)
239 print 'm1 = %s' % hex(m1)
240 for t in xrange(3):
241 print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz))))
242 print 'z%d = %s' % (t, hex(omac(E, t, '')))
243
244 def omac(E, t, m):
245 blksz = E.__class__.blksz
246 m0, m1 = omac_masks(E)
247 a = Z(blksz)
248 if t is not None: m = C.MP(t).storeb(blksz) + m
249 v, tl = blocks(m, blksz)
250 for x in v: a = E.encrypt(a ^ x)
251 r = blksz - len(tl)
252 if r == 0:
253 a = E.encrypt(a ^ tl ^ m0)
254 else:
255 pad = pad10star(tl, blksz)
256 a = E.encrypt(a ^ pad ^ m1)
257 return a
258
259 def cmac(E, m):
260 if VERBOSE: dump_omac(E)
261 return omac(E, None, m),
262
263 def cmacgen(bc):
264 return [(0,), (1,),
265 (3*bc.blksz,),
266 (3*bc.blksz - 5,)]
267
268 ###--------------------------------------------------------------------------
269 ### Main program.
270
271 class struct (object):
272 def __init__(me, **kw):
273 me.__dict__.update(kw)
274
275 binarg = struct(mk = R.block, parse = C.bytes, show = safehex)
276 intarg = struct(mk = lambda x: x, parse = int, show = None)
277
278 MODEMAP = { 'cmac': (cmacgen, [binarg], cmac) }
279
280 mode = argv[1]
281 bc = None
282 for d in CUSTOM, C.gcprps:
283 try: bc = d[argv[2]]
284 except KeyError: pass
285 else: break
286 if bc is None: raise KeyError, argv[2]
287 if len(argv) == 3:
288 VERBOSE = False
289 gen, argty, func = MODEMAP[mode]
290 print '%s-%s {' % (bc.name, mode)
291 for ksz in keylens(bc.keysz):
292 for argvals in gen(bc):
293 k = R.block(ksz)
294 args = [t.mk(a) for t, a in izip(argty, argvals)]
295 rets = func(bc(k), *args)
296 print ' %s' % safehex(k)
297 for t, a in izip(argty, args):
298 if t.show: print ' %s' % t.show(a)
299 for r, lastp in with_lastp(rets):
300 print ' %s%s' % (safehex(r), lastp and ';' or '')
301 print '}'
302 else:
303 VERBOSE = True
304 k = C.bytes(argv[3])
305 gen, argty, func = MODEMAP[mode]
306 args = [t.parse(a) for t, a in izip(argty, argv[4:])]
307 rets = func(bc(k), *args)
308 for r in rets: print hex(r)
309
310 ###----- That's all, folks --------------------------------------------------