X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/blobdiff_plain/378ceeef4e0663d913cb448c32022522d39e7848..7f38dc76ee0809207e67be7b2a2ddc600aba54d5:/catacomb/__init__.py diff --git a/catacomb/__init__.py b/catacomb/__init__.py index 0ce4d04..9178475 100644 --- a/catacomb/__init__.py +++ b/catacomb/__init__.py @@ -27,6 +27,9 @@ from __future__ import with_statement from binascii import hexlify as _hexify, unhexlify as _unhexify from contextlib import contextmanager as _ctxmgr +try: import DLFCN as _dlfcn +except ImportError: _dlfcn = None +import os as _os from struct import pack as _pack import sys as _sys import types as _types @@ -34,14 +37,46 @@ import types as _types ###-------------------------------------------------------------------------- ### Import the main C extension module. +try: + _dlflags = _odlflags = _sys.getdlopenflags() +except AttributeError: + _dlflags = _odlflags = -1 + +## Set the `deep binding' flag. Python has its own different MD5 +## implementation, and some distributions export `md5_init' and friends so +## they override our versions, which doesn't end well. Figure out how to +## turn this flag on so we don't have the problem. +if _dlflags >= 0: + try: _dlflags |= _dlfcn.RTLD_DEEPBIND + except AttributeError: + try: _dlflags |= _os.RTLD_DEEPBIND + except AttributeError: + if _os.uname()[0] == 'Linux': _dlflags |= 8 # magic knowledge + else: pass # can't do this. + _sys.setdlopenflags(_dlflags) + import _base +if _odlflags >= 0: + _sys.setdlopenflags(_odlflags) + +del _dlflags, _odlflags + ###-------------------------------------------------------------------------- ### Basic stuff. ## For the benefit of the default keyreporter, we need the program name. _base._ego(_sys.argv[0]) +## Register our module. +_base._set_home_module(_sys.modules[__name__]) +def default_lostexchook(why, ty, val, tb): + """`catacomb.lostexchook(WHY, TY, VAL, TB)' reports lost exceptions.""" + _sys.stderr.write("\n\n!!! LOST EXCEPTION: %s\n" % why) + _sys.excepthook(ty, val, tb) + _sys.stderr.write("\n") +lostexchook = default_lostexchook + ## How to fix a name back into the right identifier. Alas, the rules are not ## consistent. def _fixname(name): @@ -50,7 +85,7 @@ def _fixname(name): name = name.replace('-', '_') ## But slashes might become underscores or just vanish. - if name.startswith('salsa20'): name = name.translate(None, '/') + if name.startswith('salsa20'): name = name.replace('/', '') else: name = name.replace('/', '_') ## Done. @@ -77,7 +112,7 @@ def _init(): for j in b: if j[:plen] == pre: setattr(c, j[plen:], classmethod(b[j])) - for i in [gcciphers, gchashes, gcmacs, gcprps]: + for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]: for c in i.itervalues(): d[_fixname(c.name)] = c for c in gccrands.itervalues(): @@ -159,6 +194,27 @@ ByteString.__hash__ = str.__hash__ bytes = ByteString.fromhex ###-------------------------------------------------------------------------- +### Symmetric encryption. + +class _tmp: + def encrypt(me, n, m, tsz = None, h = ByteString('')): + if tsz is None: tsz = me.__class__.tagsz.default + e = me.enc(n, len(h), len(m), tsz) + if not len(h): a = None + else: a = e.aad().hash(h) + c0 = e.encrypt(m) + c1, t = e.done(aad = a) + return c0 + c1, t + def decrypt(me, n, c, t, h = ByteString('')): + d = me.dec(n, len(h), len(c), len(t)) + if not len(h): a = None + else: a = d.aad().hash(h) + m = d.decrypt(c) + m += d.done(t, aad = a) + return m +_augment(GAEKey, _tmp) + +###-------------------------------------------------------------------------- ### Hashing. class _tmp: @@ -171,15 +227,31 @@ _augment(Poly1305Hash, _tmp) class _HashBase (object): ## The standard hash methods. Assume that `hash' is defined and returns ## the receiver. - def hashu8(me, n): return me.hash(_pack('B', n)) - def hashu16l(me, n): return me.hash(_pack('H', n)) + def _check_range(me, n, max): + if not (0 <= n <= max): raise OverflowError("out of range") + def hashu8(me, n): + me._check_range(n, 0xff) + return me.hash(_pack('B', n)) + def hashu16l(me, n): + me._check_range(n, 0xffff) + return me.hash(_pack('H', n)) hashu16 = hashu16b - def hashu32l(me, n): return me.hash(_pack('L', n)) + def hashu32l(me, n): + me._check_range(n, 0xffffffff) + return me.hash(_pack('L', n)) hashu32 = hashu32b - def hashu64l(me, n): return me.hash(_pack('Q', n)) + def hashu64l(me, n): + me._check_range(n, 0xffffffffffffffff) + return me.hash(_pack('Q', n)) hashu64 = hashu64b def hashbuf8(me, s): return me.hashu8(len(s)).hash(s) def hashbuf16l(me, s): return me.hashu16l(len(s)).hash(s) @@ -202,8 +274,8 @@ class _ShakeBase (_HashBase): me._h = me._SHAKE(perso = perso, func = me._FUNC) ## Delegate methods... - def copy(me): new = me.__class__(); new._copy(me) - def _copy(me, other): me._h = other._h + def copy(me): new = me.__class__._bare_new(); new._copy(me); return new + def _copy(me, other): me._h = other._h.copy() def hash(me, m): me._h.hash(m); return me def xof(me): me._h.xof(); return me def get(me, n): return me._h.get(n) @@ -216,6 +288,8 @@ class _ShakeBase (_HashBase): def buffered(me): return me._h.buffered @property def rate(me): return me._h.rate + @classmethod + def _bare_new(cls): return cls() class _tmp: def check(me, h): @@ -240,7 +314,7 @@ class _tmp: me.bytepad_after() _augment(Shake, _tmp) _augment(_ShakeBase, _tmp) -Shake._Z = _ShakeBase._Z = ByteString(200*'\0') +Shake._Z = _ShakeBase._Z = ByteString.zero(200) class KMAC (_ShakeBase): _FUNC = 'KMAC' @@ -254,6 +328,8 @@ class KMAC (_ShakeBase): def xof(me): me.rightenc(0) return super(KMAC, me).xof() + @classmethod + def _bare_new(cls): return cls("") class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16 class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32 @@ -262,21 +338,12 @@ class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32 ### NaCl `secretbox'. def secret_box(k, n, m): - E = xsalsa20(k).setiv(n) - r = E.enczero(poly1305.keysz.default) - s = E.enczero(poly1305.masksz) - y = E.encrypt(m) - t = poly1305(r)(s).hash(y).done() - return ByteString(t + y) + y, t = salsa20_naclbox(k).encrypt(n, m) + return t + y def secret_unbox(k, n, c): - E = xsalsa20(k).setiv(n) - r = E.enczero(poly1305.keysz.default) - s = E.enczero(poly1305.masksz) - y = c[poly1305.tagsz:] - if not poly1305(r)(s).hash(y).check(c[0:poly1305.tagsz]): - raise ValueError, 'decryption failed' - return E.decrypt(c[poly1305.tagsz:]) + tsz = poly1305.tagsz + return salsa20_naclbox(k).decrypt(n, c[tsz:], c[0:tsz]) ###-------------------------------------------------------------------------- ### Multiprecision integers and binary polynomials. @@ -316,15 +383,18 @@ class BaseRat (object): def __mul__(me, you): n, d = _split_rat(you) return type(me)(me._n*n, me._d*d) - def __div__(me, you): + __rmul__ = __mul__ + def __truediv__(me, you): n, d = _split_rat(you) return type(me)(me._n*d, me._d*n) - def __rdiv__(me, you): + def __rtruediv__(me, you): n, d = _split_rat(you) return type(me)(me._d*n, me._n*d) + __div__ = __truediv__ + __rdiv__ = __rtruediv__ def __cmp__(me, you): n, d = _split_rat(you) - return type(me)(me._n*d, n*me._d) + return cmp(me._n*d, n*me._d) def __rcmp__(me, you): n, d = _split_rat(you) return cmp(n*me._d, me._n*d) @@ -344,8 +414,10 @@ class _tmp: def mont(x): return MPMont(x) def barrett(x): return MPBarrett(x) def reduce(x): return MPReduce(x) - def __div__(me, you): return IntRat(me, you) - def __rdiv__(me, you): return IntRat(you, me) + def __truediv__(me, you): return IntRat(me, you) + def __rtruediv__(me, you): return IntRat(you, me) + __div__ = __truediv__ + __rdiv__ = __rtruediv__ _repr_pretty_ = _pp_str _augment(MP, _tmp) @@ -356,8 +428,10 @@ class _tmp: def halftrace(x, y): return x.reduce().halftrace(y) def modsqrt(x, y): return x.reduce().sqrt(y) def quadsolve(x, y): return x.reduce().quadsolve(y) - def __div__(me, you): return GFRat(me, you) - def __rdiv__(me, you): return GFRat(you, me) + def __truediv__(me, you): return GFRat(me, you) + def __rtruediv__(me, you): return GFRat(you, me) + __div__ = __truediv__ + __rdiv__ = __rtruediv__ _repr_pretty_ = _pp_str _augment(GF, _tmp) @@ -524,6 +598,7 @@ class _tmp: def __repr__(me): return '%s(%d)' % (_clsname(me), me.default) def check(me, sz): return True def best(me, sz): return sz + def pad(me, sz): return sz _augment(KeySZAny, _tmp) class _tmp: @@ -540,11 +615,15 @@ class _tmp: pp.pretty(me.max); pp.text(','); pp.breakable() pp.pretty(me.mod) pp.end_group(ind, ')') - def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0 + def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0 def best(me, sz): if sz < me.min: raise ValueError, 'key too small' elif sz > me.max: return me.max - else: return sz - (sz % me.mod) + else: return sz - sz%me.mod + def pad(me, sz): + if sz > me.max: raise ValueError, 'key too large' + elif sz < me.min: return me.min + else: sz += me.mod; return sz - sz%me.mod _augment(KeySZRange, _tmp) class _tmp: @@ -566,6 +645,12 @@ class _tmp: if found < i <= sz: found = i if found < 0: raise ValueError, 'key too small' return found + def pad(me, sz): + found = -1 + for i in me.set: + if sz <= i and (found == -1 or i < found): found = i + if found < 0: raise ValueError, 'key too large' + return found _augment(KeySZSet, _tmp) ###--------------------------------------------------------------------------