catacomb/__init__.py: Hack because Python 3 `hex' builtin works differently.
[catacomb-python] / catacomb / __init__.py
index 24265ac..22b1fdf 100644 (file)
@@ -77,6 +77,13 @@ def default_lostexchook(why, ty, val, tb):
   _sys.stderr.write("\n")
 lostexchook = default_lostexchook
 
+## Text/binary conversions.
+def _bin(s): return s
+
+## Iterating over dictionaries.
+def _iteritems(dict): return dict.iteritems()
+def _itervalues(dict): return dict.itervalues()
+
 ## How to fix a name back into the right identifier.  Alas, the rules are not
 ## consistent.
 def _fixname(name):
@@ -99,23 +106,10 @@ def _init():
   for i in b:
     if i[0] != '_':
       d[i] = b[i];
-  for i in ['ByteString',
-            'MP', 'GF', 'Field',
-            'ECPt', 'ECPtCurve', 'ECCurve', 'ECInfo',
-            'DHInfo', 'BinDHInfo', 'RSAPriv', 'BBSPriv',
-            'PrimeFilter', 'RabinMiller',
-            'Group', 'GE',
-            'KeySZ', 'KeyData']:
-    c = d[i]
-    pre = '_' + i + '_'
-    plen = len(pre)
-    for j in b:
-      if j[:plen] == pre:
-        setattr(c, j[plen:], classmethod(b[j]))
-  for i in [gcciphers, gchashes, gcmacs, gcprps]:
-    for c in i.itervalues():
+  for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]:
+    for c in _itervalues(i):
       d[_fixname(c.name)] = c
-  for c in gccrands.itervalues():
+  for c in _itervalues(gccrands):
     d[_fixname(c.name + 'rand')] = c
 _init()
 
@@ -137,7 +131,7 @@ def _augment(c, cc):
 def _checkend(r):
   x, rest = r
   if rest != '':
-    raise SyntaxError, 'junk at end of string'
+    raise SyntaxError('junk at end of string')
   return x
 
 ## Some pretty-printing utilities.
@@ -167,7 +161,8 @@ def _pp_commas(pp, printfn, items):
     else: pp.text(','); pp.breakable()
     printfn(i)
 def _pp_dict(pp, items):
-  def p((k, v)):
+  def p(kv):
+    k, v = kv
     pp.begin_group(0)
     pp.pretty(k)
     pp.text(':')
@@ -185,15 +180,36 @@ class _tmp:
   def fromhex(x):
     return ByteString(_unhexify(x))
   fromhex = staticmethod(fromhex)
-  def __hex__(me):
-    return _hexify(me)
+  def hex(me): return _hexify(me)
+  __hex__ = hex
   def __repr__(me):
-    return 'bytes(%r)' % hex(me)
+    return 'bytes(%r)' % me.hex()
 _augment(ByteString, _tmp)
 ByteString.__hash__ = str.__hash__
 bytes = ByteString.fromhex
 
 ###--------------------------------------------------------------------------
+### Symmetric encryption.
+
+class _tmp:
+  def encrypt(me, n, m, tsz = None, h = ByteString('')):
+    if tsz is None: tsz = me.__class__.tagsz.default
+    e = me.enc(n, len(h), len(m), tsz)
+    if not len(h): a = None
+    else: a = e.aad().hash(h)
+    c0 = e.encrypt(m)
+    c1, t = e.done(aad = a)
+    return c0 + c1, t
+  def decrypt(me, n, c, t, h = ByteString('')):
+    d = me.dec(n, len(h), len(c), len(t))
+    if not len(h): a = None
+    else: a = d.aad().hash(h)
+    m = d.decrypt(c)
+    m += d.done(t, aad = a)
+    return m
+_augment(GAEKey, _tmp)
+
+###--------------------------------------------------------------------------
 ### Hashing.
 
 class _tmp:
@@ -248,7 +264,7 @@ class _ShakeBase (_HashBase):
 
   ## Python gets really confused if I try to augment `__new__' on native
   ## classes, so wrap and delegate.  Sorry.
-  def __init__(me, perso = '', *args, **kw):
+  def __init__(me, perso = _bin(''), *args, **kw):
     super(_ShakeBase, me).__init__(*args, **kw)
     me._h = me._SHAKE(perso = perso, func = me._FUNC)
 
@@ -293,10 +309,10 @@ class _tmp:
     me.bytepad_after()
 _augment(Shake, _tmp)
 _augment(_ShakeBase, _tmp)
-Shake._Z = _ShakeBase._Z = ByteString(200*'\0')
+Shake._Z = _ShakeBase._Z = ByteString.zero(200)
 
 class KMAC (_ShakeBase):
-  _FUNC = 'KMAC'
+  _FUNC = _bin('KMAC')
   def __init__(me, k, *arg, **kw):
     super(KMAC, me).__init__(*arg, **kw)
     with me.bytepad(): me.stringenc(k)
@@ -308,7 +324,7 @@ class KMAC (_ShakeBase):
     me.rightenc(0)
     return super(KMAC, me).xof()
   @classmethod
-  def _bare_new(cls): return cls("")
+  def _bare_new(cls): return cls(_bin(""))
 
 class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16
 class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
@@ -317,21 +333,12 @@ class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
 ### NaCl `secretbox'.
 
 def secret_box(k, n, m):
-  E = xsalsa20(k).setiv(n)
-  r = E.enczero(poly1305.keysz.default)
-  s = E.enczero(poly1305.masksz)
-  y = E.encrypt(m)
-  t = poly1305(r)(s).hash(y).done()
-  return ByteString(t + y)
+  y, t = salsa20_naclbox(k).encrypt(n, m)
+  return t + y
 
 def secret_unbox(k, n, c):
-  E = xsalsa20(k).setiv(n)
-  r = E.enczero(poly1305.keysz.default)
-  s = E.enczero(poly1305.masksz)
-  y = c[poly1305.tagsz:]
-  if not poly1305(r)(s).hash(y).check(c[0:poly1305.tagsz]):
-    raise ValueError, 'decryption failed'
-  return E.decrypt(c[poly1305.tagsz:])
+  tsz = poly1305.tagsz
+  return salsa20_naclbox(k).decrypt(n, c[tsz:], c[0:tsz])
 
 ###--------------------------------------------------------------------------
 ### Multiprecision integers and binary polynomials.
@@ -586,6 +593,7 @@ class _tmp:
   def __repr__(me): return '%s(%d)' % (_clsname(me), me.default)
   def check(me, sz): return True
   def best(me, sz): return sz
+  def pad(me, sz): return sz
 _augment(KeySZAny, _tmp)
 
 class _tmp:
@@ -602,11 +610,15 @@ class _tmp:
       pp.pretty(me.max); pp.text(','); pp.breakable()
       pp.pretty(me.mod)
     pp.end_group(ind, ')')
-  def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0
+  def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0
   def best(me, sz):
-    if sz < me.min: raise ValueError, 'key too small'
+    if sz < me.min: raise ValueError('key too small')
     elif sz > me.max: return me.max
-    else: return sz - (sz % me.mod)
+    else: return sz - sz%me.mod
+  def pad(me, sz):
+    if sz > me.max: raise ValueError('key too large')
+    elif sz < me.min: return me.min
+    else: sz += me.mod - 1; return sz - sz%me.mod
 _augment(KeySZRange, _tmp)
 
 class _tmp:
@@ -626,7 +638,13 @@ class _tmp:
     found = -1
     for i in me.set:
       if found < i <= sz: found = i
-    if found < 0: raise ValueError, 'key too small'
+    if found < 0: raise ValueError('key too small')
+    return found
+  def pad(me, sz):
+    found = -1
+    for i in me.set:
+      if sz <= i and (found == -1 or i < found): found = i
+    if found < 0: raise ValueError('key too large')
     return found
 _augment(KeySZSet, _tmp)
 
@@ -644,11 +662,11 @@ _augment(Key, _tmp)
 class _tmp:
   def __repr__(me):
     return '%s({%s})' % (_clsname(me),
-                         ', '.join(['%r: %r' % kv for kv in me.iteritems()]))
+                         ', '.join(['%r: %r' % kv for kv in _iteritems(me)()]))
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
     if cyclep: pp.text('...')
-    else: _pp_dict(pp, me.iteritems())
+    else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ')')
 _augment(KeyAttributes, _tmp)
 
@@ -692,11 +710,11 @@ _augment(KeyDataECPt, _tmp)
 class _tmp:
   def __repr__(me):
     return '%s({%s})' % (_clsname(me),
-                         ', '.join(['%r: %r' % kv for kv in me.iteritems()]))
+                         ', '.join(['%r: %r' % kv for kv in _iteritems(me)]))
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me, '({ ')
     if cyclep: pp.text('...')
-    else: _pp_dict(pp, me.iteritems())
+    else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ' })')
 _augment(KeyDataStructured, _tmp)
 
@@ -773,7 +791,7 @@ _augment(GE, _tmp)
 ### RSA encoding techniques.
 
 class PKCS1Crypt (object):
-  def __init__(me, ep = '', rng = rand):
+  def __init__(me, ep = _bin(''), rng = rand):
     me.ep = ep
     me.rng = rng
   def encode(me, msg, nbits):
@@ -782,7 +800,7 @@ class PKCS1Crypt (object):
     return _base._p1crypt_decode(ct, nbits, me.ep, me.rng)
 
 class PKCS1Sig (object):
-  def __init__(me, ep = '', rng = rand):
+  def __init__(me, ep = _bin(''), rng = rand):
     me.ep = ep
     me.rng = rng
   def encode(me, msg, nbits):
@@ -791,7 +809,7 @@ class PKCS1Sig (object):
     return _base._p1sig_decode(msg, sig, nbits, me.ep, me.rng)
 
 class OAEP (object):
-  def __init__(me, mgf = sha_mgf, hash = sha, ep = '', rng = rand):
+  def __init__(me, mgf = sha_mgf, hash = sha, ep = _bin(''), rng = rand):
     me.mgf = mgf
     me.hash = hash
     me.ep = ep
@@ -866,21 +884,23 @@ _augment(RSAPriv, _tmp)
 ### DSA and related schemes.
 
 class _tmp:
-  def __repr__(me): return '%s(G = %r, p = %r)' % (_clsname(me), me.G, me.p)
+  def __repr__(me): return '%s(G = %r, p = %r, hash = %r)' % \
+        (_clsname(me), me.G, me.p, me.hash)
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
     if cyclep:
       pp.text('...')
     else:
       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
-      _pp_kv(pp, 'p', me.p)
+      _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable()
+      _pp_kv(pp, 'hash', me.hash)
     pp.end_group(ind, ')')
 _augment(DSAPub, _tmp)
 _augment(KCDSAPub, _tmp)
 
 class _tmp:
-  def __repr__(me): return '%s(G = %r, u = %s, p = %r)' % \
-      (_clsname(me), me.G, _repr_secret(me.u), me.p)
+  def __repr__(me): return '%s(G = %r, u = %s, p = %r, hash = %r)' % \
+      (_clsname(me), me.G, _repr_secret(me.u), me.p, me.hash)
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
     if cyclep:
@@ -888,7 +908,8 @@ class _tmp:
     else:
       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
       _pp_kv(pp, 'u', me.u, True); pp.text(','); pp.breakable()
-      _pp_kv(pp, 'p', me.p)
+      _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable()
+      _pp_kv(pp, 'hash', me.hash)
     pp.end_group(ind, ')')
 _augment(DSAPriv, _tmp)
 _augment(KCDSAPriv, _tmp)
@@ -903,7 +924,7 @@ Z128 = ByteString.zero(16)
 
 class _BasePub (object):
   def __init__(me, pub, *args, **kw):
-    if not me._PUBSZ.check(len(pub)): raise ValueError, 'bad public key'
+    if not me._PUBSZ.check(len(pub)): raise ValueError('bad public key')
     super(_BasePub, me).__init__(*args, **kw)
     me.pub = pub
   def __repr__(me): return '%s(pub = %r)' % (_clsname(me), me.pub)
@@ -916,7 +937,7 @@ class _BasePub (object):
 
 class _BasePriv (object):
   def __init__(me, priv, pub = None, *args, **kw):
-    if not me._KEYSZ.check(len(priv)): raise ValueError, 'bad private key'
+    if not me._KEYSZ.check(len(priv)): raise ValueError('bad private key')
     if pub is None: pub = me._pubkey(priv)
     super(_BasePriv, me).__init__(pub = pub, *args, **kw)
     me.priv = priv
@@ -989,51 +1010,17 @@ class Ed448Priv (_EdDSAPriv, Ed448Pub):
     return ed448_sign(me.priv, msg, pub = me.pub, **kw)
 
 ###--------------------------------------------------------------------------
-### Built-in named curves and prime groups.
-
-class _groupmap (object):
-  def __init__(me, map, nth):
-    me.map = map
-    me.nth = nth
-    me._n = max(map.values()) + 1
-    me.i = me._n*[None]
+### Built-in algorithm and group tables.
+
+class _tmp:
   def __repr__(me):
-    return '{%s}' % ', '.join(['%r: %r' % kv for kv in me.iteritems()])
+    return '{%s}' % ', '.join(['%r: %r' % kv for kv in _iteritems(me)])
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup(pp, '{ ')
     if cyclep: pp.text('...')
-    else: _pp_dict(pp, me.iteritems())
+    else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ' }')
-  def __len__(me):
-    return me._n
-  def __contains__(me, k):
-    return k in me.map
-  def __getitem__(me, k):
-    i = me.map[k]
-    if me.i[i] is None:
-      me.i[i] = me.nth(i)
-    return me.i[i]
-  def __setitem__(me, k, v):
-    raise TypeError, "immutable object"
-  def __iter__(me):
-    return iter(me.map)
-  def iterkeys(me):
-    return iter(me.map)
-  def itervalues(me):
-    for k in me:
-      yield me[k]
-  def iteritems(me):
-    for k in me:
-      yield k, me[k]
-  def keys(me):
-    return [k for k in me]
-  def values(me):
-    return [me[k] for k in me]
-  def items(me):
-    return [(k, me[k]) for k in me]
-eccurves = _groupmap(_base._eccurves, ECInfo._curven)
-primegroups = _groupmap(_base._pgroups, DHInfo._groupn)
-bingroups = _groupmap(_base._bingroups, BinDHInfo._groupn)
+_augment(_base._MiscTable, _tmp)
 
 ###--------------------------------------------------------------------------
 ### Prime number generation.
@@ -1143,7 +1130,7 @@ class SimulStepper (PrimeGenEventHandler):
     me.add = add
   def _stepfn(me, step):
     if step <= 0:
-      raise ValueError, 'step must be positive'
+      raise ValueError('step must be positive')
     if step <= MPW_MAX:
       return lambda f: f.step(step)
     j = PrimeFilter(step)