From 1fae937d27b04ee48b15ae442ee7c46cc10511fd Mon Sep 17 00:00:00 2001 From: Mark Wooding Date: Mon, 25 Nov 2019 12:07:16 +0000 Subject: [PATCH] catacomb/__init__.py: Implement equality and hashing for `KeyData' objects. Equality is determined by value, so don't use `KeyData' objects as hashtable keys and then mutate them. --- catacomb/__init__.py | 28 ++++++++++++++++++++++++++++ t/t-key.py | 33 ++++++++++----------------------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/catacomb/__init__.py b/catacomb/__init__.py index bd8aa57..94efb6f 100644 --- a/catacomb/__init__.py +++ b/catacomb/__init__.py @@ -701,27 +701,41 @@ class _tmp: pp.text(','); pp.breakable() pp.pretty(me.writeflags(me.flags)) pp.end_group(ind, ')') + def __hash__(me): return me._HASHBASE ^ hash(me._guts()) + def __eq__(me, kd): + return type(me) == type(kd) and \ + me._guts() == kd._guts() and \ + me.flags == kd.flags + def __ne__(me, kd): + return not me == kd _augment(KeyData, _tmp) class _tmp: def _guts(me): return me.bin + def __eq__(me, kd): + return isinstance(kd, KeyDataBinary) and me.bin == kd.bin _augment(KeyDataBinary, _tmp) +KeyDataBinary._HASHBASE = 0x961755c3 class _tmp: def _guts(me): return me.ct _augment(KeyDataEncrypted, _tmp) +KeyDataEncrypted._HASHBASE = 0xffe000d4 class _tmp: def _guts(me): return me.mp _augment(KeyDataMP, _tmp) +KeyDataMP._HASHBASE = 0x1cb64d69 class _tmp: def _guts(me): return me.str _augment(KeyDataString, _tmp) +KeyDataString._HASHBASE = 0x349c33ea class _tmp: def _guts(me): return me.ecpt _augment(KeyDataECPt, _tmp) +KeyDataECPt._HASHBASE = 0x2509718b class _tmp: def __repr__(me): @@ -732,7 +746,21 @@ class _tmp: if cyclep: pp.text('...') else: _pp_dict(pp, _iteritems(me)) pp.end_group(ind, ' })') + def __hash__(me): + h = me._HASHBASE + for k, v in _iteritems(me): + h = ((h << 1) ^ 3*hash(k) ^ 5*hash(v))&0xffffffff + return h + def __eq__(me, kd): + if type(me) != type(kd) or me.flags != kd.flags or len(me) != len(kd): + return False + for k, v in _iteritems(me): + try: vv = kd[k] + except KeyError: return False + if v != vv: return False + return True _augment(KeyDataStructured, _tmp) +KeyDataStructured._HASHBASE = 0x85851b21 ###-------------------------------------------------------------------------- ### Abstract groups. diff --git a/t/t-key.py b/t/t-key.py index 20cadb1..bee107b 100644 --- a/t/t-key.py +++ b/t/t-key.py @@ -188,7 +188,7 @@ class TestKeyFile (U.TestCase): k = kf.newkey(0x11111111, "first", exp) me.assertEqual(kf.modifiedp, True) - me.assertEqual(kf[0x11111111].id, 0x11111111) + me.assertEqual(k, kf[0x11111111]) me.assertEqual(k.exptime, exp) me.assertEqual(k.deltime, exp) me.assertRaises(ValueError, setattr, k, "deltime", C.KEXP_FOREVER) @@ -212,24 +212,6 @@ class TestKeyFile (U.TestCase): "22222222:test integer,public:32519164 forever forever -") ###-------------------------------------------------------------------------- - -def keydata_equalp(kd0, kd1): - if type(kd0) is not type(kd1): return False - elif type(kd0) is C.KeyDataBinary: return kd0.bin == kd1.bin - elif type(kd0) is C.KeyDataMP: return kd0.mp == kd1.mp - elif type(kd0) is C.KeyDataEncrypted: return kd0.ct == kd1.ct - elif type(kd0) is C.KeyDataECPt: return kd0.ecpt == kd1.ecpt - elif type(kd0) is C.KeyDataString: return kd0.str == kd1.str - elif type(kd0) is C.KeyDataStructured: - if len(kd0) != len(kd1): return False - for t, v0 in T.iteritems(kd0): - try: v1 = kd1[t] - except KeyError: return False - if not keydata_equalp(v0, v1): return False - return True - else: - raise SystemError("unexpected keydata type") - class TestKeyData (U.TestCase): def test_flags(me): @@ -266,10 +248,8 @@ class TestKeyData (U.TestCase): me.assertEqual(set(T.iterkeys(kd2)), set(["b"])) def check_encode(me, kd): - me.assertTrue(keydata_equalp(C.KeyData.decode(kd.encode()), kd)) - kd1, tail = C.KeyData.read(kd.write()) - me.assertEqual(tail, "") - me.assertTrue(keydata_equalp(kd, kd1)) + me.assertEqual(C.KeyData.decode(kd.encode()), kd) + me.assertEqual(C.KeyData.read(kd.write()), (kd, "")) def test_bin(me): rng = T.detrand("kd-bin") @@ -335,6 +315,13 @@ class TestKeyFileMapping (T.ImmutableMappingTextMixin): me.check_immutable_mapping(kf, model) +class TestKeyStructMapping (T.MutableMappingTestMixin): + def _mkvalue(me, i): return C.KeyDataMP(i) + def _getvalue(me, v): return v.mp + + def test_keystructmap(me): + me.check_mapping(C.KeyDataStructured) + class TestKeyAttrMapping (T.MutableMappingTestMixin): def test_attrmap(me): -- 2.11.0