mp.c: Tighten up the `MP' and `GF' implicit conversions.
authorMark Wooding <mdw@distorted.org.uk>
Sun, 24 Nov 2019 00:28:14 +0000 (00:28 +0000)
committerMark Wooding <mdw@distorted.org.uk>
Sat, 11 Apr 2020 11:49:31 +0000 (12:49 +0100)
Allow implicit conversions to `MP' from `PrimeFilter' (for old time's
sake), and anything /other than/ field elements and polynomials which
implements the `__index__' implicit conversion to integer.

Don't allow implicit conversions to `GF' at all.

mp.c
t/t-convert.py
t/t-mp.py

diff --git a/mp.c b/mp.c
index 163e690..ce04819 100644 (file)
--- a/mp.c
+++ b/mp.c
@@ -289,24 +289,27 @@ int convgf(PyObject *o, void *p)
 
 mp *implicitmp(PyObject *o)
 {
-  if (!o ||
-      GF_PYCHECK(o) ||
-      ECPT_PYCHECK(o) ||
-      FE_PYCHECK(o) ||
-      GE_PYCHECK(o))
-    return (0);
-  return (tomp(o));
+  PyObject *l;
+
+  if (!o || GF_PYCHECK(o) || FE_PYCHECK(o)) return (0);
+  else if (MP_PYCHECK(o)) return (MP_COPY(MP_X(o)));
+  else if (PFILT_PYCHECK(o)) return (MP_COPY(PFILT_F(o)->m));
+#ifdef PY2
+  else if (PyInt_Check(o)) return (mp_fromlong(MP_NEW, PyInt_AS_LONG(o)));
+#endif
+  else if ((l = PyNumber_Index(o)) != 0) {
+#ifdef PY2
+    if (PyInt_Check(o)) return (mp_fromlong(MP_NEW, PyInt_AS_LONG(o)));
+#endif
+    if (PyLong_Check(o)) return (mp_frompylong(o));
+  }
+  PyErr_Clear(); return (0);
 }
 
 mp *implicitgf(PyObject *o)
 {
-  if (!o ||
-      MP_PYCHECK(o) ||
-      ECPT_PYCHECK(o) ||
-      FE_PYCHECK(o) ||
-      GE_PYCHECK(o))
-    return (0);
-  return (tomp(o));
+  if (GF_PYCHECK(o)) return (MP_COPY(MP_X(o)));
+  return (0);
 }
 
 static int mpbinop(PyObject *x, PyObject *y, mp **xx, mp **yy)
index a5f7ff0..19d6918 100644 (file)
@@ -51,7 +51,7 @@ class TestConvert (U.TestCase):
     me.assertEqual(pow(C.MP(5), 2, 7), 4)
     me.assertEqual(pow(5, C.MP(2), 7), 4)
     me.assertEqual(pow(5, 2, C.MP(7)), 4)
-    for bad in [C.GF, k, kk, float, lambda x: [x]]:
+    for bad in [C.GF, k, kk, str, float, lambda x: [x]]:
       me.assertRaises(TypeError, pow, C.MP(5), bad(2))
       me.assertRaises(TypeError, pow, C.MP(5), bad(2), 7)
       if not (T.PY2 and T.DEBUGP):
@@ -76,10 +76,10 @@ class TestConvert (U.TestCase):
     me.assertEqual(pow(C.GF(0x5), 2), C.GF(0x11))
     me.assertEqual(pow(C.GF(0x5), C.MP(2)), C.GF(0x11))
     me.assertEqual(pow(C.GF(5), 2, C.GF(0x13)), C.GF(0x2))
-    for bad in [k, kk, float, lambda x: [x]]:
+    for bad in [k, kk, str, float, lambda x: [x]]:
       me.assertRaises(TypeError, pow, C.GF(5), bad(2))
       me.assertRaises(TypeError, T.lsl, C.GF(5), bad(2))
-    for bad in [C.MP, k, kk, float, lambda x: [x]]:
+    for bad in [int, C.MP, k, kk, str, float, lambda x: [x]]:
       me.assertRaises(TypeError, pow, bad(5), C.GF(2))
       me.assertRaises(TypeError, pow, bad(5), C.GF(2), bad(7))
       me.assertRaises(TypeError, pow, bad(5), bad(2), C.GF(7))
@@ -91,13 +91,18 @@ class TestConvert (U.TestCase):
     ## `MP' and `GF'.
     me.assertEqual(C.MP(5), 5)
     me.assertEqual(5, C.MP(5))
+    me.assertNotEqual(C.GF(5), 5)
+    me.assertNotEqual(5, C.GF(5))
     me.assertNotEqual(C.MP(5), C.GF(5))
     me.assertNotEqual(C.GF(5), C.MP(5))
 
     me.assertEqual(C.MP(5) + 3, 8)
     me.assertEqual(3 + C.MP(5), 8)
+    me.assertRaises(TypeError, T.add, C.GF(5), 3)
+    me.assertRaises(TypeError, T.add, 3, C.GF(5))
     me.assertRaises(TypeError, T.add, C.MP(5), C.GF(3))
     me.assertRaises(TypeError, T.add, C.GF(3), C.MP(5))
+    if T.PY3: me.assertRaises(TypeError, T.le, C.MP(4), C.GF(5))
 
     ## Field elements.
     me.assertEqual(k(3) + 4, 7)
@@ -109,11 +114,15 @@ class TestConvert (U.TestCase):
     me.assertEqual(kk(3) + C.GF(7), C.GF(4))
 
     me.assertRaises(TypeError, T.add, k(3), 3.0)
+    me.assertRaises(TypeError, T.add, k(3), "3")
     me.assertRaises(TypeError, T.add, k(3), kk(3))
     me.assertRaises(TypeError, T.add, kk(3), k(3))
     me.assertRaises(TypeError, T.add, k(3), C.GF(7))
     me.assertRaises(TypeError, T.add, C.GF(7), k(3))
+    me.assertRaises(TypeError, T.add, kk(3), 7)
     me.assertRaises(TypeError, T.add, kk(3), 7.0)
+    me.assertRaises(TypeError, T.add, kk(3), "7")
+    me.assertRaises(TypeError, T.add, 7, kk(3))
     me.assertRaises(TypeError, T.add, kk(3), C.MP(7))
     me.assertRaises(TypeError, T.add, C.MP(7), kk(3))
 
index a1dd657..24e80cd 100644 (file)
--- a/t/t-mp.py
+++ b/t/t-mp.py
@@ -154,6 +154,14 @@ class TestMP (U.TestCase):
     while z == z + 1: z *= 2.0
     me.assertNotEqual(C.MP(int(z)) + 1, z)
 
+  def test_strconv(me):
+    x, y = C.MP(169), "24"
+    for fn in [T.add, T.sub]:
+      me.assertRaises(TypeError, fn, x, y)
+      me.assertRaises(TypeError, fn, y, x)
+    me.assertEqual(x*y, 169*"24")
+    me.assertEqual(y*x, 169*"24")
+
   def test_bits(me):
     x, y, zero = C.MP(169), C.MP(-24), C.MP(0)
     me.assertTrue(x.testbit(0))
@@ -396,6 +404,7 @@ class TestGF (U.TestCase):
     me.assertEqual(C.GF(E(1, 4)), C.GF(1))
     me.assertRaises(TypeError, C.GF, E())
 
+    me.assertNotEqual(x, 5) # no implicit conversion to int
     me.assertEqual(int(x), 5)
     y = C.GF(0x4eeb684a0954ec4ceb255e3e9778d41)
     me.assertEqual(type(int(y)), T.long)