X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/blobdiff_plain/688625b6d288ff893f5e96dd3bd7f10110e23639..bbb113f66ef45881f595c8dacbcb492554848878:/catacomb/__init__.py diff --git a/catacomb/__init__.py b/catacomb/__init__.py index 8713ab7..80b6095 100644 --- a/catacomb/__init__.py +++ b/catacomb/__init__.py @@ -99,20 +99,7 @@ def _init(): for i in b: if i[0] != '_': d[i] = b[i]; - for i in ['ByteString', - 'MP', 'GF', 'Field', - 'ECPt', 'ECPtCurve', 'ECCurve', 'ECInfo', - 'DHInfo', 'BinDHInfo', 'RSAPriv', 'BBSPriv', - 'PrimeFilter', 'RabinMiller', - 'Group', 'GE', - 'KeySZ', 'KeyData']: - c = d[i] - pre = '_' + i + '_' - plen = len(pre) - for j in b: - if j[:plen] == pre: - setattr(c, j[plen:], classmethod(b[j])) - for i in [gcciphers, gchashes, gcmacs, gcprps]: + for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]: for c in i.itervalues(): d[_fixname(c.name)] = c for c in gccrands.itervalues(): @@ -137,7 +124,7 @@ def _augment(c, cc): def _checkend(r): x, rest = r if rest != '': - raise SyntaxError, 'junk at end of string' + raise SyntaxError('junk at end of string') return x ## Some pretty-printing utilities. @@ -194,6 +181,27 @@ ByteString.__hash__ = str.__hash__ bytes = ByteString.fromhex ###-------------------------------------------------------------------------- +### Symmetric encryption. + +class _tmp: + def encrypt(me, n, m, tsz = None, h = ByteString('')): + if tsz is None: tsz = me.__class__.tagsz.default + e = me.enc(n, len(h), len(m), tsz) + if not len(h): a = None + else: a = e.aad().hash(h) + c0 = e.encrypt(m) + c1, t = e.done(aad = a) + return c0 + c1, t + def decrypt(me, n, c, t, h = ByteString('')): + d = me.dec(n, len(h), len(c), len(t)) + if not len(h): a = None + else: a = d.aad().hash(h) + m = d.decrypt(c) + m += d.done(t, aad = a) + return m +_augment(GAEKey, _tmp) + +###-------------------------------------------------------------------------- ### Hashing. class _tmp: @@ -206,15 +214,31 @@ _augment(Poly1305Hash, _tmp) class _HashBase (object): ## The standard hash methods. Assume that `hash' is defined and returns ## the receiver. - def hashu8(me, n): return me.hash(_pack('B', n)) - def hashu16l(me, n): return me.hash(_pack('H', n)) + def _check_range(me, n, max): + if not (0 <= n <= max): raise OverflowError("out of range") + def hashu8(me, n): + me._check_range(n, 0xff) + return me.hash(_pack('B', n)) + def hashu16l(me, n): + me._check_range(n, 0xffff) + return me.hash(_pack('H', n)) hashu16 = hashu16b - def hashu32l(me, n): return me.hash(_pack('L', n)) + def hashu32l(me, n): + me._check_range(n, 0xffffffff) + return me.hash(_pack('L', n)) hashu32 = hashu32b - def hashu64l(me, n): return me.hash(_pack('Q', n)) + def hashu64l(me, n): + me._check_range(n, 0xffffffffffffffff) + return me.hash(_pack('Q', n)) hashu64 = hashu64b def hashbuf8(me, s): return me.hashu8(len(s)).hash(s) def hashbuf16l(me, s): return me.hashu16l(len(s)).hash(s) @@ -277,7 +301,7 @@ class _tmp: me.bytepad_after() _augment(Shake, _tmp) _augment(_ShakeBase, _tmp) -Shake._Z = _ShakeBase._Z = ByteString(200*'\0') +Shake._Z = _ShakeBase._Z = ByteString.zero(200) class KMAC (_ShakeBase): _FUNC = 'KMAC' @@ -301,21 +325,12 @@ class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32 ### NaCl `secretbox'. def secret_box(k, n, m): - E = xsalsa20(k).setiv(n) - r = E.enczero(poly1305.keysz.default) - s = E.enczero(poly1305.masksz) - y = E.encrypt(m) - t = poly1305(r)(s).hash(y).done() - return ByteString(t + y) + y, t = salsa20_naclbox(k).encrypt(n, m) + return t + y def secret_unbox(k, n, c): - E = xsalsa20(k).setiv(n) - r = E.enczero(poly1305.keysz.default) - s = E.enczero(poly1305.masksz) - y = c[poly1305.tagsz:] - if not poly1305(r)(s).hash(y).check(c[0:poly1305.tagsz]): - raise ValueError, 'decryption failed' - return E.decrypt(c[poly1305.tagsz:]) + tsz = poly1305.tagsz + return salsa20_naclbox(k).decrypt(n, c[tsz:], c[0:tsz]) ###-------------------------------------------------------------------------- ### Multiprecision integers and binary polynomials. @@ -570,6 +585,7 @@ class _tmp: def __repr__(me): return '%s(%d)' % (_clsname(me), me.default) def check(me, sz): return True def best(me, sz): return sz + def pad(me, sz): return sz _augment(KeySZAny, _tmp) class _tmp: @@ -586,11 +602,15 @@ class _tmp: pp.pretty(me.max); pp.text(','); pp.breakable() pp.pretty(me.mod) pp.end_group(ind, ')') - def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0 + def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0 def best(me, sz): - if sz < me.min: raise ValueError, 'key too small' + if sz < me.min: raise ValueError('key too small') elif sz > me.max: return me.max - else: return sz - (sz % me.mod) + else: return sz - sz%me.mod + def pad(me, sz): + if sz > me.max: raise ValueError('key too large') + elif sz < me.min: return me.min + else: sz += me.mod - 1; return sz - sz%me.mod _augment(KeySZRange, _tmp) class _tmp: @@ -610,7 +630,13 @@ class _tmp: found = -1 for i in me.set: if found < i <= sz: found = i - if found < 0: raise ValueError, 'key too small' + if found < 0: raise ValueError('key too small') + return found + def pad(me, sz): + found = -1 + for i in me.set: + if sz <= i and (found == -1 or i < found): found = i + if found < 0: raise ValueError('key too large') return found _augment(KeySZSet, _tmp) @@ -850,21 +876,23 @@ _augment(RSAPriv, _tmp) ### DSA and related schemes. class _tmp: - def __repr__(me): return '%s(G = %r, p = %r)' % (_clsname(me), me.G, me.p) + def __repr__(me): return '%s(G = %r, p = %r, hash = %r)' % \ + (_clsname(me), me.G, me.p, me.hash) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup_tyname(pp, me) if cyclep: pp.text('...') else: _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable() - _pp_kv(pp, 'p', me.p) + _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable() + _pp_kv(pp, 'hash', me.hash) pp.end_group(ind, ')') _augment(DSAPub, _tmp) _augment(KCDSAPub, _tmp) class _tmp: - def __repr__(me): return '%s(G = %r, u = %s, p = %r)' % \ - (_clsname(me), me.G, _repr_secret(me.u), me.p) + def __repr__(me): return '%s(G = %r, u = %s, p = %r, hash = %r)' % \ + (_clsname(me), me.G, _repr_secret(me.u), me.p, me.hash) def _repr_pretty_(me, pp, cyclep): ind = _pp_bgroup_tyname(pp, me) if cyclep: @@ -872,7 +900,8 @@ class _tmp: else: _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable() _pp_kv(pp, 'u', me.u, True); pp.text(','); pp.breakable() - _pp_kv(pp, 'p', me.p) + _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable() + _pp_kv(pp, 'hash', me.hash) pp.end_group(ind, ')') _augment(DSAPriv, _tmp) _augment(KCDSAPriv, _tmp) @@ -887,7 +916,7 @@ Z128 = ByteString.zero(16) class _BasePub (object): def __init__(me, pub, *args, **kw): - if not me._PUBSZ.check(len(pub)): raise ValueError, 'bad public key' + if not me._PUBSZ.check(len(pub)): raise ValueError('bad public key') super(_BasePub, me).__init__(*args, **kw) me.pub = pub def __repr__(me): return '%s(pub = %r)' % (_clsname(me), me.pub) @@ -900,7 +929,7 @@ class _BasePub (object): class _BasePriv (object): def __init__(me, priv, pub = None, *args, **kw): - if not me._KEYSZ.check(len(priv)): raise ValueError, 'bad private key' + if not me._KEYSZ.check(len(priv)): raise ValueError('bad private key') if pub is None: pub = me._pubkey(priv) super(_BasePriv, me).__init__(pub = pub, *args, **kw) me.priv = priv @@ -973,14 +1002,9 @@ class Ed448Priv (_EdDSAPriv, Ed448Pub): return ed448_sign(me.priv, msg, pub = me.pub, **kw) ###-------------------------------------------------------------------------- -### Built-in named curves and prime groups. - -class _groupmap (object): - def __init__(me, map, nth): - me.map = map - me.nth = nth - me._n = max(map.values()) + 1 - me.i = me._n*[None] +### Built-in algorithm and group tables. + +class _tmp: def __repr__(me): return '{%s}' % ', '.join(['%r: %r' % kv for kv in me.iteritems()]) def _repr_pretty_(me, pp, cyclep): @@ -988,36 +1012,7 @@ class _groupmap (object): if cyclep: pp.text('...') else: _pp_dict(pp, me.iteritems()) pp.end_group(ind, ' }') - def __len__(me): - return me._n - def __contains__(me, k): - return k in me.map - def __getitem__(me, k): - i = me.map[k] - if me.i[i] is None: - me.i[i] = me.nth(i) - return me.i[i] - def __setitem__(me, k, v): - raise TypeError, "immutable object" - def __iter__(me): - return iter(me.map) - def iterkeys(me): - return iter(me.map) - def itervalues(me): - for k in me: - yield me[k] - def iteritems(me): - for k in me: - yield k, me[k] - def keys(me): - return [k for k in me] - def values(me): - return [me[k] for k in me] - def items(me): - return [(k, me[k]) for k in me] -eccurves = _groupmap(_base._eccurves, ECInfo._curven) -primegroups = _groupmap(_base._pgroups, DHInfo._groupn) -bingroups = _groupmap(_base._bingroups, BinDHInfo._groupn) +_augment(_base._MiscTable, _tmp) ###-------------------------------------------------------------------------- ### Prime number generation. @@ -1127,7 +1122,7 @@ class SimulStepper (PrimeGenEventHandler): me.add = add def _stepfn(me, step): if step <= 0: - raise ValueError, 'step must be positive' + raise ValueError('step must be positive') if step <= MPW_MAX: return lambda f: f.step(step) j = PrimeFilter(step)