From bbb84e218f971d9dd134e85557951b36146c017a Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Fri, 21 Oct 2016 00:44:46 -0400 Subject: Better enum handling, more readable RPC methods. Using a context manager allows us to write the individual RPC methods fairly legibly, while still enforcing xdrlib.Unpacker.done() logic. Python doesn't really have enums in the sense that C does, and many people have put entirely too much skull sweat into trying to invent the Most Pythonic reimplementation of the enum concept, but an int subclass with a few extra methods is close enough for our purposes. --- libhal.py | 306 +++++++++++++++++++++++++++++++++----------------------------- 1 file changed, 163 insertions(+), 143 deletions(-) diff --git a/libhal.py b/libhal.py index 0924863..5e5832b 100644 --- a/libhal.py +++ b/libhal.py @@ -44,6 +44,7 @@ import time import uuid import xdrlib import serial +import contextlib SLIP_END = chr(0300) # indicates end of packet SLIP_ESC = chr(0333) # indicates byte stuffing @@ -108,11 +109,32 @@ HALError.define(HAL_ERROR_ATTRIBUTE_NOT_FOUND = "Attribute not found") HALError.define(HAL_ERROR_NO_KEY_INDEX_SLOTS = "No key index slots available") -def def_enum(text): - for i, name in enumerate(text.translate(None, ",").split()): - globals()[name] = i +class Enum(int): -def_enum(''' + def __new__(cls, name, value): + self = int.__new__(cls, value) + self._name = name + setattr(self.__class__, name, self) + return self + + def __str__(self): + return self._name + + def __repr__(self): + return "".format(self) + + @classmethod + def define(cls, names): + cls.index = tuple(cls(name, i) for i, name in enumerate(names.translate(None, ",").split())) + globals().update((symbol._name, symbol) for symbol in cls.index) + + def xdr_packer(self, packer): + packer.pack_uint(self) + + +class RPCFunc(Enum): pass + +RPCFunc.define(''' RPC_FUNC_GET_VERSION, RPC_FUNC_GET_RANDOM, RPC_FUNC_SET_PIN, @@ -146,7 +168,9 @@ def_enum(''' RPC_FUNC_PKEY_DELETE_ATTRIBUTE, ''') -def_enum(''' +class HALDigestAlgorithm(Enum): pass + +HALDigestAlgorithm.define(''' hal_digest_algorithm_none, hal_digest_algorithm_sha1, hal_digest_algorithm_sha224, @@ -157,7 +181,9 @@ def_enum(''' hal_digest_algorithm_sha512 ''') -def_enum(''' +class HALKeyType(Enum): pass + +HALKeyType.define(''' HAL_KEY_TYPE_NONE, HAL_KEY_TYPE_RSA_PRIVATE, HAL_KEY_TYPE_RSA_PUBLIC, @@ -165,14 +191,18 @@ def_enum(''' HAL_KEY_TYPE_EC_PUBLIC ''') -def_enum(''' +class HALCurve(Enum): pass + +HALCurve.define(''' HAL_CURVE_NONE, HAL_CURVE_P256, HAL_CURVE_P384, HAL_CURVE_P521 ''') -def_enum(''' +class HALUser(Enum): pass + +HALUser.define(''' HAL_USER_NONE, HAL_USER_NORMAL, HAL_USER_SO, @@ -196,6 +226,12 @@ class Attribute(object): packer.pack_bytes(self.value) +class UUID(uuid.UUID): + + def xdr_packer(self, packer): + packer.pack_bytes(self.bytes) + + def cached_property(func): attr_name = "_" + func.__name__ @@ -221,6 +257,9 @@ class Handle(object): def __cmp__(self, other): return cmp(self.handle, int(other)) + def xdr_packer(self, packer): + packer.pack_uint(self.handle) + class Digest(Handle): @@ -279,20 +318,22 @@ class PKey(Handle): def verify(self, hash = 0, data = "", signature = None): self.hsm.pkey_verify(self, hash, data, signature) - def set_attribute(self, type, value): - self.hsm.pkey_set_attribute(self, type, value) + def set_attribute(self, attr_type, attr_value = None): + self.hsm.pkey_set_attribute(self, attr_type, attr_value) - def get_attribute(self, type): - return self.hsm.pkey_get_attribute(self, type) + def get_attribute(self, attr_type): + return self.hsm.pkey_get_attribute(self, attr_type) - def delete_attribute(self, type): - self.hsm.pkey_delete_attribute(self, type) + def delete_attribute(self, attr_type): + self.hsm.pkey_delete_attribute(self, attr_type) class HSM(object): debug = False + _send_delay = 0 # 0.1 + def _raise_if_error(self, status): if status != 0: raise HALError.table[status]() @@ -300,7 +341,7 @@ class HSM(object): def __init__(self, device = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE", "/dev/ttyUSB0")): while True: try: - self.tty = serial.Serial(device, 921600, timeout=0.1) + self.tty = serial.Serial(device, 921600, timeout = 0.1) break except serial.SerialException: time.sleep(0.2) @@ -309,7 +350,8 @@ class HSM(object): if self.debug: sys.stdout.write("{:02x}".format(ord(c))) self.tty.write(c) - time.sleep(0.1) + if self._send_delay > 0: + time.sleep(self._send_delay) def _send(self, msg): # Expects an xdrlib.Packer if self.debug: @@ -365,18 +407,30 @@ class HSM(object): for arg in args: if hasattr(arg, "xdr_packer"): arg.xdr_packer(packer) - elif isinstance(arg, (int, long, Handle)): - packer.pack_uint(arg) - elif isinstance(arg, str): - packer.pack_bytes(arg) - elif isinstance(arg, uuid.UUID): - packer.pack_bytes(arg.bytes) - elif isinstance(arg, (list, tuple)): - packer.pack_uint(len(arg)) - self._pack(packer, arg) else: - raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg))) + try: + func = getattr(self, "_pack_" + type(arg).__name__) + except AttributeError: + raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg))) + else: + func(packer, arg) + + @staticmethod + def _pack_int(packer, arg): + packer.pack_uint(arg) + + @staticmethod + def _pack_str(packer, arg): + packer.pack_bytes(arg) + + def _pack_tuple(self, packer, arg): + packer.pack_uint(len(arg)) + self._pack(packer, arg) + _pack_long = _pack_int + _pack_list = _pack_tuple + + @contextlib.contextmanager def rpc(self, code, *args, **kwargs): client = kwargs.get("client", 0) packer = xdrlib.Packer() @@ -387,179 +441,144 @@ class HSM(object): unpacker = self._recv(code) client = unpacker.unpack_uint() self._raise_if_error(unpacker.unpack_uint()) - return unpacker + yield unpacker + unpacker.done() def get_version(self): - u = self.rpc(RPC_FUNC_GET_VERSION) - r = u.unpack_uint() - u.done() - return r + with self.rpc(RPC_FUNC_GET_VERSION) as r: + return r.unpack_uint() def get_random(self, n): - u = self.rpc(RPC_FUNC_GET_RANDOM, n) - r = u.unpack_bytes() - u.done() - return r + with self.rpc(RPC_FUNC_GET_RANDOM, n) as r: + return r.unpack_bytes() def set_pin(self, user, pin): - u = self.rpc(RPC_FUNC_SET_PIN, user, pin) - u.done() + with self.rpc(RPC_FUNC_SET_PIN, user, pin): + return def login(self, user, pin): - u = self.rpc(RPC_FUNC_LOGIN, user, pin) - u.done() + with self.rpc(RPC_FUNC_LOGIN, user, pin): + return def logout(self): - u = self.rpc(RPC_FUNC_LOGOUT) - u.done() + with self.rpc(RPC_FUNC_LOGOUT): + return def logout_all(self): - u = self.rpc(RPC_FUNC_LOGOUT_ALL) - u.done() + with self.rpc(RPC_FUNC_LOGOUT_ALL): + return def is_logged_in(self, user): - u = self.rpc(RPC_FUNC_IS_LOGGED_IN, user) - r = u.unpack_bool() - u.done() - return r + with self.rpc(RPC_FUNC_IS_LOGGED_IN, user) as r: + return r.unpack_bool() def hash_get_digest_length(self, alg): - u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg) - r = u.unpack_uint() - u.done() - return r + with self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg) as r: + return r.unpack_uint() def hash_get_digest_algorithm_id(self, alg, max_len = 256): - u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len) - r = u.unpack_bytes() - u.done() - return r + with self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len) as r: + return r.unpack_bytes() def hash_get_algorithm(self, handle): - u = self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle) - r = u.unpack_uint() - u.done() - return r + with self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle) as r: + return HALDigestAlgorithm.index[r.unpack_uint()] def hash_initialize(self, alg, key = "", client = 0, session = 0): - u = self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client) - r = Digest(self, u.unpack_uint(), alg) - u.done() - return r + with self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client) as r: + return Digest(self, r.unpack_uint(), alg) def hash_update(self, handle, data): - u = self.rpc(RPC_FUNC_HASH_UPDATE, handle, data) - u.done() + with self.rpc(RPC_FUNC_HASH_UPDATE, handle, data): + return def hash_finalize(self, handle, length = None): if length is None: length = self.hash_get_digest_length(self.hash_get_algorithm(handle)) - u = self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length) - r = u.unpack_bytes() - u.done() - return r + with self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length) as r: + return r.unpack_bytes() def pkey_load(self, type, curve, der, flags = 0, client = 0, session = 0): - u = self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) - r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) as r: + return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) def pkey_find(self, uuid, flags = 0, client = 0, session = 0): - u = self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client) - r = PKey(self, u.unpack_uint(), uuid) - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client) as r: + return PKey(self, r.unpack_uint(), uuid) def pkey_generate_rsa(self, keylen, exponent, flags = 0, client = 0, session = 0): - u = self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) - r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) as r: + return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) def pkey_generate_ec(self, curve, flags = 0, client = 0, session = 0): - u = self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) - r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) as r: + return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) def pkey_close(self, pkey): - u = self.rpc(RPC_FUNC_PKEY_CLOSE, pkey) - u.done() + with self.rpc(RPC_FUNC_PKEY_CLOSE, pkey): + return def pkey_delete(self, pkey): - u = self.rpc(RPC_FUNC_PKEY_DELETE, pkey) - u.done() + with self.rpc(RPC_FUNC_PKEY_DELETE, pkey): + return def pkey_get_key_type(self, pkey): - u = self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey) - r = u.unpack_uint() - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey) as r: + return HALKeyType.index[r.unpack_uint()] def pkey_get_key_flags(self, pkey): - u = self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey) - r = u.unpack_uint() - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey) as r: + return r.unpack_uint() def pkey_get_public_key_len(self, pkey): - u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey) - r = u.unpack_uint() - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey) as r: + return r.unpack_uint() def pkey_get_public_key(self, pkey, length = None): if length is None: length = self.pkey_get_public_key_len(pkey) - u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length) - r = u.unpack_bytes() - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length) as r: + return r.unpack_bytes() def pkey_sign(self, pkey, hash = 0, data = "", length = 1024): - u = self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length) - r = u.unpack_bytes() - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length) as r: + return r.unpack_bytes() def pkey_verify(self, pkey, hash = 0, data = "", signature = None): - u = self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature) - u.done() + with self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature): + return def pkey_list(self, flags = 0, client = 0, session = 0, length = 512): - u = self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client) - r = tuple((u.unpack_uint(), u.unpack_uint(), u.unpack_uint(), - uuid.UUID(bytes = u.unpack_bytes())) - for i in xrange(u.unpack_uint())) - u.done() - return r + with self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client) as r: + return tuple((HALKeyType.index[r.unpack_uint()], + HALCurve.index[r.unpack_uint()], + r.unpack_uint(), + UUID(bytes = r.unpack_bytes())) + for i in xrange(r.unpack_uint())) def pkey_match(self, type = 0, curve = 0, flags = 0, attributes = (), - previous_uuid = uuid.UUID(int = 0), length = 512, client = 0, session = 0): - u = self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags, - attributes, length, previous_uuid, client = client) - r = tuple(uuid.UUID(bytes = u.unpack_bytes()) - for i in xrange(u.unpack_uint())) - x = uuid.UUID(bytes = u.unpack_bytes()) - u.done() - assert len(r) == 0 or x == r[-1] - return r - - def pkey_set_attribute(self, pkey, type, value): - u = self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, type, value) - u.done() - - def pkey_get_attribute(self, pkey, type): - u = self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, type) - r = u.unpack_bytes() - u.done() - return r - - def pkey_delete_attribute(self, pkey, type): - u = self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, type) - u.done() - + previous_uuid = UUID(int = 0), length = 512, client = 0, session = 0): + with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags, + attributes, length, previous_uuid, client = client) as r: + x = tuple(UUID(bytes = r.unpack_bytes()) + for i in xrange(r.unpack_uint())) + y = UUID(bytes = r.unpack_bytes()) + assert len(x) == 0 or y == x[-1] + return x + + def pkey_set_attribute(self, pkey, attr_type, attr_value = None): + if attr_value is None and isinstance(attr_type, Attribute): + attr_type, attr_value = attr_type.type, attr_type.attr_value + with self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, attr_type, attr_value): + return + + def pkey_get_attribute(self, pkey, attr_type): + with self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, attr_type) as r: + return Attribute(attr_type, r.unpack_bytes()) + + def pkey_delete_attribute(self, pkey, attr_type): + with self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, attr_type): + return if __name__ == "__main__": @@ -583,8 +602,6 @@ if __name__ == "__main__": k = hsm.pkey_generate_ec(HAL_CURVE_P256) print "{0.uuid} {0.key_type} {0.key_flags} {1}".format(k, hexstr(k.public_key)) hsm.pkey_close(k) - k = hsm.pkey_find(k.uuid) - hsm.pkey_delete(k) for flags in (0, HAL_KEY_FLAG_TOKEN): for t, c, f, u in hsm.pkey_list(flags = flags): @@ -593,3 +610,6 @@ if __name__ == "__main__": for f in (HAL_KEY_FLAG_TOKEN, 0): for u in hsm.pkey_match(flags = f): print u + + k = hsm.pkey_find(k.uuid) + hsm.pkey_delete(k) -- cgit v1.2.3