| 1 | #! /usr/bin/python |
| 2 | |
| 3 | from sys import argv |
| 4 | from struct import pack |
| 5 | from itertools import izip |
| 6 | import catacomb as C |
| 7 | |
| 8 | R = C.FibRand(0) |
| 9 | |
| 10 | ###-------------------------------------------------------------------------- |
| 11 | ### Utilities. |
| 12 | |
| 13 | def combs(things, k): |
| 14 | ii = range(k) |
| 15 | n = len(things) |
| 16 | while True: |
| 17 | yield [things[i] for i in ii] |
| 18 | for j in xrange(k): |
| 19 | if j == k - 1: lim = n |
| 20 | else: lim = ii[j + 1] |
| 21 | i = ii[j] + 1 |
| 22 | if i < lim: |
| 23 | ii[j] = i |
| 24 | break |
| 25 | ii[j] = j |
| 26 | else: |
| 27 | return |
| 28 | |
| 29 | POLYMAP = {} |
| 30 | |
| 31 | def poly(nbits): |
| 32 | try: return POLYMAP[nbits] |
| 33 | except KeyError: pass |
| 34 | base = C.GF(0).setbit(nbits).setbit(0) |
| 35 | for k in xrange(1, nbits, 2): |
| 36 | for cc in combs(range(1, nbits), k): |
| 37 | p = base + sum(C.GF(0).setbit(c) for c in cc) |
| 38 | if p.irreduciblep(): POLYMAP[nbits] = p; return p |
| 39 | raise ValueError, nbits |
| 40 | |
| 41 | def Z(n): |
| 42 | return C.ByteString.zero(n) |
| 43 | |
| 44 | def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8) |
| 45 | |
| 46 | def with_lastp(it): |
| 47 | it = iter(it) |
| 48 | try: j = next(it) |
| 49 | except StopIteration: raise ValueError, 'empty iter' |
| 50 | lastp = False |
| 51 | while not lastp: |
| 52 | i = j |
| 53 | try: j = next(it) |
| 54 | except StopIteration: lastp = True |
| 55 | yield i, lastp |
| 56 | |
| 57 | def safehex(x): |
| 58 | if len(x): return hex(x) |
| 59 | else: return '""' |
| 60 | |
| 61 | def keylens(ksz): |
| 62 | sel = [] |
| 63 | if isinstance(ksz, C.KeySZSet): kk = ksz.set |
| 64 | elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod) |
| 65 | elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0] |
| 66 | kk = list(kk); kk = kk[:] |
| 67 | n = len(kk) |
| 68 | while n and len(sel) < 4: |
| 69 | i = R.range(n) |
| 70 | n -= 1 |
| 71 | kk[i], kk[n] = kk[n], kk[i] |
| 72 | sel.append(kk[n]) |
| 73 | return sel |
| 74 | |
| 75 | def pad0star(m, w): |
| 76 | n = len(m) |
| 77 | if not n: r = w |
| 78 | else: r = (-len(m))%w |
| 79 | if r: m += Z(r) |
| 80 | return C.ByteString(m) |
| 81 | |
| 82 | def pad10star(m, w): |
| 83 | r = w - len(m)%w |
| 84 | if r: m += '\x80' + Z(r - 1) |
| 85 | return C.ByteString(m) |
| 86 | |
| 87 | def ntz(i): |
| 88 | j = 0 |
| 89 | while (i&1) == 0: i >>= 1; j += 1 |
| 90 | return j |
| 91 | |
| 92 | def blocks(x, w): |
| 93 | v, i, n = [], 0, len(x) |
| 94 | while n - i > w: |
| 95 | v.append(C.ByteString(x[i:i + w])) |
| 96 | i += w |
| 97 | return v, C.ByteString(x[i:]) |
| 98 | |
| 99 | EMPTY = C.bytes('') |
| 100 | |
| 101 | def blocks0(x, w): |
| 102 | v, tl = blocks(x, w) |
| 103 | if len(tl) == w: v.append(tl); tl = EMPTY |
| 104 | return v, tl |
| 105 | |
| 106 | def dummygen(bc): return [] |
| 107 | |
| 108 | CUSTOM = {} |
| 109 | |
| 110 | ###-------------------------------------------------------------------------- |
| 111 | ### RC6. |
| 112 | |
| 113 | class RC6Cipher (type): |
| 114 | def __new__(cls, w, r): |
| 115 | name = 'rc6-%d/%d' % (w, r) |
| 116 | me = type(name, (RC6Base,), {}) |
| 117 | me.name = name |
| 118 | me.r = r |
| 119 | me.w = w |
| 120 | me.blksz = w/2 |
| 121 | me.keysz = C.KeySZRange(me.blksz, 1, 255, 1) |
| 122 | return me |
| 123 | |
| 124 | def rotw(w): |
| 125 | return w.bit_length() - 1 |
| 126 | |
| 127 | def rol(w, x, n): |
| 128 | m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1 |
| 129 | return ((x&m0) << n) | (x >> (w - n))&m1 |
| 130 | |
| 131 | def ror(w, x, n): |
| 132 | m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1 |
| 133 | return ((x&m0) << (w - n)) | (x >> n)&m1 |
| 134 | |
| 135 | class RC6Base (object): |
| 136 | |
| 137 | ## Magic constants. |
| 138 | P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06) |
| 139 | Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188) |
| 140 | |
| 141 | def __init__(me, k): |
| 142 | |
| 143 | ## Build the magic numbers. |
| 144 | P = me.P400 >> (400 - me.w) |
| 145 | if P%2 == 0: P += 1 |
| 146 | Q = me.Q400 >> (400 - me.w) |
| 147 | if Q%2 == 0: Q += 1 |
| 148 | M = C.MP(0).setbit(me.w) - 1 |
| 149 | |
| 150 | ## Convert the key into words. |
| 151 | wb = me.w/8 |
| 152 | c = (len(k) + wb - 1)/wb |
| 153 | kb, ktl = blocks(k, me.w/8) |
| 154 | L = map(C.MP.loadl, kb + [ktl]) |
| 155 | assert c == len(L) |
| 156 | |
| 157 | ## Build the subkey table. |
| 158 | me.d = rotw(me.w) |
| 159 | n = 2*me.r + 4 |
| 160 | S = [(P + i*Q)&M for i in xrange(n)] |
| 161 | |
| 162 | ##for j in xrange(c): |
| 163 | ## print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0')) |
| 164 | ##for i in xrange(n): |
| 165 | ## print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0')) |
| 166 | |
| 167 | i = j = 0 |
| 168 | A = B = C.MP(0) |
| 169 | |
| 170 | for s in xrange(3*max(c, n)): |
| 171 | A = S[i] = rol(me.w, S[i] + A + B, 3) |
| 172 | B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d)) |
| 173 | ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0')) |
| 174 | ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0')) |
| 175 | i = (i + 1)%n |
| 176 | j = (j + 1)%c |
| 177 | |
| 178 | ## Done. |
| 179 | me.s = S |
| 180 | |
| 181 | def encrypt(me, x): |
| 182 | M = C.MP(0).setbit(me.w) - 1 |
| 183 | a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0]) |
| 184 | b = (b + me.s[0])&M |
| 185 | d = (d + me.s[1])&M |
| 186 | ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0')) |
| 187 | ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0')) |
| 188 | for i in xrange(2, 2*me.r + 2, 2): |
| 189 | t = rol(me.w, 2*b*b + b, me.d) |
| 190 | u = rol(me.w, 2*d*d + d, me.d) |
| 191 | a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M |
| 192 | c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M |
| 193 | ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0')) |
| 194 | ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0')) |
| 195 | a, b, c, d = b, c, d, a |
| 196 | a = (a + me.s[2*me.r + 2])&M |
| 197 | c = (c + me.s[2*me.r + 3])&M |
| 198 | ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0')) |
| 199 | ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0')) |
| 200 | return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) + |
| 201 | c.storel(me.blksz/4) + d.storel(me.blksz/4)) |
| 202 | |
| 203 | def decrypt(me, x): |
| 204 | M = C.MP(0).setbit(me.w) - 1 |
| 205 | a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)) |
| 206 | c = (c - me.s[2*me.r + 3])&M |
| 207 | a = (a - me.s[2*me.r + 2])&M |
| 208 | for i in xrange(2*me.r + 1, 1, -2): |
| 209 | a, b, c, d = d, a, b, c |
| 210 | u = rol(me.w, 2*d*d + d, me.d) |
| 211 | t = rol(me.w, 2*b*b + b, me.d) |
| 212 | c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u |
| 213 | a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t |
| 214 | a = (a + s[2*me.r + 2])&M |
| 215 | c = (c + s[2*me.r + 3])&M |
| 216 | return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) + |
| 217 | c.storel(me.blksz/4) + d.storel(me.blksz/4)) |
| 218 | |
| 219 | for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16), |
| 220 | (32, 20), (48, 16), (64, 16), (96, 16), (128, 16), |
| 221 | (192, 16), (256, 16), (400, 16)]: |
| 222 | CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r) |
| 223 | |
| 224 | ###-------------------------------------------------------------------------- |
| 225 | ### OMAC (or CMAC). |
| 226 | |
| 227 | def omac_masks(E): |
| 228 | blksz = E.__class__.blksz |
| 229 | p = poly(8*blksz) |
| 230 | z = Z(blksz) |
| 231 | L = E.encrypt(z) |
| 232 | m0 = mul_blk_gf(L, 2, p) |
| 233 | m1 = mul_blk_gf(m0, 2, p) |
| 234 | return m0, m1 |
| 235 | |
| 236 | def dump_omac(E): |
| 237 | blksz = E.__class__.blksz |
| 238 | m0, m1 = omac_masks(E) |
| 239 | print 'L = %s' % hex(E.encrypt(Z(blksz))) |
| 240 | print 'm0 = %s' % hex(m0) |
| 241 | print 'm1 = %s' % hex(m1) |
| 242 | for t in xrange(3): |
| 243 | print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz)))) |
| 244 | print 'z%d = %s' % (t, hex(omac(E, t, ''))) |
| 245 | |
| 246 | def omac(E, t, m): |
| 247 | blksz = E.__class__.blksz |
| 248 | m0, m1 = omac_masks(E) |
| 249 | a = Z(blksz) |
| 250 | if t is not None: m = C.MP(t).storeb(blksz) + m |
| 251 | v, tl = blocks(m, blksz) |
| 252 | for x in v: a = E.encrypt(a ^ x) |
| 253 | r = blksz - len(tl) |
| 254 | if r == 0: |
| 255 | a = E.encrypt(a ^ tl ^ m0) |
| 256 | else: |
| 257 | pad = pad10star(tl, blksz) |
| 258 | a = E.encrypt(a ^ pad ^ m1) |
| 259 | return a |
| 260 | |
| 261 | def cmac(E, m): |
| 262 | if VERBOSE: dump_omac(E) |
| 263 | return omac(E, None, m), |
| 264 | |
| 265 | def cmacgen(bc): |
| 266 | return [(0,), (1,), |
| 267 | (3*bc.blksz,), |
| 268 | (3*bc.blksz - 5,)] |
| 269 | |
| 270 | ###-------------------------------------------------------------------------- |
| 271 | ### Counter mode. |
| 272 | |
| 273 | def ctr(E, m, c0): |
| 274 | blksz = E.__class__.blksz |
| 275 | y = C.WriteBuffer() |
| 276 | c = C.MP.loadb(c0) |
| 277 | while y.size < len(m): |
| 278 | y.put(E.encrypt(c.storeb(blksz))) |
| 279 | c += 1 |
| 280 | return C.ByteString(m) ^ C.ByteString(y)[:len(m)] |
| 281 | |
| 282 | ###-------------------------------------------------------------------------- |
| 283 | ### EAX. |
| 284 | |
| 285 | def eaxenc(E, n, h, m, tsz = None): |
| 286 | if VERBOSE: |
| 287 | print 'k = %s' % hex(k) |
| 288 | print 'n = %s' % hex(n) |
| 289 | print 'h = %s' % hex(h) |
| 290 | print 'm = %s' % hex(m) |
| 291 | dump_omac(E) |
| 292 | if tsz is None: tsz = E.__class__.blksz |
| 293 | c0 = omac(E, 0, n) |
| 294 | y = ctr(E, m, c0) |
| 295 | ht = omac(E, 1, h) |
| 296 | yt = omac(E, 2, y) |
| 297 | if VERBOSE: |
| 298 | print 'c0 = %s' % hex(c0) |
| 299 | print 'ht = %s' % hex(ht) |
| 300 | print 'yt = %s' % hex(yt) |
| 301 | return y, C.ByteString((c0 ^ ht ^ yt)[:tsz]) |
| 302 | |
| 303 | def eaxdec(E, n, h, y, t): |
| 304 | if VERBOSE: |
| 305 | print 'k = %s' % hex(k) |
| 306 | print 'n = %s' % hex(n) |
| 307 | print 'h = %s' % hex(h) |
| 308 | print 'y = %s' % hex(y) |
| 309 | print 't = %s' % hex(t) |
| 310 | dump_omac(E) |
| 311 | c0 = omac(E, 0, n) |
| 312 | m = ctr(E, y, c0) |
| 313 | ht = omac(E, 1, h) |
| 314 | yt = omac(E, 2, y) |
| 315 | if VERBOSE: |
| 316 | print 'c0 = %s' % hex(c0) |
| 317 | print 'ht = %s' % hex(ht) |
| 318 | print 'yt = %s' % hex(yt) |
| 319 | if t == (c0 ^ ht ^ yt)[:len(t)]: return m, |
| 320 | else: return None, |
| 321 | |
| 322 | def eaxgen(bc): |
| 323 | return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), |
| 324 | (bc.blksz, 3*bc.blksz, 3*bc.blksz), |
| 325 | (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)] |
| 326 | |
| 327 | ###-------------------------------------------------------------------------- |
| 328 | ### Main program. |
| 329 | |
| 330 | class struct (object): |
| 331 | def __init__(me, **kw): |
| 332 | me.__dict__.update(kw) |
| 333 | |
| 334 | binarg = struct(mk = R.block, parse = C.bytes, show = safehex) |
| 335 | intarg = struct(mk = lambda x: x, parse = int, show = None) |
| 336 | |
| 337 | MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc), |
| 338 | 'eax-dec': (dummygen, 4*[binarg], eaxdec), |
| 339 | 'cmac': (cmacgen, [binarg], cmac) } |
| 340 | |
| 341 | mode = argv[1] |
| 342 | bc = None |
| 343 | for d in CUSTOM, C.gcprps: |
| 344 | try: bc = d[argv[2]] |
| 345 | except KeyError: pass |
| 346 | else: break |
| 347 | if bc is None: raise KeyError, argv[2] |
| 348 | if len(argv) == 3: |
| 349 | VERBOSE = False |
| 350 | gen, argty, func = MODEMAP[mode] |
| 351 | if mode.endswith('-enc'): mode = mode[:-4] |
| 352 | print '%s-%s {' % (bc.name, mode) |
| 353 | for ksz in keylens(bc.keysz): |
| 354 | for argvals in gen(bc): |
| 355 | k = R.block(ksz) |
| 356 | args = [t.mk(a) for t, a in izip(argty, argvals)] |
| 357 | rets = func(bc(k), *args) |
| 358 | print ' %s' % safehex(k) |
| 359 | for t, a in izip(argty, args): |
| 360 | if t.show: print ' %s' % t.show(a) |
| 361 | for r, lastp in with_lastp(rets): |
| 362 | print ' %s%s' % (safehex(r), lastp and ';' or '') |
| 363 | print '}' |
| 364 | else: |
| 365 | VERBOSE = True |
| 366 | k = C.bytes(argv[3]) |
| 367 | gen, argty, func = MODEMAP[mode] |
| 368 | args = [t.parse(a) for t, a in izip(argty, argv[4:])] |
| 369 | rets = func(bc(k), *args) |
| 370 | for r in rets: |
| 371 | if r is None: print "X" |
| 372 | else: print hex(r) |
| 373 | |
| 374 | ###----- That's all, folks -------------------------------------------------- |