catacomb.c (mexp_common): Accept an arbitrary iterable.
authorMark Wooding <mdw@distorted.org.uk>
Mon, 25 Nov 2019 13:19:45 +0000 (13:19 +0000)
committerMark Wooding <mdw@distorted.org.uk>
Mon, 25 Nov 2019 17:51:34 +0000 (17:51 +0000)
This means that any iterable producing either (BASE, EXP) pair objects
or alternating BASE/EXP pairs can be used as input to any of the
multiple-exponentiation functions.  Of course, expecting an object with
unpredictable iteration order to produce a useful sequence of BASE/EXP
pairs will probably cause disappointment, but, for example, a dictionary
mapping bases to exponents seems like a reasonable thing to maintain.

catacomb.c
t/t-ec.py
t/t-mp.py

index 1f38213..23906eb 100644 (file)
@@ -38,61 +38,69 @@ PyObject *mexp_common(PyObject *me, PyObject *arg,
                      PyObject *(*exp)(PyObject *, void *, size_t),
                      void (*drop)(void *))
 {
-  size_t i = 0, j, n;
+  size_t i = 0, o, n;
   int flat;
-  PyObject *qq, *x, *y, *z = 0;
-  char *v = 0, *vv;
-
-  if (PyTuple_GET_SIZE(arg) == 1)
-    arg = PyTuple_GET_ITEM(arg, 0);
-  if (!PySequence_Check(arg)) TYERR("not a sequence");
-  n = PySequence_Size(arg); if (n < 0) goto end;
-  if (!n) { z = id(me); goto end; }
-  x = PySequence_GetItem(arg, 0);
-  if (PySequence_Check(x))
-    flat = 0;
+  PyObject *qq = 0, *x = 0, *y = 0, *z = 0, *it = 0;
+  char *v = 0;
+
+  if (PyTuple_Size(arg) == 1) arg = PyTuple_GET_ITEM(arg, 0);
+  it = PyObject_GetIter(arg); if (!it) goto end;
+  qq = PyIter_Next(it);
+  if (!qq) {
+    if (!PyErr_Occurred()) z = id(me);
+    else goto end;
+  }
+  flat = !PySequence_Check(qq);
+  if (!PySequence_Check(arg))
+    n = 16;
   else {
-    if (n % 2) VALERR("must have even number of arguments");
-    n /= 2;
-    flat = 1;
+    n = PySequence_Size(arg);
+    if (n == (size_t)-1 && PyErr_Occurred()) goto end;
+    if (flat) n /= 2;
+    if (!n) n = 16;
   }
-  Py_DECREF(x);
 
-  v = xmalloc(n * efsz);
-  for (i = j = 0, vv = v; i < n; i++, vv += efsz) {
-    if (flat) {
-      x = PySequence_GetItem(arg, j++);
-      y = PySequence_GetItem(arg, j++);
-    } else {
-      qq = PySequence_GetItem(arg, j++);
-      if (!qq) goto end;
-      if (!PySequence_Check(qq) || PySequence_Size(qq) != 2) {
-       Py_DECREF(qq);
+  v = xmalloc(n*efsz);
+  o = 0;
+  for (;;) {
+    if (!flat) {
+      if (!PySequence_Check(qq) || PySequence_Size(qq) != 2)
        TYERR("want a sequence of pairs");
-      }
       x = PySequence_GetItem(qq, 0);
       y = PySequence_GetItem(qq, 1);
-      Py_DECREF(qq);
+    } else {
+      x = qq; qq = 0;
+      y = PyIter_Next(it);
+      if (!y) {
+       if (PyErr_Occurred()) goto end;
+       VALERR("must have even number of operands");
+      }
     }
     if (!x || !y) goto end;
-    if (fill(vv, me, x, y)) {
-      Py_DECREF(x);
-      Py_DECREF(y);
-      if (!PyErr_Occurred())
-       PyErr_SetString(PyExc_TypeError, "type mismatch");
-      goto end;
+
+    if (i >= n) { n *= 2; v = xrealloc(v, n*efsz, i*efsz); }
+    if (fill(v + o, me, x, y)) {
+      if (PyErr_Occurred()) goto end;
+      TYERR("type mismatch");
+    }
+    i++; o += efsz;
+    Py_DECREF(x); x = 0;
+    Py_DECREF(y); y = 0;
+    Py_XDECREF(qq);
+
+    qq = PyIter_Next(it);
+    if (!qq) {
+      if (PyErr_Occurred()) goto end;
+      else break;
     }
-    Py_DECREF(x);
-    Py_DECREF(y);
   }
-  z = exp(me, v, n);
+
+  z = exp(me, v, i);
 
 end:
-  if (v) {
-    for (j = 0, vv = v; j < i; j++, vv += efsz)
-      drop(vv);
-    xfree(v);
-  }
+  while (i--) { o -= efsz; drop(v + o); }
+  xfree(v);
+  Py_XDECREF(it); Py_XDECREF(qq); Py_XDECREF(x); Py_XDECREF(y);
   return (z);
 }
 
index d854216..b85105d 100644 (file)
--- a/t/t-ec.py
+++ b/t/t-ec.py
@@ -242,6 +242,7 @@ class TestCurves (T.GenericTestMixin):
 
     ## Simultaneous multiplication.
     Q, R, S = 5*P, 7*P, 11*P
+    me.assertEqual(E.mmul(set([(Q, 9), (R, 8), (S, 5)])), 156*P)
     me.assertEqual(E.mmul([Q, 9, R, 8, S, 5]), 156*P)
     me.assertEqual(E.mmul(Q, 9, R, 8, S, 5), 156*P)
 
index 9927c12..f856f21 100644 (file)
--- a/t/t-mp.py
+++ b/t/t-mp.py
@@ -333,10 +333,12 @@ class TestMPMont (U.TestCase):
     me.assertEqual(m.expr(m.int(2), p - 1), m.r)
 
     q, r, s, z = 32, 128, 2048, pow(g, 156, p)
+    me.assertEqual(m.mexp(set([(q, 9), (r, 8), (s, 5)])), z)
     me.assertEqual(m.mexp([(q, 9), (r, 8), (s, 5)]), z)
     me.assertEqual(m.mexp(q, 9, r, 8, s, 5), z)
 
     q, r, s, z = T.imap(m.int, [32, 128, 2048, pow(g, 156, p)])
+    me.assertEqual(m.mexpr(set([(q, 9), (r, 8), (s, 5)])), z)
     me.assertEqual(m.mexpr([(q, 9), (r, 8), (s, 5)]), z)
     me.assertEqual(m.mexpr(q, 9, r, 8, s, 5), z)
 
@@ -359,6 +361,7 @@ class TestMPBarrett (U.TestCase):
     me.assertEqual(m.exp(2, p - 1), 1)
 
     q, r, s, z = 32, 128, 2048, pow(g, 156, p)
+    me.assertEqual(m.mexp(set([(q, 9), (r, 8), (s, 5)])), z)
     me.assertEqual(m.mexp([(q, 9), (r, 8), (s, 5)]), z)
     me.assertEqual(m.mexp(q, 9, r, 8, s, 5), z)