From: Mark Wooding Date: Sun, 24 Nov 2019 00:28:14 +0000 (+0000) Subject: mp.c: Tighten up the `MP' and `GF' implicit conversions. X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb-python/commitdiff_plain/e7d2e2e29042a980933261a173d96c71982514fa mp.c: Tighten up the `MP' and `GF' implicit conversions. Allow implicit conversions to `MP' from `PrimeFilter' (for old time's sake), and anything /other than/ field elements and polynomials which implements the `__index__' implicit conversion to integer. Don't allow implicit conversions to `GF' at all. --- diff --git a/mp.c b/mp.c index 163e690..ce04819 100644 --- a/mp.c +++ b/mp.c @@ -289,24 +289,27 @@ int convgf(PyObject *o, void *p) mp *implicitmp(PyObject *o) { - if (!o || - GF_PYCHECK(o) || - ECPT_PYCHECK(o) || - FE_PYCHECK(o) || - GE_PYCHECK(o)) - return (0); - return (tomp(o)); + PyObject *l; + + if (!o || GF_PYCHECK(o) || FE_PYCHECK(o)) return (0); + else if (MP_PYCHECK(o)) return (MP_COPY(MP_X(o))); + else if (PFILT_PYCHECK(o)) return (MP_COPY(PFILT_F(o)->m)); +#ifdef PY2 + else if (PyInt_Check(o)) return (mp_fromlong(MP_NEW, PyInt_AS_LONG(o))); +#endif + else if ((l = PyNumber_Index(o)) != 0) { +#ifdef PY2 + if (PyInt_Check(o)) return (mp_fromlong(MP_NEW, PyInt_AS_LONG(o))); +#endif + if (PyLong_Check(o)) return (mp_frompylong(o)); + } + PyErr_Clear(); return (0); } mp *implicitgf(PyObject *o) { - if (!o || - MP_PYCHECK(o) || - ECPT_PYCHECK(o) || - FE_PYCHECK(o) || - GE_PYCHECK(o)) - return (0); - return (tomp(o)); + if (GF_PYCHECK(o)) return (MP_COPY(MP_X(o))); + return (0); } static int mpbinop(PyObject *x, PyObject *y, mp **xx, mp **yy) diff --git a/t/t-convert.py b/t/t-convert.py index a5f7ff0..19d6918 100644 --- a/t/t-convert.py +++ b/t/t-convert.py @@ -51,7 +51,7 @@ class TestConvert (U.TestCase): me.assertEqual(pow(C.MP(5), 2, 7), 4) me.assertEqual(pow(5, C.MP(2), 7), 4) me.assertEqual(pow(5, 2, C.MP(7)), 4) - for bad in [C.GF, k, kk, float, lambda x: [x]]: + for bad in [C.GF, k, kk, str, float, lambda x: [x]]: me.assertRaises(TypeError, pow, C.MP(5), bad(2)) me.assertRaises(TypeError, pow, C.MP(5), bad(2), 7) if not (T.PY2 and T.DEBUGP): @@ -76,10 +76,10 @@ class TestConvert (U.TestCase): me.assertEqual(pow(C.GF(0x5), 2), C.GF(0x11)) me.assertEqual(pow(C.GF(0x5), C.MP(2)), C.GF(0x11)) me.assertEqual(pow(C.GF(5), 2, C.GF(0x13)), C.GF(0x2)) - for bad in [k, kk, float, lambda x: [x]]: + for bad in [k, kk, str, float, lambda x: [x]]: me.assertRaises(TypeError, pow, C.GF(5), bad(2)) me.assertRaises(TypeError, T.lsl, C.GF(5), bad(2)) - for bad in [C.MP, k, kk, float, lambda x: [x]]: + for bad in [int, C.MP, k, kk, str, float, lambda x: [x]]: me.assertRaises(TypeError, pow, bad(5), C.GF(2)) me.assertRaises(TypeError, pow, bad(5), C.GF(2), bad(7)) me.assertRaises(TypeError, pow, bad(5), bad(2), C.GF(7)) @@ -91,13 +91,18 @@ class TestConvert (U.TestCase): ## `MP' and `GF'. me.assertEqual(C.MP(5), 5) me.assertEqual(5, C.MP(5)) + me.assertNotEqual(C.GF(5), 5) + me.assertNotEqual(5, C.GF(5)) me.assertNotEqual(C.MP(5), C.GF(5)) me.assertNotEqual(C.GF(5), C.MP(5)) me.assertEqual(C.MP(5) + 3, 8) me.assertEqual(3 + C.MP(5), 8) + me.assertRaises(TypeError, T.add, C.GF(5), 3) + me.assertRaises(TypeError, T.add, 3, C.GF(5)) me.assertRaises(TypeError, T.add, C.MP(5), C.GF(3)) me.assertRaises(TypeError, T.add, C.GF(3), C.MP(5)) + if T.PY3: me.assertRaises(TypeError, T.le, C.MP(4), C.GF(5)) ## Field elements. me.assertEqual(k(3) + 4, 7) @@ -109,11 +114,15 @@ class TestConvert (U.TestCase): me.assertEqual(kk(3) + C.GF(7), C.GF(4)) me.assertRaises(TypeError, T.add, k(3), 3.0) + me.assertRaises(TypeError, T.add, k(3), "3") me.assertRaises(TypeError, T.add, k(3), kk(3)) me.assertRaises(TypeError, T.add, kk(3), k(3)) me.assertRaises(TypeError, T.add, k(3), C.GF(7)) me.assertRaises(TypeError, T.add, C.GF(7), k(3)) + me.assertRaises(TypeError, T.add, kk(3), 7) me.assertRaises(TypeError, T.add, kk(3), 7.0) + me.assertRaises(TypeError, T.add, kk(3), "7") + me.assertRaises(TypeError, T.add, 7, kk(3)) me.assertRaises(TypeError, T.add, kk(3), C.MP(7)) me.assertRaises(TypeError, T.add, C.MP(7), kk(3)) diff --git a/t/t-mp.py b/t/t-mp.py index a1dd657..24e80cd 100644 --- a/t/t-mp.py +++ b/t/t-mp.py @@ -154,6 +154,14 @@ class TestMP (U.TestCase): while z == z + 1: z *= 2.0 me.assertNotEqual(C.MP(int(z)) + 1, z) + def test_strconv(me): + x, y = C.MP(169), "24" + for fn in [T.add, T.sub]: + me.assertRaises(TypeError, fn, x, y) + me.assertRaises(TypeError, fn, y, x) + me.assertEqual(x*y, 169*"24") + me.assertEqual(y*x, 169*"24") + def test_bits(me): x, y, zero = C.MP(169), C.MP(-24), C.MP(0) me.assertTrue(x.testbit(0)) @@ -396,6 +404,7 @@ class TestGF (U.TestCase): me.assertEqual(C.GF(E(1, 4)), C.GF(1)) me.assertRaises(TypeError, C.GF, E()) + me.assertNotEqual(x, 5) # no implicit conversion to int me.assertEqual(int(x), 5) y = C.GF(0x4eeb684a0954ec4ceb255e3e9778d41) me.assertEqual(type(int(y)), T.long)