Split 'pyke/' into commit 'c80de12d8d0827e0553fed2e4d392cb9bf3a378f'
[catacomb-python] / catacomb / __init__.py
index 80b6095..0e5c31c 100644 (file)
@@ -55,7 +55,8 @@ if _dlflags >= 0:
       else: pass # can't do this.
   _sys.setdlopenflags(_dlflags)
 
-import _base
+if _sys.version_info >= (3,): from . import _base
+else: import _base
 
 if _odlflags >= 0:
   _sys.setdlopenflags(_odlflags)
@@ -77,6 +78,25 @@ def default_lostexchook(why, ty, val, tb):
   _sys.stderr.write("\n")
 lostexchook = default_lostexchook
 
+## Text/binary conversions.
+if _sys.version_info >= (3,):
+  def _bin(s): return s.encode('iso8859-1')
+else:
+  def _bin(s): return s
+
+## Iterating over dictionaries.
+if _sys.version_info >= (3,):
+  def _iteritems(dict): return dict.items()
+  def _itervalues(dict): return dict.values()
+else:
+  def _iteritems(dict): return dict.iteritems()
+  def _itervalues(dict): return dict.itervalues()
+
+## The built-in bignum type.
+try: long
+except NameError: _long = int
+else: _long = long
+
 ## How to fix a name back into the right identifier.  Alas, the rules are not
 ## consistent.
 def _fixname(name):
@@ -100,9 +120,9 @@ def _init():
     if i[0] != '_':
       d[i] = b[i];
   for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]:
-    for c in i.itervalues():
+    for c in _itervalues(i):
       d[_fixname(c.name)] = c
-  for c in gccrands.itervalues():
+  for c in _itervalues(gccrands):
     d[_fixname(c.name + 'rand')] = c
 _init()
 
@@ -154,7 +174,8 @@ def _pp_commas(pp, printfn, items):
     else: pp.text(','); pp.breakable()
     printfn(i)
 def _pp_dict(pp, items):
-  def p((k, v)):
+  def p(kv):
+    k, v = kv
     pp.begin_group(0)
     pp.pretty(k)
     pp.text(':')
@@ -166,16 +187,34 @@ def _pp_dict(pp, items):
   _pp_commas(pp, p, items)
 
 ###--------------------------------------------------------------------------
+### Mappings.
+
+if _sys.version_info >= (3,):
+  class _tmp:
+    def __str__(me): return '%s(%r)' % (type(me).__name__, list(me))
+    __repr__ = __str__
+    def _repr_pretty_(me, pp, cyclep):
+      ind = _pp_bgroup_tyname(pp, me, '([')
+      _pp_commas(pp, pp.pretty, me)
+      pp.end_group(ind, '])')
+  _augment(_base._KeyView, _tmp)
+  _augment(_base._ValueView, _tmp)
+  _augment(_base._ItemView, _tmp)
+
+###--------------------------------------------------------------------------
 ### Bytestrings.
 
 class _tmp:
   def fromhex(x):
     return ByteString(_unhexify(x))
   fromhex = staticmethod(fromhex)
-  def __hex__(me):
-    return _hexify(me)
+  if _sys.version_info >= (3,):
+    def hex(me): return _hexify(me).decode()
+  else:
+    def hex(me): return _hexify(me)
+    __hex__ = hex
   def __repr__(me):
-    return 'bytes(%r)' % hex(me)
+    return 'bytes(%r)' % me.hex()
 _augment(ByteString, _tmp)
 ByteString.__hash__ = str.__hash__
 bytes = ByteString.fromhex
@@ -184,7 +223,7 @@ bytes = ByteString.fromhex
 ### Symmetric encryption.
 
 class _tmp:
-  def encrypt(me, n, m, tsz = None, h = ByteString('')):
+  def encrypt(me, n, m, tsz = None, h = ByteString.zero(0)):
     if tsz is None: tsz = me.__class__.tagsz.default
     e = me.enc(n, len(h), len(m), tsz)
     if not len(h): a = None
@@ -192,7 +231,7 @@ class _tmp:
     c0 = e.encrypt(m)
     c1, t = e.done(aad = a)
     return c0 + c1, t
-  def decrypt(me, n, c, t, h = ByteString('')):
+  def decrypt(me, n, c, t, h = ByteString.zero(0)):
     d = me.dec(n, len(h), len(c), len(t))
     if not len(h): a = None
     else: a = d.aad().hash(h)
@@ -256,7 +295,7 @@ class _ShakeBase (_HashBase):
 
   ## Python gets really confused if I try to augment `__new__' on native
   ## classes, so wrap and delegate.  Sorry.
-  def __init__(me, perso = '', *args, **kw):
+  def __init__(me, perso = _bin(''), *args, **kw):
     super(_ShakeBase, me).__init__(*args, **kw)
     me._h = me._SHAKE(perso = perso, func = me._FUNC)
 
@@ -304,7 +343,7 @@ _augment(_ShakeBase, _tmp)
 Shake._Z = _ShakeBase._Z = ByteString.zero(200)
 
 class KMAC (_ShakeBase):
-  _FUNC = 'KMAC'
+  _FUNC = _bin('KMAC')
   def __init__(me, k, *arg, **kw):
     super(KMAC, me).__init__(*arg, **kw)
     with me.bytepad(): me.stringenc(k)
@@ -316,7 +355,7 @@ class KMAC (_ShakeBase):
     me.rightenc(0)
     return super(KMAC, me).xof()
   @classmethod
-  def _bare_new(cls): return cls("")
+  def _bare_new(cls): return cls(_bin(""))
 
 class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16
 class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
@@ -377,17 +416,25 @@ class BaseRat (object):
   def __rtruediv__(me, you):
     n, d = _split_rat(you)
     return type(me)(me._d*n, me._n*d)
-  __div__ = __truediv__
-  __rdiv__ = __rtruediv__
-  def __cmp__(me, you):
+  if _sys.version_info < (3,):
+    __div__ = __truediv__
+    __rdiv__ = __rtruediv__
+  def _order(me, you, op):
     n, d = _split_rat(you)
-    return cmp(me._n*d, n*me._d)
-  def __rcmp__(me, you):
-    n, d = _split_rat(you)
-    return cmp(n*me._d, me._n*d)
+    return op(me._n*d, n*me._d)
+  def __eq__(me, you): return me._order(you, lambda x, y: x == y)
+  def __ne__(me, you): return me._order(you, lambda x, y: x != y)
+  def __le__(me, you): return me._order(you, lambda x, y: x <= y)
+  def __lt__(me, you): return me._order(you, lambda x, y: x <  y)
+  def __gt__(me, you): return me._order(you, lambda x, y: x >  y)
+  def __ge__(me, you): return me._order(you, lambda x, y: x >= y)
 
 class IntRat (BaseRat):
   RING = MP
+  def __new__(cls, a, b):
+    if isinstance(a, float) or isinstance(b, float): return a/b
+    return super(IntRat, cls).__new__(cls, a, b)
+  def __float__(me): return float(me._n)/float(me._d)
 
 class GFRat (BaseRat):
   RING = GF
@@ -401,10 +448,15 @@ class _tmp:
   def mont(x): return MPMont(x)
   def barrett(x): return MPBarrett(x)
   def reduce(x): return MPReduce(x)
-  def __truediv__(me, you): return IntRat(me, you)
-  def __rtruediv__(me, you): return IntRat(you, me)
-  __div__ = __truediv__
-  __rdiv__ = __rtruediv__
+  def __truediv__(me, you):
+    if isinstance(you, float): return _long(me)/you
+    else: return IntRat(me, you)
+  def __rtruediv__(me, you):
+    if isinstance(you, float): return you/_long(me)
+    else: return IntRat(you, me)
+  if _sys.version_info < (3,):
+    __div__ = __truediv__
+    __rdiv__ = __rtruediv__
   _repr_pretty_ = _pp_str
 _augment(MP, _tmp)
 
@@ -417,8 +469,9 @@ class _tmp:
   def quadsolve(x, y): return x.reduce().quadsolve(y)
   def __truediv__(me, you): return GFRat(me, you)
   def __rtruediv__(me, you): return GFRat(you, me)
-  __div__ = __truediv__
-  __rdiv__ = __rtruediv__
+  if _sys.version_info < (3,):
+    __div__ = __truediv__
+    __rdiv__ = __rtruediv__
   _repr_pretty_ = _pp_str
 _augment(GF, _tmp)
 
@@ -438,7 +491,7 @@ class _tmp:
 _augment(Field, _tmp)
 
 class _tmp:
-  def __repr__(me): return '%s(%sL)' % (_clsname(me), me.p)
+  def __repr__(me): return '%s(%s)' % (_clsname(me), me.p)
   def __hash__(me): return 0x114401de ^ hash(me.p)
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
@@ -449,7 +502,7 @@ class _tmp:
 _augment(PrimeField, _tmp)
 
 class _tmp:
-  def __repr__(me): return '%s(%#xL)' % (_clsname(me), me.p)
+  def __repr__(me): return '%s(%#x)' % (_clsname(me), me.p)
   def ec(me, a, b): return ECBinProjCurve(me, a, b)
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
@@ -644,21 +697,34 @@ _augment(KeySZSet, _tmp)
 ### Key data objects.
 
 class _tmp:
+  def merge(me, file, report = None):
+    """KF.merge(FILE, [report = <built-in-reporter>])"""
+    name = file.name
+    lno = 1
+    for line in file:
+      me.mergeline(name, lno, line, report)
+      lno += 1
+    return me
   def __repr__(me): return '%s(%r)' % (_clsname(me), me.name)
 _augment(KeyFile, _tmp)
 
 class _tmp:
+  def extract(me, file, filter = ''):
+    """KEY.extract(FILE, [filter = <any>])"""
+    line = me.extractline(filter)
+    file.write(line)
+    return me
   def __repr__(me): return '%s(%r)' % (_clsname(me), me.fulltag)
 _augment(Key, _tmp)
 
 class _tmp:
   def __repr__(me):
     return '%s({%s})' % (_clsname(me),
-                         ', '.join(['%r: %r' % kv for kv in me.iteritems()]))
+                         ', '.join(['%r: %r' % kv for kv in _iteritems(me)()]))
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
     if cyclep: pp.text('...')
-    else: _pp_dict(pp, me.iteritems())
+    else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ')')
 _augment(KeyAttributes, _tmp)
 
@@ -677,38 +743,66 @@ class _tmp:
       pp.text(','); pp.breakable()
       pp.pretty(me.writeflags(me.flags))
     pp.end_group(ind, ')')
+  def __hash__(me): return me._HASHBASE ^ hash(me._guts())
+  def __eq__(me, kd):
+    return type(me) == type(kd) and \
+      me._guts() == kd._guts() and \
+      me.flags == kd.flags
+  def __ne__(me, kd):
+    return not me == kd
 _augment(KeyData, _tmp)
 
 class _tmp:
   def _guts(me): return me.bin
+  def __eq__(me, kd):
+    return isinstance(kd, KeyDataBinary) and me.bin == kd.bin
 _augment(KeyDataBinary, _tmp)
+KeyDataBinary._HASHBASE = 0x961755c3
 
 class _tmp:
   def _guts(me): return me.ct
 _augment(KeyDataEncrypted, _tmp)
+KeyDataEncrypted._HASHBASE = 0xffe000d4
 
 class _tmp:
   def _guts(me): return me.mp
 _augment(KeyDataMP, _tmp)
+KeyDataMP._HASHBASE = 0x1cb64d69
 
 class _tmp:
   def _guts(me): return me.str
 _augment(KeyDataString, _tmp)
+KeyDataString._HASHBASE = 0x349c33ea
 
 class _tmp:
   def _guts(me): return me.ecpt
 _augment(KeyDataECPt, _tmp)
+KeyDataECPt._HASHBASE = 0x2509718b
 
 class _tmp:
   def __repr__(me):
     return '%s({%s})' % (_clsname(me),
-                         ', '.join(['%r: %r' % kv for kv in me.iteritems()]))
+                         ', '.join(['%r: %r' % kv for kv in _iteritems(me)]))
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me, '({ ')
     if cyclep: pp.text('...')
-    else: _pp_dict(pp, me.iteritems())
+    else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ' })')
+  def __hash__(me):
+    h = me._HASHBASE
+    for k, v in _iteritems(me):
+      h = ((h << 1) ^ 3*hash(k) ^ 5*hash(v))&0xffffffff
+    return h
+  def __eq__(me, kd):
+    if type(me) != type(kd) or me.flags != kd.flags or len(me) != len(kd):
+      return False
+    for k, v in _iteritems(me):
+      try: vv = kd[k]
+      except KeyError: return False
+      if v != vv: return False
+    return True
 _augment(KeyDataStructured, _tmp)
+KeyDataStructured._HASHBASE = 0x85851b21
 
 ###--------------------------------------------------------------------------
 ### Abstract groups.
@@ -783,7 +877,7 @@ _augment(GE, _tmp)
 ### RSA encoding techniques.
 
 class PKCS1Crypt (object):
-  def __init__(me, ep = '', rng = rand):
+  def __init__(me, ep = _bin(''), rng = rand):
     me.ep = ep
     me.rng = rng
   def encode(me, msg, nbits):
@@ -792,7 +886,7 @@ class PKCS1Crypt (object):
     return _base._p1crypt_decode(ct, nbits, me.ep, me.rng)
 
 class PKCS1Sig (object):
-  def __init__(me, ep = '', rng = rand):
+  def __init__(me, ep = _bin(''), rng = rand):
     me.ep = ep
     me.rng = rng
   def encode(me, msg, nbits):
@@ -801,7 +895,7 @@ class PKCS1Sig (object):
     return _base._p1sig_decode(msg, sig, nbits, me.ep, me.rng)
 
 class OAEP (object):
-  def __init__(me, mgf = sha_mgf, hash = sha, ep = '', rng = rand):
+  def __init__(me, mgf = sha_mgf, hash = sha, ep = _bin(''), rng = rand):
     me.mgf = mgf
     me.hash = hash
     me.ep = ep
@@ -1006,11 +1100,11 @@ class Ed448Priv (_EdDSAPriv, Ed448Pub):
 
 class _tmp:
   def __repr__(me):
-    return '{%s}' % ', '.join(['%r: %r' % kv for kv in me.iteritems()])
+    return '{%s}' % ', '.join(['%r: %r' % kv for kv in _iteritems(me)])
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup(pp, '{ ')
     if cyclep: pp.text('...')
-    else: _pp_dict(pp, me.iteritems())
+    else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ' }')
 _augment(_base._MiscTable, _tmp)