From 2f21fc43638cf39da07ab85dd21546bc34515730 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Tue, 9 Jun 2020 18:34:06 -0400 Subject: Whack with club until Python 2 works again and Python 3 almost works There's still something wrong with XDR for attribute lists in Python 3, XDR complains that there's unconsumed data and attributes coming back are (sometimes truncated). Python 2 works. Probably data type issue somewhere but haven't spotted it yet. --- cryptech/libhal.py | 10 +++++++--- unit-tests.py | 30 +++++++++++++++--------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/cryptech/libhal.py b/cryptech/libhal.py index 56712cb..1e2dbc6 100644 --- a/cryptech/libhal.py +++ b/cryptech/libhal.py @@ -490,18 +490,19 @@ class HSM(object): msg = slip_decode(b"".join(msg)) if not msg: continue - msg = ContextManagedUnpacker(b"".join(msg)) + msg = ContextManagedUnpacker(msg) if msg.unpack_uint() != code: continue return msg _pack_builtin = ((int, "_pack_uint"), - (str, "_pack_bytes"), + (bytes, "_pack_bytes"), + (str, "_pack_str"), ((list, tuple, set), "_pack_array"), (dict, "_pack_items")) try: - _pack_builtin += (long, "_pack_uint") + _pack_builtin += ((long, "_pack_uint"),) except NameError: # "long" merged with "int" in Python 3 pass @@ -523,6 +524,9 @@ class HSM(object): def _pack_bytes(self, packer, arg): packer.pack_bytes(arg) + def _pack_str(self, packer, arg): + packer.pack_bytes(arg.encode()) + def _pack_array(self, packer, arg): packer.pack_uint(len(arg)) self._pack_args(packer, arg) diff --git a/unit-tests.py b/unit-tests.py index 77ca4cb..fab09e8 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -171,7 +171,7 @@ class TestBasic(TestCase): def test_get_random(self): length = 32 random = hsm.get_random(length) - self.assertIsInstance(random, str) + self.assertIsInstance(random, bytes) self.assertEqual(length, len(random)) @@ -863,7 +863,7 @@ class TestPKeyGen(TestCaseLoggedIn): def sign_verify(self, hashalg, k1, k2, length = 1024): h = hsm.hash_initialize(hashalg) - h.update("Your mother was a hamster") + h.update("Your mother was a hamster".encode()) data = h.finalize() sig = k1.sign(data = data, length = length) k1.verify(signature = sig, data = data) @@ -954,7 +954,7 @@ class TestPKeyHashing(TestCaseLoggedIn): @staticmethod def h(alg, mixed_mode = False): h = hsm.hash_initialize(alg, mixed_mode = mixed_mode) - h.update("Your mother was a hamster") + h.update("Your mother was a hamster".encode()) return h def sign_verify_data(self, alg, k1, k2, length = 1024): @@ -1126,7 +1126,7 @@ class TestPKeyRSAInterop(TestCaseLoggedIn): return h def load_sign_verify_rsa(self, alg, pyhash, keylen): - hamster = "Your mother was a hamster" + hamster = "Your mother was a hamster".encode() sk = PreloadedKey.db[HAL_KEY_TYPE_RSA_PRIVATE, keylen] vk = PreloadedKey.db[HAL_KEY_TYPE_RSA_PUBLIC, keylen] k1 = hsm.pkey_load(sk.der, HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) @@ -1161,7 +1161,7 @@ class TestPKeyECDSAInterop(TestCaseLoggedIn): return h def load_sign_verify_ecdsa(self, alg, pyhash, curve): - hamster = "Your mother was a hamster" + hamster = "Your mother was a hamster".encode() sk = PreloadedKey.db[HAL_KEY_TYPE_EC_PRIVATE, curve] vk = PreloadedKey.db[HAL_KEY_TYPE_EC_PUBLIC, curve] k1 = hsm.pkey_load(sk.der, HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) @@ -1256,7 +1256,7 @@ class TestPKeyMatch(TestCaseLoggedIn): n = 0 for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = keytype): self.assertEqual(k.key_type, keytype) - self.assertEqual(k.get_attributes({0}).pop(0), str(keytype)) + self.assertEqual(k.get_attributes({0}).pop(0), bytes(keytype)) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype)) for curve in set(HALCurve.index.values()) - {HAL_CURVE_NONE}: @@ -1268,7 +1268,7 @@ class TestPKeyMatch(TestCaseLoggedIn): HAL_KEY_TYPE_EC_PRIVATE)) self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve)) - for keylen in set(kl for kt, kl in tags if not isinstance(kl, Enum)): + for keylen in set(kl for kt, kl in tags if not isinstance(kl, CEnum)): n = 0 for n, k in self.match(mask = mask, flags = flags, uuids = uuids, attributes = {1 : str(keylen)}): @@ -1276,7 +1276,7 @@ class TestPKeyMatch(TestCaseLoggedIn): self.assertIn(k.key_type, (HAL_KEY_TYPE_RSA_PUBLIC, HAL_KEY_TYPE_RSA_PRIVATE)) self.assertEqual(n, sum(1 for t1, t2 in tags - if not isinstance(t2, Enum) and t2 == keylen)) + if not isinstance(t2, CEnum) and t2 == keylen)) n = 0 for n, k in self.match(mask = mask, flags = flags, uuids = uuids, @@ -1477,7 +1477,7 @@ class TestPKeyAttributeReadSpeedToken(TestCaseLoggedIn): super(TestPKeyAttributeReadSpeedToken, self).setUp() def verify_attributes(self, n_attrs, attributes): - expected = dict((i, "Attribute {}".format(i)) + expected = dict((i, "Attribute {}".format(i).encode()) for i in range(n_attrs)) self.assertEqual(attributes, expected) @@ -1509,7 +1509,7 @@ class TestPKeyAttributeReadSpeedVolatile(TestCaseLoggedIn): super(TestPKeyAttributeReadSpeedVolatile, self).setUp() def verify_attributes(self, n_attrs, attributes): - expected = dict((i, "Attribute {}".format(i)) + expected = dict((i, "Attribute {}".format(i).encode()) for i in range(n_attrs)) self.assertEqual(attributes, expected) @@ -1601,12 +1601,12 @@ class TestPKeyBackup(TestCaseLoggedIn): @staticmethod def make_kek(): import Crypto.Random - return Crypto.Random.new().read(256/8) + return Crypto.Random.new().read(256//8) def sig_check(self, pkey, der): from Crypto.Util.asn1 import DerSequence, DerNull, DerOctetString p115 = PKCS115_SigScheme(RSA.importKey(der)) - hash = SHA256("Your mother was a hamster") + hash = SHA256("Your mother was a hamster".encode()) data = DerSequence([ DerSequence([hash.oid, DerNull().encode()]).encode(), DerOctetString(hash.digest()).encode() @@ -1708,7 +1708,7 @@ class AESKeyWrapWithPadding(object): if len(C) % 8 != 0: raise self.UnwrapError("Ciphertext length {} is not an integral number of blocks" .format(len(C))) - n = (len(C) / 8) - 1 + n = (len(C) // 8) - 1 R = [C[i : i + 8] for i in range(0, len(C), 8)] if n == 1: R[0], R[1] = self._decrypt(R[0], R[1]) @@ -1726,8 +1726,8 @@ class AESKeyWrapWithPadding(object): raise self.UnwrapError("Length encoded in AIV out of range: m {}, n {}".format(m, n)) R = b"".join(R[1:]) assert len(R) == 8 * n - if any(r != b"\x00" for r in R[m:]): - raise self.UnwrapError("Nonzero trailing bytes 0x{}".format(binascii.hexlify(R[m:]).decode("ascii"))) + if R[m:].strip(b"\x00"): + raise self.UnwrapError("Nonzero trailing bytes 0x{}".format(binascii.hexlify(R[m:]).decode())) return R[:m] -- cgit v1.2.3