From 74836df7a6a201ee0e87500621ae7d4f581de1e5 Mon Sep 17 00:00:00 2001 From: Mark Wooding Date: Mon, 25 Nov 2019 13:19:45 +0000 Subject: [PATCH] catacomb.c (mexp_common): Accept an arbitrary iterable. This means that any iterable producing either (BASE, EXP) pair objects or alternating BASE/EXP pairs can be used as input to any of the multiple-exponentiation functions. Of course, expecting an object with unpredictable iteration order to produce a useful sequence of BASE/EXP pairs will probably cause disappointment, but, for example, a dictionary mapping bases to exponents seems like a reasonable thing to maintain. --- catacomb.c | 92 ++++++++++++++++++++++++++++++++++---------------------------- t/t-ec.py | 1 + t/t-mp.py | 3 ++ 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/catacomb.c b/catacomb.c index 1f38213..23906eb 100644 --- a/catacomb.c +++ b/catacomb.c @@ -38,61 +38,69 @@ PyObject *mexp_common(PyObject *me, PyObject *arg, PyObject *(*exp)(PyObject *, void *, size_t), void (*drop)(void *)) { - size_t i = 0, j, n; + size_t i = 0, o, n; int flat; - PyObject *qq, *x, *y, *z = 0; - char *v = 0, *vv; - - if (PyTuple_GET_SIZE(arg) == 1) - arg = PyTuple_GET_ITEM(arg, 0); - if (!PySequence_Check(arg)) TYERR("not a sequence"); - n = PySequence_Size(arg); if (n < 0) goto end; - if (!n) { z = id(me); goto end; } - x = PySequence_GetItem(arg, 0); - if (PySequence_Check(x)) - flat = 0; + PyObject *qq = 0, *x = 0, *y = 0, *z = 0, *it = 0; + char *v = 0; + + if (PyTuple_Size(arg) == 1) arg = PyTuple_GET_ITEM(arg, 0); + it = PyObject_GetIter(arg); if (!it) goto end; + qq = PyIter_Next(it); + if (!qq) { + if (!PyErr_Occurred()) z = id(me); + else goto end; + } + flat = !PySequence_Check(qq); + if (!PySequence_Check(arg)) + n = 16; else { - if (n % 2) VALERR("must have even number of arguments"); - n /= 2; - flat = 1; + n = PySequence_Size(arg); + if (n == (size_t)-1 && PyErr_Occurred()) goto end; + if (flat) n /= 2; + if (!n) n = 16; } - Py_DECREF(x); - v = xmalloc(n * efsz); - for (i = j = 0, vv = v; i < n; i++, vv += efsz) { - if (flat) { - x = PySequence_GetItem(arg, j++); - y = PySequence_GetItem(arg, j++); - } else { - qq = PySequence_GetItem(arg, j++); - if (!qq) goto end; - if (!PySequence_Check(qq) || PySequence_Size(qq) != 2) { - Py_DECREF(qq); + v = xmalloc(n*efsz); + o = 0; + for (;;) { + if (!flat) { + if (!PySequence_Check(qq) || PySequence_Size(qq) != 2) TYERR("want a sequence of pairs"); - } x = PySequence_GetItem(qq, 0); y = PySequence_GetItem(qq, 1); - Py_DECREF(qq); + } else { + x = qq; qq = 0; + y = PyIter_Next(it); + if (!y) { + if (PyErr_Occurred()) goto end; + VALERR("must have even number of operands"); + } } if (!x || !y) goto end; - if (fill(vv, me, x, y)) { - Py_DECREF(x); - Py_DECREF(y); - if (!PyErr_Occurred()) - PyErr_SetString(PyExc_TypeError, "type mismatch"); - goto end; + + if (i >= n) { n *= 2; v = xrealloc(v, n*efsz, i*efsz); } + if (fill(v + o, me, x, y)) { + if (PyErr_Occurred()) goto end; + TYERR("type mismatch"); + } + i++; o += efsz; + Py_DECREF(x); x = 0; + Py_DECREF(y); y = 0; + Py_XDECREF(qq); + + qq = PyIter_Next(it); + if (!qq) { + if (PyErr_Occurred()) goto end; + else break; } - Py_DECREF(x); - Py_DECREF(y); } - z = exp(me, v, n); + + z = exp(me, v, i); end: - if (v) { - for (j = 0, vv = v; j < i; j++, vv += efsz) - drop(vv); - xfree(v); - } + while (i--) { o -= efsz; drop(v + o); } + xfree(v); + Py_XDECREF(it); Py_XDECREF(qq); Py_XDECREF(x); Py_XDECREF(y); return (z); } diff --git a/t/t-ec.py b/t/t-ec.py index d854216..b85105d 100644 --- a/t/t-ec.py +++ b/t/t-ec.py @@ -242,6 +242,7 @@ class TestCurves (T.GenericTestMixin): ## Simultaneous multiplication. Q, R, S = 5*P, 7*P, 11*P + me.assertEqual(E.mmul(set([(Q, 9), (R, 8), (S, 5)])), 156*P) me.assertEqual(E.mmul([Q, 9, R, 8, S, 5]), 156*P) me.assertEqual(E.mmul(Q, 9, R, 8, S, 5), 156*P) diff --git a/t/t-mp.py b/t/t-mp.py index 9927c12..f856f21 100644 --- a/t/t-mp.py +++ b/t/t-mp.py @@ -333,10 +333,12 @@ class TestMPMont (U.TestCase): me.assertEqual(m.expr(m.int(2), p - 1), m.r) q, r, s, z = 32, 128, 2048, pow(g, 156, p) + me.assertEqual(m.mexp(set([(q, 9), (r, 8), (s, 5)])), z) me.assertEqual(m.mexp([(q, 9), (r, 8), (s, 5)]), z) me.assertEqual(m.mexp(q, 9, r, 8, s, 5), z) q, r, s, z = T.imap(m.int, [32, 128, 2048, pow(g, 156, p)]) + me.assertEqual(m.mexpr(set([(q, 9), (r, 8), (s, 5)])), z) me.assertEqual(m.mexpr([(q, 9), (r, 8), (s, 5)]), z) me.assertEqual(m.mexpr(q, 9, r, 8, s, 5), z) @@ -359,6 +361,7 @@ class TestMPBarrett (U.TestCase): me.assertEqual(m.exp(2, p - 1), 1) q, r, s, z = 32, 128, 2048, pow(g, 156, p) + me.assertEqual(m.mexp(set([(q, 9), (r, 8), (s, 5)])), z) me.assertEqual(m.mexp([(q, 9), (r, 8), (s, 5)]), z) me.assertEqual(m.mexp(q, 9, r, 8, s, 5), z) -- 2.11.0