algorithms.c: Add bindings for STROBE.
[catacomb-python] / ec.c
diff --git a/ec.c b/ec.c
index 459c272..40d73e7 100644 (file)
--- a/ec.c
+++ b/ec.c
@@ -186,7 +186,11 @@ static PyObject *ecpt_pymul(PyObject *x, PyObject *y)
   ec zz = EC_INIT;
 
   if (ECPT_PYCHECK(x)) { PyObject *t; t = x; x = y; y = t; }
-  if (!ECPT_PYCHECK(y) || (xx = tomp(x)) == 0) RETURN_NOTIMPL;
+  if (!ECPT_PYCHECK(y)) RETURN_NOTIMPL;
+  if (FE_PYCHECK(x) && FE_F(x)->ops->ty == FTY_PRIME)
+    xx = F_OUT(FE_F(x), MP_NEW, FE_X(x));
+  else if ((xx = implicitmp(x)) == 0)
+    RETURN_NOTIMPL;
   ec_imul(ECPT_C(y), &zz, ECPT_P(y), xx);
   MP_DROP(xx);
   return (ecpt_pywrap(ECPT_COBJ(y), &zz));
@@ -329,11 +333,15 @@ static PyObject *epmeth_parse(PyObject *me, PyObject *arg)
   char *p;
   qd_parse qd;
   PyObject *rc = 0;
+  int paren;
   ec pp = EC_INIT;
 
   if (!PyArg_ParseTuple(arg, "s:parse", &p)) goto end;
   qd.p = p; qd.e = 0;
+  qd_skipspc(&qd); paren = qd_delim(&qd, '(');
   if (!ec_ptparse(&qd, &pp)) VALERR(qd.e);
+  qd_skipspc(&qd); if (paren && !qd_delim(&qd, ')'))
+    { EC_DESTROY(&pp); VALERR("missing `)'"); }
   rc = Py_BuildValue("(Ns)", ecpt_pywrapout(me, &pp), qd.p);
 end:
   return (rc);
@@ -463,7 +471,7 @@ static PyObject *epget__z(PyObject *me, void *hunoz)
 static mp *coord_in(field *f, PyObject *x)
 {
   mp *xx;
-  if (FE_PYCHECK(x) && FE_F(x) == f)
+  if (FE_PYCHECK(x) && (FE_F(x) == f || field_samep(FE_F(x), f)))
     return (MP_COPY(FE_X(x)));
   else if ((xx = getmp(x)) == 0)
     return (0);
@@ -505,10 +513,10 @@ end:
 static int ecptxl_1(ec_curve *c, ec *p, PyObject *x)
 {
   int rc = -1;
-  PyObject *y = 0, *z = 0, *t = 0;
+  PyObject *y = 0, *z = 0, *t = 0, *u = 0;
   mp *xx = 0;
-  Py_ssize_t n;
   qd_parse qd;
+  int paren;
 
   Py_XINCREF(x);
   if (!x || x == Py_None)
@@ -517,35 +525,43 @@ static int ecptxl_1(ec_curve *c, ec *p, PyObject *x)
     getecptout(p, x);
     goto fix;
   } else if (TEXT_CHECK(x)) {
-    qd.p = TEXT_PTR(x);
-    qd.e = 0;
-    if (!ec_ptparse(&qd, p))
-      VALERR(qd.e);
+    qd.p = TEXT_PTR(x); qd.e = 0;
+    qd_skipspc(&qd); paren = qd_delim(&qd, '(');
+    if (!ec_ptparse(&qd, p)) VALERR(qd.e);
+    qd_skipspc(&qd); if (paren && !qd_delim(&qd, ')'))
+      { EC_DESTROY(p); VALERR("missing `)'"); }
+    qd_skipspc(&qd); if (!qd_eofp(&qd)) VALERR("junk at eof");
     goto fix;
   } else if (c && (xx = tomp(x)) != 0) {
     xx = F_IN(c->f, xx, xx);
     if (!EC_FIND(c, p, xx)) VALERR("not on the curve");
-  } else if (PySequence_Check(x)) {
-    t = x; x = 0;
-    n = PySequence_Size(t); if (n < 0) goto end;
-    if (n != 2 && (n != 3 || !c))
-      TYERR("want sequence of two or three items");
-    if ((x = PySequence_GetItem(t, 0)) == 0 ||
-       (y = PySequence_GetItem(t, 1)) == 0 ||
-       (n == 3 && (z = PySequence_GetItem(t, 2)) == 0))
-      goto end;
-    rc = (n == 2) ? ecptxl_2(c, p, x, y) : ecptxl_3(c, p, x, y, z);
+  } else if ((t = PyObject_GetIter(x)) != 0) {
+    Py_DECREF(x);
+    x = PyIter_Next(t); if (!x) goto enditer;
+    y = PyIter_Next(t); if (!y) goto enditer;
+    z = PyIter_Next(t); if (!z && PyErr_Occurred()) goto end;
+    if (z) {
+      u = PyIter_Next(t);
+      if (u) goto enditer;
+      else if (PyErr_Occurred()) goto end;
+    }
+    rc = !z ? ecptxl_2(c, p, x, y) : ecptxl_3(c, p, x, y, z);
     goto end;
-  } else
+  } else {
+    PyErr_Clear();
     TYERR("can't convert to curve point");
+  }
   goto ok;
 
+enditer:
+  if (PyErr_Occurred()) goto end;
+  TYERR("expected sequence of 2 or 3 items");
 fix:
   if (c) EC_IN(c, p, p);
 ok:
   rc = 0;
 end:
-  Py_XDECREF(x); Py_XDECREF(y); Py_XDECREF(z); Py_XDECREF(t);
+  Py_XDECREF(x); Py_XDECREF(y); Py_XDECREF(z); Py_XDECREF(t); Py_XDECREF(u);
   mp_drop(xx);
   return (rc);
 }
@@ -902,7 +918,7 @@ static int ecmmul_fill(void *pp, PyObject *me, PyObject *x, PyObject *m)
   return (0);
 }
 
-static PyObject *ecmmul_exp(PyObject *me, void *pp, int n)
+static PyObject *ecmmul_exp(PyObject *me, void *pp, size_t n)
 {
   ec p = EC_INIT;
   ec_immul(ECCURVE_C(me), &p, pp, n);
@@ -975,7 +991,7 @@ static PyObject *ecmeth_parse(PyObject *me, PyObject *arg)
   if (!PyArg_ParseTuple(arg, "s:parse", &p)) goto end;
   qd.p = p; qd.e = 0;
   if ((c = ec_curveparse(&qd)) == 0) VALERR(qd.e);
-  rc = eccurve_pywrap(0, c);
+  rc = Py_BuildValue("(Ns)", eccurve_pywrap(0, c), qd.p);
 end:
   return (rc);
 }