From bfc8465895e2f792f878119edf216d9b6c51cd1f Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Wed, 2 Nov 2016 19:01:21 -0400 Subject: Convert pkey_match() test into a proper assertion-based unit test. --- unit-tests.py | 66 ++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/unit-tests.py b/unit-tests.py index 3b67f1f..cee3202 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -539,42 +539,48 @@ class TestPKeyList(TestCaseLoggedIn): def test_ks_list_token(self): self.ks_list(HAL_KEY_FLAG_TOKEN) - verbose = False - - def blather(self, line = ""): - if self.verbose: - print line - - def ks_print(self, flags, **kwargs): + def match(self, flags, **kwargs): kwargs.update(flags = flags) - for uuid in hsm.pkey_match(**kwargs): + for n, uuid in enumerate(hsm.pkey_match(**kwargs), 1): with hsm.pkey_find(uuid, flags) as k: - line = "{k.uuid} 0x{k.key_flags:02x} {k.key_curve} {k.key_type} {a[0]} {a[1]}".format( - k = k, a = [k.get_attribute(i) for i in xrange(2)]) - self.blather(line) + yield n, k def ks_match(self, flags): - self.load_keys(flags) - self.load_keys(flags) - self.blather() - self.blather("All:") - self.ks_print(flags = flags) - self.blather() + tags = [] + for i in xrange(2): + self.load_keys(flags) + tags.extend(static_keys) + + uuids = set() + for n, k in self.match(flags = flags): + uuids.add(k.uuid) + self.assertEqual(n, len(uuids)) + self.assertEqual(n, len(tags)) + for keytype in set(HALKeyType.index.itervalues()) - {HAL_KEY_TYPE_NONE}: - self.blather("Type: {}".format(keytype)) - self.ks_print(flags = flags, type = keytype) - self.blather() + for n, k in self.match(flags = flags, type = keytype): + self.assertEqual(k.key_type, keytype) + self.assertEqual(k.get_attribute(0), str(keytype)) + self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype)) + for curve in set(HALCurve.index.itervalues()) - {HAL_CURVE_NONE}: - self.blather("Curve: {}".format(curve)) - self.ks_print(flags = flags, curve = curve) - self.blather() - for keylen in sorted(set(kl for kt, kl in static_keys if not isinstance(kl, Enum))): - self.blather("Keylen: {}".format(keylen)) - self.ks_print(flags = flags, attributes = {1 : str(keylen)}) - self.blather() - self.blather("Keylen 2048 RSA public keys:") - self.ks_print(flags = flags, type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}) - self.blather() + for n, k in self.match(flags = flags, curve = curve): + self.assertEqual(k.key_curve, curve) + self.assertEqual(k.get_attribute(1), str(curve)) + self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC, + HAL_KEY_TYPE_EC_PRIVATE)) + self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve)) + + for keylen in set(kl for kt, kl in tags if not isinstance(kl, Enum)): + for n, k in self.match(flags = flags, attributes = {1 : str(keylen)}): + self.assertEqual(keylen, int(k.get_attribute(1))) + self.assertIn(k.key_type, (HAL_KEY_TYPE_RSA_PUBLIC, + HAL_KEY_TYPE_RSA_PRIVATE)) + self.assertEqual(n, sum(1 for t1, t2 in tags if not isinstance(t2, Enum) and t2 == keylen)) + + for n, k in self.match(flags = flags, type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}): + self.assertEqual(k.key_type, HAL_KEY_TYPE_RSA_PUBLIC) + self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == HAL_KEY_TYPE_RSA_PUBLIC and t2 == 2048)) def test_ks_match_token(self): self.ks_match(HAL_KEY_FLAG_TOKEN) -- cgit v1.2.3