pubkey.c (RSAPriv): Fix segfault if `p' is provided but not `q'.
[catacomb-python] / pgen.c
diff --git a/pgen.c b/pgen.c
index afeeb33..9460310 100644 (file)
--- a/pgen.c
+++ b/pgen.c
@@ -379,6 +379,7 @@ static PyTypeObject rabin_pytype_skel = {
 
 typedef struct pgevent_pyobj {
   PyObject_HEAD
+  PyObject *r;
   pgen_event *ev;
 } pgevent_pyobj;
 
@@ -388,18 +389,30 @@ static PyTypeObject *pgevent_pytype;
 static PyObject *pgevent_pywrap(pgen_event *ev)
 {
   pgevent_pyobj *o = PyObject_New(pgevent_pyobj, pgevent_pytype);
-  o->ev = ev;
+  o->ev = ev; o->r = 0;
   return ((PyObject *)o);
 }
 
 static CONVFUNC(pgevent, pgen_event *, PGEVENT_EV)
 
-static void pgevent_kill(PyObject *me) { PGEVENT_EV(me) = 0; }
-static void pgevent_pydealloc(PyObject *me) { FREEOBJ(me); }
+static void pgevent_kill(PyObject *me)
+{
+  pgevent_pyobj *ev = (pgevent_pyobj *)me;
+
+  ev->ev = 0;
+  if (ev->r) GRAND_R(ev->r) = 0;
+}
+
+static void pgevent_pydealloc(PyObject *me)
+{
+  pgevent_pyobj *ev = (pgevent_pyobj *)me;
+  if (ev->r) Py_DECREF(ev->r);
+  FREEOBJ(me);
+}
 
 #define PGEVENT_CHECK(me) do {                                         \
   if (!PGEVENT_EV(me)) {                                               \
-    PyErr_SetString(PyExc_ValueError, "event object is dead");         \
+    PyErr_SetString(PyExc_ValueError, "event object is no longer valid"); \
     return (0);                                                                \
   }                                                                    \
 } while (0)
@@ -417,7 +430,13 @@ static PyObject *peget_tests(PyObject *me, void *hunoz)
   { PGEVENT_CHECK(me); return (PyInt_FromLong(PGEVENT_EV(me)->tests)); }
 
 static PyObject *peget_rng(PyObject *me, void *hunoz)
-  { PGEVENT_CHECK(me); return (grand_pywrap(PGEVENT_EV(me)->r, 0)); }
+{
+  pgevent_pyobj *ev = (pgevent_pyobj *)me;
+
+  PGEVENT_CHECK(me);
+  if (!ev->r) ev->r = grand_pywrap(ev->ev->r, 0);
+  Py_INCREF(ev->r); return ((PyObject *)ev->r);
+}
 
 static int peset_x(PyObject *me, PyObject *xobj, void *hunoz)
 {
@@ -525,7 +544,7 @@ static PyTypeObject *pgtest_pytype;
 
 static int pgev_python(int rq, pgen_event *ev, void *p)
 {
-  PyObject *py = p;
+  pypgev *pg = p;
   PyObject *pyev = 0;
   PyObject *rc = 0;
   int st = PGEN_ABORT;
@@ -534,11 +553,10 @@ static int pgev_python(int rq, pgen_event *ev, void *p)
     "pg_abort", "pg_done", "pg_begin", "pg_try", "pg_fail", "pg_pass"
   };
 
-  Py_INCREF(py);
   rq++;
   if (rq > N(meth)) SYSERR("event code out of range");
   pyev = pgevent_pywrap(ev);
-  if ((rc = PyObject_CallMethod(py, meth[rq], "(O)", pyev)) == 0)
+  if ((rc = PyObject_CallMethod(pg->obj, meth[rq], "(O)", pyev)) == 0)
     goto end;
   if (rc == Py_None)
     st = PGEN_TRY;
@@ -549,12 +567,13 @@ static int pgev_python(int rq, pgen_event *ev, void *p)
   else
     st = l;
 end:
+  if (PyErr_Occurred())
+    stash_exception(pg->exc, "exception from `pgen' handler");
   if (pyev) {
     pgevent_kill(pyev);
     Py_DECREF(pyev);
   }
   Py_XDECREF(rc);
-  Py_DECREF(py);
   return (st);
 }
 
@@ -569,24 +588,22 @@ static PyObject *pgev_pywrap(const pgev *pg)
 
 int convpgev(PyObject *o, void *p)
 {
-  pgev *pg = p;
+  pypgev *pg = p;
 
   if (PGEV_PYCHECK(o))
-    *pg = *PGEV_PG(o);
+    pg->ev = *PGEV_PG(o);
   else {
-    pg->proc = pgev_python;
-    pg->ctx = o;
-    Py_INCREF(o);
+    pg->ev.proc = pgev_python;
+    pg->ev.ctx = pg;
+    pg->obj = o; Py_INCREF(o);
   }
   return (1);
 }
 
-void droppgev(pgev *p)
+void droppgev(pypgev *pg)
 {
-  if (p->proc == pgev_python) {
-    PyObject *py = p->ctx;
-    Py_DECREF(py);
-  }
+  if (pg->ev.proc == pgev_python)
+    { assert(pg->ev.ctx == pg); Py_DECREF(pg->obj); }
 }
 
 static PyObject *pgmeth_common(PyObject *me, PyObject *arg, int rq)
@@ -894,10 +911,10 @@ static PyTypeObject pgtest_pytype_skel = {
 
 /*----- Prime generation functions ----------------------------------------*/
 
-void pgenerr(void)
+void pgenerr(struct excinfo *exc)
 {
-  if (!PyErr_Occurred())
-    PyErr_SetString(PyExc_ValueError, "prime generation failed");
+  if (exc->ty) RESTORE_EXCINFO(exc);
+  else PyErr_SetString(PyExc_ValueError, "prime generation failed");
 }
 
 static PyObject *meth_pgen(PyObject *me, PyObject *arg, PyObject *kw)
@@ -908,26 +925,26 @@ static PyObject *meth_pgen(PyObject *me, PyObject *arg, PyObject *kw)
   char *p = "p";
   pgen_filterctx fc = { 2 };
   rabin tc;
-  pgev step = { 0 }, test = { 0 }, evt = { 0 };
+  struct excinfo exc = EXCINFO_INIT;
+  pypgev step = { { 0 } }, test = { { 0 } }, evt = { { 0 } };
   unsigned nsteps = 0, ntests = 0;
   char *kwlist[] = { "start", "name", "stepper", "tester", "event",
                     "nsteps", "ntests", 0 };
 
-  step.proc = pgen_filter; step.ctx = &fc;
-  test.proc = pgen_test; test.ctx = &tc;
+  step.exc = &exc; step.ev.proc = pgen_filter; step.ev.ctx = &fc;
+  test.exc = &exc; test.ev.proc = pgen_test; test.ev.ctx = &tc;
+  evt.exc = &exc;
   if (!PyArg_ParseTupleAndKeywords(arg, kw, "O&|sO&O&O&O&O&:pgen", kwlist,
                                   convmp, &x, &p, convpgev, &step,
                                   convpgev, &test, convpgev, &evt,
                                   convuint, &nsteps, convuint, &ntests))
     goto end;
   if (!ntests) ntests = rabin_iters(mp_bits(x));
-  if ((r = pgen(p, MP_NEW, x, evt.proc, evt.ctx,
-               nsteps, step.proc, step.ctx,
-               ntests, test.proc, test.ctx)) == 0)
-    PGENERR;
-  if (PyErr_Occurred()) goto end;
-  rc = mp_pywrap(r);
-  r = 0;
+  if ((r = pgen(p, MP_NEW, x, evt.ev.proc, evt.ev.ctx,
+               nsteps, step.ev.proc, step.ev.ctx,
+               ntests, test.ev.proc, test.ev.ctx)) == 0)
+    PGENERR(&exc);
+  rc = mp_pywrap(r); r = 0;
 end:
   mp_drop(r); mp_drop(x);
   droppgev(&step); droppgev(&test); droppgev(&evt);
@@ -943,18 +960,20 @@ static PyObject *meth_strongprime_setup(PyObject *me,
   unsigned nbits;
   char *name = "p";
   unsigned n = 0;
-  pgev evt = { 0 };
+  struct excinfo exc = EXCINFO_INIT;
+  pypgev evt = { { 0 } };
   PyObject *rc = 0;
   char *kwlist[] = { "nbits", "name", "event", "rng", "nsteps", 0 };
 
+  evt.exc = &exc;
   if (!PyArg_ParseTupleAndKeywords(arg, kw, "O&|sO&O&O&", kwlist,
                                   convuint, &nbits, &name,
                                   convpgev, &evt, convgrand, &r,
                                   convuint, &n))
     goto end;
   if ((x = strongprime_setup(name, MP_NEW, &f, nbits,
-                            r, n, evt.proc, evt.ctx)) == 0)
-    PGENERR;
+                            r, n, evt.ev.proc, evt.ev.ctx)) == 0)
+    PGENERR(&exc);
   rc = Py_BuildValue("(NN)", mp_pywrap(x), pfilt_pywrap(&f));
   x = 0;
 end:
@@ -970,18 +989,20 @@ static PyObject *meth_strongprime(PyObject *me, PyObject *arg, PyObject *kw)
   unsigned nbits;
   char *name = "p";
   unsigned n = 0;
-  pgev evt = { 0 };
+  struct excinfo exc = EXCINFO_INIT;
+  pypgev evt = { { 0 } };
   PyObject *rc = 0;
   char *kwlist[] = { "nbits", "name", "event", "rng", "nsteps", 0 };
 
+  evt.exc = &exc;
   if (!PyArg_ParseTupleAndKeywords(arg, kw, "O&|sO&O&O&", kwlist,
                                   convuint, &nbits, &name,
                                   convpgev, &evt, convgrand, &r,
                                   convuint, &n))
     goto end;
   if ((x = strongprime(name, MP_NEW, nbits,
-                      r, n, evt.proc, evt.ctx)) == 0)
-    PGENERR;
+                      r, n, evt.ev.proc, evt.ev.ctx)) == 0)
+    PGENERR(&exc);
   rc = mp_pywrap(x);
   x = 0;
 end:
@@ -993,7 +1014,8 @@ end:
 static PyObject *meth_limlee(PyObject *me, PyObject *arg, PyObject *kw)
 {
   char *p = "p";
-  pgev ie = { 0 }, oe = { 0 };
+  struct excinfo exc = EXCINFO_INIT;
+  pypgev ie = { { 0 } }, oe = { { 0 } };
   unsigned ql, pl;
   grand *r = &rand_global;
   unsigned on = 0;
@@ -1003,14 +1025,16 @@ static PyObject *meth_limlee(PyObject *me, PyObject *arg, PyObject *kw)
                     "rng", "nsteps", 0 };
   mp *x = 0, **v = 0;
 
+  ie.exc = oe.exc = &exc;
   if (!PyArg_ParseTupleAndKeywords(arg, kw, "O&O&|sO&O&O&O&:limlee", kwlist,
                                   convuint, &pl, convuint, &ql,
                                   &p, convpgev, &oe, convpgev, &ie,
                                   convgrand, &r, convuint, &on))
     goto end;
   if ((x = limlee(p, MP_NEW, MP_NEW, ql, pl, r, on,
-                 oe.proc, oe.ctx, ie.proc, ie.ctx, &nf, &v)) == 0)
-    PGENERR;
+                 oe.ev.proc, oe.ev.ctx, ie.ev.proc, ie.ev.ctx,
+                 &nf, &v)) == 0)
+    PGENERR(&exc);;
   vec = PyList_New(nf);
   for (i = 0; i < nf; i++)
     PyList_SetItem(vec, i, mp_pywrap(v[i]));