diff options
-rw-r--r-- | unit-tests.py | 42 |
1 files changed, 25 insertions, 17 deletions
diff --git a/unit-tests.py b/unit-tests.py index dc13265..5f472fd 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -510,16 +510,19 @@ class TestPKeyList(TestCaseLoggedIn): """ def load_keys(self, flags): + uuids = set() for obj in PreloadedKey.db.itervalues(): with hsm.pkey_load(obj.keytype, obj.curve, obj.der, flags) as k: self.addCleanup(lambda uuid: hsm.pkey_find(uuid, flags = flags).delete(), k.uuid) + uuids.add(k.uuid) for i, a in enumerate((str(obj.keytype), str(obj.fn2))): k.set_attribute(i, a) + return uuids def ks_list(self, flags): - self.load_keys(flags) - hsm.pkey_list(flags = flags) - hsm.pkey_match(flags = flags) + uuids = self.load_keys(flags) + self.assertLessEqual(len(uuids), len(set(hsm.pkey_list(flags = flags)))) + self.assertLessEqual(uuids, set(hsm.pkey_match(flags = flags))) def test_ks_list_volatile(self): self.ks_list(0) @@ -528,31 +531,34 @@ class TestPKeyList(TestCaseLoggedIn): self.ks_list(HAL_KEY_FLAG_TOKEN) def match(self, flags, **kwargs): + uuids = kwargs.pop("uuids", None) kwargs.update(flags = flags) - for n, uuid in enumerate(hsm.pkey_match(**kwargs), 1): - with hsm.pkey_find(uuid, flags) as k: - yield n, k + n = 0 + for uuid in hsm.pkey_match(**kwargs): + if uuids is None or uuid in uuids: + with hsm.pkey_find(uuid, flags) as k: + n += 1 + yield n, k def ks_match(self, flags): - tags = [] + tags = [] + uuids = set() for i in xrange(2): - self.load_keys(flags) + uuids |= self.load_keys(flags) tags.extend(PreloadedKey.db) + self.assertEqual(len(tags), len(uuids)) - uuids = set() - for n, k in self.match(flags = flags): - uuids.add(k.uuid) - self.assertEqual(n, len(uuids)) - self.assertEqual(n, len(tags)) + matched_uuids = set(k.uuid for n, k in self.match(flags = flags)) + self.assertGreaterEqual(matched_uuids, uuids) for keytype in set(HALKeyType.index.itervalues()) - {HAL_KEY_TYPE_NONE}: - for n, k in self.match(flags = flags, type = keytype): + for n, k in self.match(flags = flags, uuids = uuids, type = keytype): self.assertEqual(k.key_type, keytype) self.assertEqual(k.get_attribute(0), str(keytype)) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype)) for curve in set(HALCurve.index.itervalues()) - {HAL_CURVE_NONE}: - for n, k in self.match(flags = flags, curve = curve): + for n, k in self.match(flags = flags, uuids = uuids, curve = curve): self.assertEqual(k.key_curve, curve) self.assertEqual(k.get_attribute(1), str(curve)) self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC, @@ -560,13 +566,15 @@ class TestPKeyList(TestCaseLoggedIn): self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve)) for keylen in set(kl for kt, kl in tags if not isinstance(kl, Enum)): - for n, k in self.match(flags = flags, attributes = {1 : str(keylen)}): + for n, k in self.match(flags = flags, uuids = uuids, + attributes = {1 : str(keylen)}): self.assertEqual(keylen, int(k.get_attribute(1))) self.assertIn(k.key_type, (HAL_KEY_TYPE_RSA_PUBLIC, HAL_KEY_TYPE_RSA_PRIVATE)) self.assertEqual(n, sum(1 for t1, t2 in tags if not isinstance(t2, Enum) and t2 == keylen)) - for n, k in self.match(flags = flags, type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}): + for n, k in self.match(flags = flags, uuids = uuids, + type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}): self.assertEqual(k.key_type, HAL_KEY_TYPE_RSA_PUBLIC) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == HAL_KEY_TYPE_RSA_PUBLIC and t2 == 2048)) |