Merge branch '1.2.x' into 1.3.x
[catacomb-python] / catacomb / __init__.py
index 24265ac..9178475 100644 (file)
@@ -112,7 +112,7 @@ def _init():
     for j in b:
       if j[:plen] == pre:
         setattr(c, j[plen:], classmethod(b[j]))
-  for i in [gcciphers, gchashes, gcmacs, gcprps]:
+  for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]:
     for c in i.itervalues():
       d[_fixname(c.name)] = c
   for c in gccrands.itervalues():
@@ -194,6 +194,27 @@ ByteString.__hash__ = str.__hash__
 bytes = ByteString.fromhex
 
 ###--------------------------------------------------------------------------
+### Symmetric encryption.
+
+class _tmp:
+  def encrypt(me, n, m, tsz = None, h = ByteString('')):
+    if tsz is None: tsz = me.__class__.tagsz.default
+    e = me.enc(n, len(h), len(m), tsz)
+    if not len(h): a = None
+    else: a = e.aad().hash(h)
+    c0 = e.encrypt(m)
+    c1, t = e.done(aad = a)
+    return c0 + c1, t
+  def decrypt(me, n, c, t, h = ByteString('')):
+    d = me.dec(n, len(h), len(c), len(t))
+    if not len(h): a = None
+    else: a = d.aad().hash(h)
+    m = d.decrypt(c)
+    m += d.done(t, aad = a)
+    return m
+_augment(GAEKey, _tmp)
+
+###--------------------------------------------------------------------------
 ### Hashing.
 
 class _tmp:
@@ -293,7 +314,7 @@ class _tmp:
     me.bytepad_after()
 _augment(Shake, _tmp)
 _augment(_ShakeBase, _tmp)
-Shake._Z = _ShakeBase._Z = ByteString(200*'\0')
+Shake._Z = _ShakeBase._Z = ByteString.zero(200)
 
 class KMAC (_ShakeBase):
   _FUNC = 'KMAC'
@@ -317,21 +338,12 @@ class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
 ### NaCl `secretbox'.
 
 def secret_box(k, n, m):
-  E = xsalsa20(k).setiv(n)
-  r = E.enczero(poly1305.keysz.default)
-  s = E.enczero(poly1305.masksz)
-  y = E.encrypt(m)
-  t = poly1305(r)(s).hash(y).done()
-  return ByteString(t + y)
+  y, t = salsa20_naclbox(k).encrypt(n, m)
+  return t + y
 
 def secret_unbox(k, n, c):
-  E = xsalsa20(k).setiv(n)
-  r = E.enczero(poly1305.keysz.default)
-  s = E.enczero(poly1305.masksz)
-  y = c[poly1305.tagsz:]
-  if not poly1305(r)(s).hash(y).check(c[0:poly1305.tagsz]):
-    raise ValueError, 'decryption failed'
-  return E.decrypt(c[poly1305.tagsz:])
+  tsz = poly1305.tagsz
+  return salsa20_naclbox(k).decrypt(n, c[tsz:], c[0:tsz])
 
 ###--------------------------------------------------------------------------
 ### Multiprecision integers and binary polynomials.
@@ -586,6 +598,7 @@ class _tmp:
   def __repr__(me): return '%s(%d)' % (_clsname(me), me.default)
   def check(me, sz): return True
   def best(me, sz): return sz
+  def pad(me, sz): return sz
 _augment(KeySZAny, _tmp)
 
 class _tmp:
@@ -602,11 +615,15 @@ class _tmp:
       pp.pretty(me.max); pp.text(','); pp.breakable()
       pp.pretty(me.mod)
     pp.end_group(ind, ')')
-  def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0
+  def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0
   def best(me, sz):
     if sz < me.min: raise ValueError, 'key too small'
     elif sz > me.max: return me.max
-    else: return sz - (sz % me.mod)
+    else: return sz - sz%me.mod
+  def pad(me, sz):
+    if sz > me.max: raise ValueError, 'key too large'
+    elif sz < me.min: return me.min
+    else: sz += me.mod; return sz - sz%me.mod
 _augment(KeySZRange, _tmp)
 
 class _tmp:
@@ -628,6 +645,12 @@ class _tmp:
       if found < i <= sz: found = i
     if found < 0: raise ValueError, 'key too small'
     return found
+  def pad(me, sz):
+    found = -1
+    for i in me.set:
+      if sz <= i and (found == -1 or i < found): found = i
+    if found < 0: raise ValueError, 'key too large'
+    return found
 _augment(KeySZSet, _tmp)
 
 ###--------------------------------------------------------------------------
@@ -866,21 +889,23 @@ _augment(RSAPriv, _tmp)
 ### DSA and related schemes.
 
 class _tmp:
-  def __repr__(me): return '%s(G = %r, p = %r)' % (_clsname(me), me.G, me.p)
+  def __repr__(me): return '%s(G = %r, p = %r, hash = %r)' % \
+        (_clsname(me), me.G, me.p, me.hash)
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
     if cyclep:
       pp.text('...')
     else:
       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
-      _pp_kv(pp, 'p', me.p)
+      _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable()
+      _pp_kv(pp, 'hash', me.hash)
     pp.end_group(ind, ')')
 _augment(DSAPub, _tmp)
 _augment(KCDSAPub, _tmp)
 
 class _tmp:
-  def __repr__(me): return '%s(G = %r, u = %s, p = %r)' % \
-      (_clsname(me), me.G, _repr_secret(me.u), me.p)
+  def __repr__(me): return '%s(G = %r, u = %s, p = %r, hash = %r)' % \
+      (_clsname(me), me.G, _repr_secret(me.u), me.p, me.hash)
   def _repr_pretty_(me, pp, cyclep):
     ind = _pp_bgroup_tyname(pp, me)
     if cyclep:
@@ -888,7 +913,8 @@ class _tmp:
     else:
       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
       _pp_kv(pp, 'u', me.u, True); pp.text(','); pp.breakable()
-      _pp_kv(pp, 'p', me.p)
+      _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable()
+      _pp_kv(pp, 'hash', me.hash)
     pp.end_group(ind, ')')
 _augment(DSAPriv, _tmp)
 _augment(KCDSAPriv, _tmp)