X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/blobdiff_plain/7f38dc76ee0809207e67be7b2a2ddc600aba54d5..71574cbaff3942fd35ceb2754cbfc36449856644:/catacomb/__init__.py diff --git a/catacomb/__init__.py b/catacomb/__init__.py index 9178475..26463b4 100644 --- a/catacomb/__init__.py +++ b/catacomb/__init__.py @@ -55,7 +55,8 @@ if _dlflags >= 0: else: pass # can't do this. _sys.setdlopenflags(_dlflags) -import _base +if _sys.version_info >= (3,): from . import _base +else: import _base if _odlflags >= 0: _sys.setdlopenflags(_odlflags) @@ -77,6 +78,25 @@ def default_lostexchook(why, ty, val, tb): _sys.stderr.write("\n") lostexchook = default_lostexchook +## Text/binary conversions. +if _sys.version_info >= (3,): + def _bin(s): return s.encode('iso8859-1') +else: + def _bin(s): return s + +## Iterating over dictionaries. +if _sys.version_info >= (3,): + def _iteritems(dict): return dict.items() + def _itervalues(dict): return dict.values() +else: + def _iteritems(dict): return dict.iteritems() + def _itervalues(dict): return dict.itervalues() + +## The built-in bignum type. +try: long +except NameError: _long = int +else: _long = long + ## How to fix a name back into the right identifier. Alas, the rules are not ## consistent. def _fixname(name): @@ -99,23 +119,10 @@ def _init(): for i in b: if i[0] != '_': d[i] = b[i]; - for i in ['ByteString', - 'MP', 'GF', 'Field', - 'ECPt', 'ECPtCurve', 'ECCurve', 'ECInfo', - 'DHInfo', 'BinDHInfo', 'RSAPriv', 'BBSPriv', - 'PrimeFilter', 'RabinMiller', - 'Group', 'GE', - 'KeySZ', 'KeyData']: - c = d[i] - pre = '_' + i + '_' - plen = len(pre) - for j in b: - if j[:plen] == pre: - setattr(c, j[plen:], classmethod(b[j])) for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]: - for c in i.itervalues(): + for c in _itervalues(i): d[_fixname(c.name)] = c - for c in gccrands.itervalues(): + for c in _itervalues(gccrands): d[_fixname(c.name + 'rand')] = c _init() @@ -137,7 +144,7 @@ def _augment(c, cc): def _checkend(r): x, rest = r if rest != '': - raise SyntaxError, 'junk at end of string' + raise SyntaxError('junk at end of string') return x ## Some pretty-printing utilities. @@ -167,7 +174,8 @@ def _pp_commas(pp, printfn, items): else: pp.text(','); pp.breakable() printfn(i) def _pp_dict(pp, items): - def p((k, v)): + def p(kv): + k, v = kv pp.begin_group(0) pp.pretty(k) pp.text(':') @@ -179,16 +187,34 @@ def _pp_dict(pp, items): _pp_commas(pp, p, items) ###-------------------------------------------------------------------------- +### Mappings. + +if _sys.version_info >= (3,): + class _tmp: + def __str__(me): return '%s(%r)' % (type(me).__name__, list(me)) + __repr__ = __str__ + def _repr_pretty_(me, pp, cyclep): + ind = _pp_bgroup_tyname(pp, me, '([') + _pp_commas(pp, pp.pretty, me) + pp.end_group(ind, '])') + _augment(_base._KeyView, _tmp) + _augment(_base._ValueView, _tmp) + _augment(_base._ItemView, _tmp) + +###-------------------------------------------------------------------------- ### Bytestrings. class _tmp: def fromhex(x): return ByteString(_unhexify(x)) fromhex = staticmethod(fromhex) - def __hex__(me): - return _hexify(me) + if _sys.version_info >= (3,): + def hex(me): return _hexify(me).decode() + else: + def hex(me): return _hexify(me) + __hex__ = hex def __repr__(me): - return 'bytes(%r)' % hex(me) + return 'bytes(%r)' % me.hex() _augment(ByteString, _tmp) ByteString.__hash__ = str.__hash__ bytes = ByteString.fromhex @@ -197,7 +223,7 @@ bytes = ByteString.fromhex ### Symmetric encryption. class _tmp: - def encrypt(me, n, m, tsz = None, h = ByteString('')): + def encrypt(me, n, m, tsz = None, h = ByteString.zero(0)): if tsz is None: tsz = me.__class__.tagsz.default e = me.enc(n, len(h), len(m), tsz) if not len(h): a = None @@ -205,7 +231,7 @@ class _tmp: c0 = e.encrypt(m) c1, t = e.done(aad = a) return c0 + c1, t - def decrypt(me, n, c, t, h = ByteString('')): + def decrypt(me, n, c, t, h = ByteString.zero(0)): d = me.dec(n, len(h), len(c), len(t)) if not len(h): a = None else: a = d.aad().hash(h) @@ -269,7 +295,7 @@ class _ShakeBase (_HashBase): ## Python gets really confused if I try to augment `__new__' on native ## classes, so wrap and delegate. Sorry. - def __init__(me, perso = '', *args, **kw): + def __init__(me, perso = _bin(''), *args, **kw): super(_ShakeBase, me).__init__(*args, **kw) me._h = me._SHAKE(perso = perso, func = me._FUNC) @@ -317,7 +343,7 @@ _augment(_ShakeBase, _tmp) Shake._Z = _ShakeBase._Z = ByteString.zero(200) class KMAC (_ShakeBase): - _FUNC = 'KMAC' + _FUNC = _bin('KMAC') def __init__(me, k, *arg, **kw): super(KMAC, me).__init__(*arg, **kw) with me.bytepad(): me.stringenc(k) @@ -329,7 +355,7 @@ class KMAC (_ShakeBase): me.rightenc(0) return super(KMAC, me).xof() @classmethod - def _bare_new(cls): return cls("") + def _bare_new(cls): return cls(_bin("")) class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16 class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32 @@ -390,17 +416,25 @@ class BaseRat (object): def __rtruediv__(me, you): n, d = _split_rat(you) return type(me)(me._d*n, me._n*d) - __div__ = __truediv__ - __rdiv__ = __rtruediv__ - def __cmp__(me, you): - n, d = _split_rat(you) - return cmp(me._n*d, n*me._d) - def __rcmp__(me, you): + if _sys.version_info < (3,): + __div__ = __truediv__ + __rdiv__ = __rtruediv__ + def _order(me, you, op): n, d = _split_rat(you) - return cmp(n*me._d, me._n*d) + return op(me._n*d, n*me._d) + def __eq__(me, you): return me._order(you, lambda x, y: x == y) + def __ne__(me, you): return me._order(you, lambda x, y: x != y) + def __le__(me, you): return me._order(you, lambda x, y: x <= y) + def __lt__(me, you): return me._order(you, lambda x, y: x < y) + def __gt__(me, you): return me._order(you, lambda x, y: x > y) + def __ge__(me, you): return me._order(you, lambda x, y: x >= y) class IntRat (BaseRat): RING = MP + def __new__(cls, a, b): + if isinstance(a, float) or isinstance(b, float): return a/b + return super(IntRat, cls).__new__(cls, a, b) + def __float__(me): return float(me._n)/float(me._d) class GFRat (BaseRat): RING = GF @@ -414,10 +448,15 @@ class _tmp: def mont(x): return MPMont(x) def barrett(x): return MPBarrett(x) def reduce(x): return MPReduce(x) - def __truediv__(me, you): return IntRat(me, you) - def __rtruediv__(me, you): return IntRat(you, me) - __div__ = __truediv__ - __rdiv__ = __rtruediv__ + def __truediv__(me, you): + if isinstance(you, float): return _long(me)/you + else: return IntRat(me, you) + def __rtruediv__(me, you): + if isinstance(you, float): return you/_long(me) + else: return IntRat(you, me) + if _sys.version_info < (3,): + __div__ = __truediv__ + __rdiv__ = __rtruediv__ _repr_pretty_ = _pp_str _augment(MP, _tmp) @@ -430,8 +469,9 @@ class _tmp: def quadsolve(x, y): return x.reduce().quadsolve(y) def __truediv__(me, you): return GFRat(me, you) def __rtruediv__(me, you): return GFRat(you, me) - __div__ = __truediv__ - __rdiv__ = __rtruediv__ + if _sys.version_info < (3,): + __div__ = __truediv__ + __rdiv__ = __rtruediv__ _repr_pretty_ = _pp_str _augment(GF, _tmp) @@ -451,7 +491,7 @@ class _tmp: _augment(Field, _tmp) class _tmp: - def __repr__(me): return '%s(%sL)' % (_clsname(me), me.p) + def __repr__(me): return '%s(%s)' % (_clsname(me), me.p) def __hash__(me): return 0x114401de ^ hash(me.p) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup_tyname(pp, me) @@ -462,7 +502,7 @@ class _tmp: _augment(PrimeField, _tmp) class _tmp: - def __repr__(me): return '%s(%#xL)' % (_clsname(me), me.p) + def __repr__(me): return '%s(%#x)' % (_clsname(me), me.p) def ec(me, a, b): return ECBinProjCurve(me, a, b) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup_tyname(pp, me) @@ -617,13 +657,13 @@ class _tmp: pp.end_group(ind, ')') def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0 def best(me, sz): - if sz < me.min: raise ValueError, 'key too small' - elif sz > me.max: return me.max + if sz < me.min: raise ValueError('key too small') + elif me.max is not None and sz > me.max: return me.max else: return sz - sz%me.mod def pad(me, sz): - if sz > me.max: raise ValueError, 'key too large' + if me.max is not None and sz > me.max: raise ValueError('key too large') elif sz < me.min: return me.min - else: sz += me.mod; return sz - sz%me.mod + else: sz += me.mod - 1; return sz - sz%me.mod _augment(KeySZRange, _tmp) class _tmp: @@ -643,13 +683,13 @@ class _tmp: found = -1 for i in me.set: if found < i <= sz: found = i - if found < 0: raise ValueError, 'key too small' + if found < 0: raise ValueError('key too small') return found def pad(me, sz): found = -1 for i in me.set: if sz <= i and (found == -1 or i < found): found = i - if found < 0: raise ValueError, 'key too large' + if found < 0: raise ValueError('key too large') return found _augment(KeySZSet, _tmp) @@ -657,21 +697,34 @@ _augment(KeySZSet, _tmp) ### Key data objects. class _tmp: + def merge(me, file, report = None): + """KF.merge(FILE, [report = ])""" + name = file.name + lno = 1 + for line in file: + me.mergeline(name, lno, line, report) + lno += 1 + return me def __repr__(me): return '%s(%r)' % (_clsname(me), me.name) _augment(KeyFile, _tmp) class _tmp: + def extract(me, file, filter = ''): + """KEY.extract(FILE, [filter = ])""" + line = me.extractline(filter) + file.write(line) + return me def __repr__(me): return '%s(%r)' % (_clsname(me), me.fulltag) _augment(Key, _tmp) class _tmp: def __repr__(me): return '%s({%s})' % (_clsname(me), - ', '.join(['%r: %r' % kv for kv in me.iteritems()])) + ', '.join(['%r: %r' % kv for kv in _iteritems(me)()])) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup_tyname(pp, me) if cyclep: pp.text('...') - else: _pp_dict(pp, me.iteritems()) + else: _pp_dict(pp, _iteritems(me)) pp.end_group(ind, ')') _augment(KeyAttributes, _tmp) @@ -690,38 +743,66 @@ class _tmp: pp.text(','); pp.breakable() pp.pretty(me.writeflags(me.flags)) pp.end_group(ind, ')') + def __hash__(me): return me._HASHBASE ^ hash(me._guts()) + def __eq__(me, kd): + return type(me) == type(kd) and \ + me._guts() == kd._guts() and \ + me.flags == kd.flags + def __ne__(me, kd): + return not me == kd _augment(KeyData, _tmp) class _tmp: def _guts(me): return me.bin + def __eq__(me, kd): + return isinstance(kd, KeyDataBinary) and me.bin == kd.bin _augment(KeyDataBinary, _tmp) +KeyDataBinary._HASHBASE = 0x961755c3 class _tmp: def _guts(me): return me.ct _augment(KeyDataEncrypted, _tmp) +KeyDataEncrypted._HASHBASE = 0xffe000d4 class _tmp: def _guts(me): return me.mp _augment(KeyDataMP, _tmp) +KeyDataMP._HASHBASE = 0x1cb64d69 class _tmp: def _guts(me): return me.str _augment(KeyDataString, _tmp) +KeyDataString._HASHBASE = 0x349c33ea class _tmp: def _guts(me): return me.ecpt _augment(KeyDataECPt, _tmp) +KeyDataECPt._HASHBASE = 0x2509718b class _tmp: def __repr__(me): return '%s({%s})' % (_clsname(me), - ', '.join(['%r: %r' % kv for kv in me.iteritems()])) + ', '.join(['%r: %r' % kv for kv in _iteritems(me)])) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup_tyname(pp, me, '({ ') if cyclep: pp.text('...') - else: _pp_dict(pp, me.iteritems()) + else: _pp_dict(pp, _iteritems(me)) pp.end_group(ind, ' })') + def __hash__(me): + h = me._HASHBASE + for k, v in _iteritems(me): + h = ((h << 1) ^ 3*hash(k) ^ 5*hash(v))&0xffffffff + return h + def __eq__(me, kd): + if type(me) != type(kd) or me.flags != kd.flags or len(me) != len(kd): + return False + for k, v in _iteritems(me): + try: vv = kd[k] + except KeyError: return False + if v != vv: return False + return True _augment(KeyDataStructured, _tmp) +KeyDataStructured._HASHBASE = 0x85851b21 ###-------------------------------------------------------------------------- ### Abstract groups. @@ -796,7 +877,7 @@ _augment(GE, _tmp) ### RSA encoding techniques. class PKCS1Crypt (object): - def __init__(me, ep = '', rng = rand): + def __init__(me, ep = _bin(''), rng = rand): me.ep = ep me.rng = rng def encode(me, msg, nbits): @@ -805,7 +886,7 @@ class PKCS1Crypt (object): return _base._p1crypt_decode(ct, nbits, me.ep, me.rng) class PKCS1Sig (object): - def __init__(me, ep = '', rng = rand): + def __init__(me, ep = _bin(''), rng = rand): me.ep = ep me.rng = rng def encode(me, msg, nbits): @@ -814,7 +895,7 @@ class PKCS1Sig (object): return _base._p1sig_decode(msg, sig, nbits, me.ep, me.rng) class OAEP (object): - def __init__(me, mgf = sha_mgf, hash = sha, ep = '', rng = rand): + def __init__(me, mgf = sha_mgf, hash = sha, ep = _bin(''), rng = rand): me.mgf = mgf me.hash = hash me.ep = ep @@ -929,7 +1010,7 @@ Z128 = ByteString.zero(16) class _BasePub (object): def __init__(me, pub, *args, **kw): - if not me._PUBSZ.check(len(pub)): raise ValueError, 'bad public key' + if not me._PUBSZ.check(len(pub)): raise ValueError('bad public key') super(_BasePub, me).__init__(*args, **kw) me.pub = pub def __repr__(me): return '%s(pub = %r)' % (_clsname(me), me.pub) @@ -942,7 +1023,7 @@ class _BasePub (object): class _BasePriv (object): def __init__(me, priv, pub = None, *args, **kw): - if not me._KEYSZ.check(len(priv)): raise ValueError, 'bad private key' + if not me._KEYSZ.check(len(priv)): raise ValueError('bad private key') if pub is None: pub = me._pubkey(priv) super(_BasePriv, me).__init__(pub = pub, *args, **kw) me.priv = priv @@ -1015,51 +1096,17 @@ class Ed448Priv (_EdDSAPriv, Ed448Pub): return ed448_sign(me.priv, msg, pub = me.pub, **kw) ###-------------------------------------------------------------------------- -### Built-in named curves and prime groups. - -class _groupmap (object): - def __init__(me, map, nth): - me.map = map - me.nth = nth - me._n = max(map.values()) + 1 - me.i = me._n*[None] +### Built-in algorithm and group tables. + +class _tmp: def __repr__(me): - return '{%s}' % ', '.join(['%r: %r' % kv for kv in me.iteritems()]) + return '{%s}' % ', '.join(['%r: %r' % kv for kv in _iteritems(me)]) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup(pp, '{ ') if cyclep: pp.text('...') - else: _pp_dict(pp, me.iteritems()) + else: _pp_dict(pp, _iteritems(me)) pp.end_group(ind, ' }') - def __len__(me): - return me._n - def __contains__(me, k): - return k in me.map - def __getitem__(me, k): - i = me.map[k] - if me.i[i] is None: - me.i[i] = me.nth(i) - return me.i[i] - def __setitem__(me, k, v): - raise TypeError, "immutable object" - def __iter__(me): - return iter(me.map) - def iterkeys(me): - return iter(me.map) - def itervalues(me): - for k in me: - yield me[k] - def iteritems(me): - for k in me: - yield k, me[k] - def keys(me): - return [k for k in me] - def values(me): - return [me[k] for k in me] - def items(me): - return [(k, me[k]) for k in me] -eccurves = _groupmap(_base._eccurves, ECInfo._curven) -primegroups = _groupmap(_base._pgroups, DHInfo._groupn) -bingroups = _groupmap(_base._bingroups, BinDHInfo._groupn) +_augment(_base._MiscTable, _tmp) ###-------------------------------------------------------------------------- ### Prime number generation. @@ -1169,7 +1216,7 @@ class SimulStepper (PrimeGenEventHandler): me.add = add def _stepfn(me, step): if step <= 0: - raise ValueError, 'step must be positive' + raise ValueError('step must be positive') if step <= MPW_MAX: return lambda f: f.step(step) j = PrimeFilter(step)