X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/blobdiff_plain/f76230157bd427829e49628de37d53f9c8ae7842..e5c26109231763a2464a8f8076e7b633996b5d5c:/catacomb/__init__.py diff --git a/catacomb/__init__.py b/catacomb/__init__.py index 0e5c31c..08ec3d7 100644 --- a/catacomb/__init__.py +++ b/catacomb/__init__.py @@ -250,115 +250,13 @@ class _tmp: _augment(GHash, _tmp) _augment(Poly1305Hash, _tmp) -class _HashBase (object): - ## The standard hash methods. Assume that `hash' is defined and returns - ## the receiver. - def _check_range(me, n, max): - if not (0 <= n <= max): raise OverflowError("out of range") - def hashu8(me, n): - me._check_range(n, 0xff) - return me.hash(_pack('B', n)) - def hashu16l(me, n): - me._check_range(n, 0xffff) - return me.hash(_pack('H', n)) - hashu16 = hashu16b - def hashu32l(me, n): - me._check_range(n, 0xffffffff) - return me.hash(_pack('L', n)) - hashu32 = hashu32b - def hashu64l(me, n): - me._check_range(n, 0xffffffffffffffff) - return me.hash(_pack('Q', n)) - hashu64 = hashu64b - def hashbuf8(me, s): return me.hashu8(len(s)).hash(s) - def hashbuf16l(me, s): return me.hashu16l(len(s)).hash(s) - def hashbuf16b(me, s): return me.hashu16b(len(s)).hash(s) - hashbuf16 = hashbuf16b - def hashbuf32l(me, s): return me.hashu32l(len(s)).hash(s) - def hashbuf32b(me, s): return me.hashu32b(len(s)).hash(s) - hashbuf32 = hashbuf32b - def hashbuf64l(me, s): return me.hashu64l(len(s)).hash(s) - def hashbuf64b(me, s): return me.hashu64b(len(s)).hash(s) - hashbuf64 = hashbuf64b - def hashstrz(me, s): return me.hash(s).hashu8(0) - -class _ShakeBase (_HashBase): - - ## Python gets really confused if I try to augment `__new__' on native - ## classes, so wrap and delegate. Sorry. - def __init__(me, perso = _bin(''), *args, **kw): - super(_ShakeBase, me).__init__(*args, **kw) - me._h = me._SHAKE(perso = perso, func = me._FUNC) - - ## Delegate methods... - def copy(me): new = me.__class__._bare_new(); new._copy(me); return new - def _copy(me, other): me._h = other._h.copy() - def hash(me, m): me._h.hash(m); return me - def xof(me): me._h.xof(); return me - def get(me, n): return me._h.get(n) - def mask(me, m): return me._h.mask(m) - def done(me, n): return me._h.done(n) - def check(me, h): return ctstreq(h, me.done(len(h))) - @property - def state(me): return me._h.state - @property - def buffered(me): return me._h.buffered - @property - def rate(me): return me._h.rate - @classmethod - def _bare_new(cls): return cls() - class _tmp: def check(me, h): return ctstreq(h, me.done(len(h))) - def leftenc(me, n): - nn = MP(n).storeb() - return me.hashu8(len(nn)).hash(nn) - def rightenc(me, n): - nn = MP(n).storeb() - return me.hash(nn).hashu8(len(nn)) - def stringenc(me, str): - return me.leftenc(8*len(str)).hash(str) - def bytepad_before(me): - return me.leftenc(me.rate) - def bytepad_after(me): - if me.buffered: me.hash(me._Z[:me.rate - me.buffered]) - return me - @_ctxmgr - def bytepad(me): - me.bytepad_before() - yield me - me.bytepad_after() _augment(Shake, _tmp) -_augment(_ShakeBase, _tmp) -Shake._Z = _ShakeBase._Z = ByteString.zero(200) - -class KMAC (_ShakeBase): - _FUNC = _bin('KMAC') - def __init__(me, k, *arg, **kw): - super(KMAC, me).__init__(*arg, **kw) - with me.bytepad(): me.stringenc(k) - def done(me, n = -1): - if n < 0: n = me._TAGSZ - me.rightenc(8*n) - return super(KMAC, me).done(n) - def xof(me): - me.rightenc(0) - return super(KMAC, me).xof() - @classmethod - def _bare_new(cls): return cls(_bin("")) -class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16 -class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32 +KMAC128.keysz = KeySZAny(16); KMAC128.tagsz = 16 +KMAC256.keysz = KeySZAny(32); KMAC256.tagsz = 32 ###-------------------------------------------------------------------------- ### NaCl `secretbox'. @@ -374,15 +272,12 @@ def secret_unbox(k, n, c): ###-------------------------------------------------------------------------- ### Multiprecision integers and binary polynomials. -def _split_rat(x): - if isinstance(x, BaseRat): return x._n, x._d - else: return x, 1 class BaseRat (object): """Base class implementing fields of fractions over Euclidean domains.""" def __new__(cls, a, b): - a, b = cls.RING(a), cls.RING(b) + a, b = cls.RING._implicit(a), cls.RING._implicit(b) q, r = divmod(a, b) - if r == 0: return q + if r == cls.ZERO: return q g = b.gcd(r) me = super(BaseRat, cls).__new__(cls) me._n = a//g @@ -396,31 +291,34 @@ class BaseRat (object): def __repr__(me): return '%s(%s, %s)' % (_clsname(me), me._n, me._d) _repr_pretty_ = _pp_str + def _split_rat(me, x): + if isinstance(x, me.__class__): return x._n, x._d + else: return x, me.ONE def __add__(me, you): - n, d = _split_rat(you) + n, d = me._split_rat(you) return type(me)(me._n*d + n*me._d, d*me._d) __radd__ = __add__ def __sub__(me, you): - n, d = _split_rat(you) + n, d = me._split_rat(you) return type(me)(me._n*d - n*me._d, d*me._d) def __rsub__(me, you): - n, d = _split_rat(you) + n, d = me._split_rat(you) return type(me)(n*me._d - me._n*d, d*me._d) def __mul__(me, you): - n, d = _split_rat(you) + n, d = me._split_rat(you) return type(me)(me._n*n, me._d*d) __rmul__ = __mul__ def __truediv__(me, you): - n, d = _split_rat(you) + n, d = me._split_rat(you) return type(me)(me._n*d, me._d*n) def __rtruediv__(me, you): - n, d = _split_rat(you) + n, d = me._split_rat(you) return type(me)(me._d*n, me._n*d) if _sys.version_info < (3,): __div__ = __truediv__ __rdiv__ = __rtruediv__ def _order(me, you, op): - n, d = _split_rat(you) + n, d = me._split_rat(you) return op(me._n*d, n*me._d) def __eq__(me, you): return me._order(you, lambda x, y: x == y) def __ne__(me, you): return me._order(you, lambda x, y: x != y) @@ -431,6 +329,7 @@ class BaseRat (object): class IntRat (BaseRat): RING = MP + ZERO, ONE = MP(0), MP(1) def __new__(cls, a, b): if isinstance(a, float) or isinstance(b, float): return a/b return super(IntRat, cls).__new__(cls, a, b) @@ -438,6 +337,7 @@ class IntRat (BaseRat): class GFRat (BaseRat): RING = GF + ZERO, ONE = GF(0), GF(1) class _tmp: def negp(x): return x < 0 @@ -544,6 +444,8 @@ class _tmp: pp.pretty(me.a); pp.text(','); pp.breakable() pp.pretty(me.b) pp.end_group(ind, ')') + def fromstring(str): return _checkend(ECCurve.parse(str)) + fromstring = staticmethod(fromstring) def frombuf(me, s): return ecpt.frombuf(me, s) def fromraw(me, s): @@ -608,6 +510,8 @@ class _tmp: h ^= hash(me.curve) h ^= 2*hash(me.G) & 0xffffffff return h + def fromstring(str): return _checkend(ECInfo.parse(str)) + fromstring = staticmethod(fromstring) def group(me): return ECGroup(me) _augment(ECInfo, _tmp) @@ -658,10 +562,10 @@ class _tmp: def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0 def best(me, sz): if sz < me.min: raise ValueError('key too small') - elif sz > me.max: return me.max + elif me.max is not None and sz > me.max: return me.max else: return sz - sz%me.mod def pad(me, sz): - if sz > me.max: raise ValueError('key too large') + if me.max is not None and sz > me.max: raise ValueError('key too large') elif sz < me.min: return me.min else: sz += me.mod - 1; return sz - sz%me.mod _augment(KeySZRange, _tmp)