algorithms.c (KeySZSet): Represent the set as an actual (frozen) set.
authorMark Wooding <mdw@distorted.org.uk>
Sun, 13 Oct 2019 23:52:23 +0000 (00:52 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Mon, 25 Nov 2019 17:51:31 +0000 (17:51 +0000)
algorithms.c
t/t-algorithms.py

index a195de9..071894a 100644 (file)
@@ -86,12 +86,13 @@ PyObject *keysz_pywrap(const octet *k)
     case KSZ_SET: {
       keyszset_pyobj *o =
        PyObject_New(keyszset_pyobj, keyszset_pytype);
+      PyObject *l;
       int i, n;
       o->dfl = ARG(0);
       for (i = 0; ARG(i); i++) ;
-      n = i; o->set = PyTuple_New(n);
-      for (i = 0; i < n; i++)
-       PyTuple_SET_ITEM(o->set, i, PyInt_FromLong(ARG(i)));
+      n = i; l = PyList_New(n);
+      for (i = 0; i < n; i++) PyList_SET_ITEM(l, i, PyInt_FromLong(ARG(i)));
+      o->set = PyFrozenSet_New(l); Py_DECREF(l);
       return ((PyObject *)o);
     } break;
     default:
@@ -152,41 +153,30 @@ static PyObject *keyszset_pynew(PyTypeObject *ty,
                                PyObject *arg, PyObject *kw)
 {
   static const char *const kwlist[] = { "default", "set", 0 };
-  int dfl, i, n, xx;
+  int dfl, xx;
   PyObject *set = 0;
-  PyObject *x = 0, *l = 0;
+  PyObject *x = 0, *l = 0, *i = 0;
   keyszset_pyobj *o = 0;
 
   if (!PyArg_ParseTupleAndKeywords(arg, kw, "i|O:new", KWLIST, &dfl, &set))
     goto end;
-  if (!set) set = PyTuple_New(0);
-  else Py_INCREF(set);
-  if (!PySequence_Check(set)) TYERR("want a sequence");
-  n = PySequence_Size(set); if (n < 0) goto end;
+  if (set) i = PyObject_GetIter(set);
+  else { set = PyTuple_New(0); i = PyObject_GetIter(set); Py_DECREF(set); }
+  if (!i) goto end;
   l = PyList_New(0); if (!l) goto end;
   if (dfl < 0) VALERR("key size cannot be negative");
-  x = PyInt_FromLong(dfl);
-  PyList_Append(l, x);
-  Py_DECREF(x);
-  x = 0;
-  for (i = 0; i < n; i++) {
-    if ((x = PySequence_GetItem(set, i)) == 0) goto end;
-    xx = PyInt_AsLong(x);
-    if (PyErr_Occurred()) goto end;
-    if (xx == dfl) continue;
+  x = PyInt_FromLong(dfl); PyList_Append(l, x); Py_DECREF(x); x = 0;
+  for (;;) {
+    x = PyIter_Next(i); if (!x) break;
+    xx = PyInt_AsLong(x); if (xx == -1 && PyErr_Occurred()) goto end;
     if (xx < 0) VALERR("key size cannot be negative");
-    PyList_Append(l, x);
-    Py_DECREF(x);
-    x = 0;
+    PyList_Append(l, x); Py_DECREF(x); x = 0;
   }
-  Py_DECREF(set);
-  if ((set = PySequence_Tuple(l)) == 0) goto end;
+  if ((set = PyFrozenSet_New(l)) == 0) goto end;
   o = (keyszset_pyobj *)ty->tp_alloc(ty, 0);
   o->dfl = dfl;
   o->set = set;
-  Py_INCREF(set);
 end:
-  Py_XDECREF(set);
   Py_XDECREF(l);
   Py_XDECREF(x);
   return ((PyObject *)o);
@@ -206,25 +196,31 @@ static PyObject *krget_max(PyObject *me, void *hunoz)
 
 static PyObject *ksget_min(PyObject *me, void *hunoz)
 {
-  PyObject *set = ((keyszset_pyobj *)me)->set;
-  int i, n, y, x = -1;
-  n = PyTuple_GET_SIZE(set);
-  for (i = 0; i < n; i++) {
-    y = PyInt_AS_LONG(PyTuple_GET_ITEM(set, i));
+  PyObject *i = PyObject_GetIter(((keyszset_pyobj *)me)->set);
+  PyObject *v = 0;
+  int y, x = -1;
+  for (;;) {
+    v = PyIter_Next(i); if (!v) break;
+    y = PyInt_AsLong(v); assert(y >= 0);
     if (x == -1 || y < x) x = y;
   }
+  Py_DECREF(i); Py_XDECREF(v);
+  if (PyErr_Occurred()) return (0);
   return (PyInt_FromLong(x));
 }
 
 static PyObject *ksget_max(PyObject *me, void *hunoz)
 {
-  PyObject *set = ((keyszset_pyobj *)me)->set;
-  int i, n, y, x = -1;
-  n = PyTuple_GET_SIZE(set);
-  for (i = 0; i < n; i++) {
-    y = PyInt_AS_LONG(PyTuple_GET_ITEM(set, i));
+  PyObject *i = PyObject_GetIter(((keyszset_pyobj *)me)->set);
+  PyObject *v = 0;
+  int y, x = -1;
+  for (;;) {
+    v = PyIter_Next(i); if (!v) break;
+    y = PyInt_AsLong(v); assert(y >= 0);
     if (y > x) x = y;
   }
+  Py_DECREF(i); Py_XDECREF(v);
+  if (PyErr_Occurred()) return (0);
   return (PyInt_FromLong(x));
 }
 
@@ -488,8 +484,8 @@ static const PyTypeObject keyszset_pytype_skel = {
     Py_TPFLAGS_BASETYPE,
 
   /* @tp_doc@ */
-  "KeySZSet(DEFAULT, SEQ)\n"
-  "  Key size constraints: size must be DEFAULT or an element of SEQ.",
+  "KeySZSet(DEFAULT, ITER)\n"
+  "  Key size constraints: size must be DEFAULT or an element of ITER.",
 
   0,                                   /* @tp_traverse@ */
   0,                                   /* @tp_clear@ */
index cff1ebd..52decd6 100644 (file)
@@ -180,7 +180,7 @@ class TestKeysize (U.TestCase):
     me.assertEqual(ksz.default, 32)
     me.assertEqual(ksz.min, 10)
     me.assertEqual(ksz.max, 32)
-    me.assertEqual(set(ksz.set), set([10, 16, 32]))
+    me.assertEqual(ksz.set, set([10, 16, 32]))
     for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
                          (15, 10, 16), (16, 16, 16), (17, 16, 32),
                          (31, 16, 32), (32, 32, 32), (33, 32, None)]:
@@ -194,12 +194,12 @@ class TestKeysize (U.TestCase):
     ## Check construction.
     ksz = C.KeySZSet(7)
     me.assertEqual(ksz.default, 7)
-    me.assertEqual(set(ksz.set), set([7]))
+    me.assertEqual(ksz.set, set([7]))
     me.assertEqual(ksz.min, 7)
     me.assertEqual(ksz.max, 7)
-    ksz = C.KeySZSet(7, [3, 6, 9])
+    ksz = C.KeySZSet(7, iter([3, 6, 9]))
     me.assertEqual(ksz.default, 7)
-    me.assertEqual(set(ksz.set), set([3, 6, 7, 9]))
+    me.assertEqual(ksz.set, set([3, 6, 7, 9]))
     me.assertEqual(ksz.min, 3)
     me.assertEqual(ksz.max, 9)
 
@@ -673,7 +673,7 @@ class TestPoly1305 (HashBufferTestMixin):
     me.assertEqual(C.poly1305.name, "poly1305")
     me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
     me.assertEqual(C.poly1305.keysz.default, 16)
-    me.assertEqual(set(C.poly1305.keysz.set), set([16]))
+    me.assertEqual(C.poly1305.keysz.set, set([16]))
     me.assertEqual(C.poly1305.tagsz, 16)
     me.assertEqual(C.poly1305.masksz, 16)