aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--unit-tests.py42
1 files changed, 25 insertions, 17 deletions
diff --git a/unit-tests.py b/unit-tests.py
index dc13265..5f472fd 100644
--- a/unit-tests.py
+++ b/unit-tests.py
@@ -510,16 +510,19 @@ class TestPKeyList(TestCaseLoggedIn):
"""
def load_keys(self, flags):
+ uuids = set()
for obj in PreloadedKey.db.itervalues():
with hsm.pkey_load(obj.keytype, obj.curve, obj.der, flags) as k:
self.addCleanup(lambda uuid: hsm.pkey_find(uuid, flags = flags).delete(), k.uuid)
+ uuids.add(k.uuid)
for i, a in enumerate((str(obj.keytype), str(obj.fn2))):
k.set_attribute(i, a)
+ return uuids
def ks_list(self, flags):
- self.load_keys(flags)
- hsm.pkey_list(flags = flags)
- hsm.pkey_match(flags = flags)
+ uuids = self.load_keys(flags)
+ self.assertLessEqual(len(uuids), len(set(hsm.pkey_list(flags = flags))))
+ self.assertLessEqual(uuids, set(hsm.pkey_match(flags = flags)))
def test_ks_list_volatile(self):
self.ks_list(0)
@@ -528,31 +531,34 @@ class TestPKeyList(TestCaseLoggedIn):
self.ks_list(HAL_KEY_FLAG_TOKEN)
def match(self, flags, **kwargs):
+ uuids = kwargs.pop("uuids", None)
kwargs.update(flags = flags)
- for n, uuid in enumerate(hsm.pkey_match(**kwargs), 1):
- with hsm.pkey_find(uuid, flags) as k:
- yield n, k
+ n = 0
+ for uuid in hsm.pkey_match(**kwargs):
+ if uuids is None or uuid in uuids:
+ with hsm.pkey_find(uuid, flags) as k:
+ n += 1
+ yield n, k
def ks_match(self, flags):
- tags = []
+ tags = []
+ uuids = set()
for i in xrange(2):
- self.load_keys(flags)
+ uuids |= self.load_keys(flags)
tags.extend(PreloadedKey.db)
+ self.assertEqual(len(tags), len(uuids))
- uuids = set()
- for n, k in self.match(flags = flags):
- uuids.add(k.uuid)
- self.assertEqual(n, len(uuids))
- self.assertEqual(n, len(tags))
+ matched_uuids = set(k.uuid for n, k in self.match(flags = flags))
+ self.assertGreaterEqual(matched_uuids, uuids)
for keytype in set(HALKeyType.index.itervalues()) - {HAL_KEY_TYPE_NONE}:
- for n, k in self.match(flags = flags, type = keytype):
+ for n, k in self.match(flags = flags, uuids = uuids, type = keytype):
self.assertEqual(k.key_type, keytype)
self.assertEqual(k.get_attribute(0), str(keytype))
self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype))
for curve in set(HALCurve.index.itervalues()) - {HAL_CURVE_NONE}:
- for n, k in self.match(flags = flags, curve = curve):
+ for n, k in self.match(flags = flags, uuids = uuids, curve = curve):
self.assertEqual(k.key_curve, curve)
self.assertEqual(k.get_attribute(1), str(curve))
self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC,
@@ -560,13 +566,15 @@ class TestPKeyList(TestCaseLoggedIn):
self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve))
for keylen in set(kl for kt, kl in tags if not isinstance(kl, Enum)):
- for n, k in self.match(flags = flags, attributes = {1 : str(keylen)}):
+ for n, k in self.match(flags = flags, uuids = uuids,
+ attributes = {1 : str(keylen)}):
self.assertEqual(keylen, int(k.get_attribute(1)))
self.assertIn(k.key_type, (HAL_KEY_TYPE_RSA_PUBLIC,
HAL_KEY_TYPE_RSA_PRIVATE))
self.assertEqual(n, sum(1 for t1, t2 in tags if not isinstance(t2, Enum) and t2 == keylen))
- for n, k in self.match(flags = flags, type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}):
+ for n, k in self.match(flags = flags, uuids = uuids,
+ type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}):
self.assertEqual(k.key_type, HAL_KEY_TYPE_RSA_PUBLIC)
self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == HAL_KEY_TYPE_RSA_PUBLIC and t2 == 2048))