From d7157081e8edd10d15b0686429d9323a584c7133 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Wed, 19 Oct 2016 18:18:10 -0400 Subject: Add handle objects to make API a bit more Pythonic. --- libhal.py | 208 +++++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 151 insertions(+), 57 deletions(-) diff --git a/libhal.py b/libhal.py index 99004ad..fa01ffc 100644 --- a/libhal.py +++ b/libhal.py @@ -107,6 +107,7 @@ HALError.define(HAL_ERROR_BAD_ATTRIBUTE_LENGTH = "Bad attribute length") HALError.define(HAL_ERROR_ATTRIBUTE_NOT_FOUND = "Attribute not found") HALError.define(HAL_ERROR_NO_KEY_INDEX_SLOTS = "No key index slots available") + def def_enum(text): for i, name in enumerate(text.translate(None, ",").split()): globals()[name] = i @@ -183,7 +184,101 @@ HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT = (1 << 1) HAL_KEY_FLAG_USAGE_DATAENCIPHERMENT = (1 << 2) HAL_KEY_FLAG_TOKEN = (1 << 3) -class RPC(object): + +def cached_property(func): + + attr_name = "_" + func.__name__ + + def wrapped(self): + try: + value = getattr(self, attr_name) + except AttributeError: + value = func(self) + setattr(self, attr_name, value) + return value + + wrapped.__name__ = func.__name__ + + return property(wrapped) + + +class Handle(object): + + def __int__(self): + return self.handle + + def __cmp__(self, other): + return cmp(self.handle, int(other)) + + +class Digest(Handle): + + def __init__(self, hsm, handle, algorithm): + self.hsm = hsm + self.handle = handle + self.algorithm = algorithm + + def update(self, data): + self.hsm.hash_update(self, data) + + def finalize(self, length = None): + return self.hsm.hash_finalize(self, length or self.digest_length) + + @cached_property + def algorithm_id(self): + return self.hsm.hash_get_digest_algorithm_id(self.algorithm) + + @cached_property + def digest_length(self): + return self.hsm.hash_get_digest_length(self.algorithm) + + +class PKey(Handle): + + def __init__(self, hsm, handle, uuid): + self.hsm = hsm + self.handle = handle + self.uuid = uuid + + def close(self): + self.hsm.pkey_close(self) + + def delete(self): + self.hsm.pkey_delete(self) + + @cached_property + def key_type(self): + return self.hsm.pkey_get_key_type(self) + + @cached_property + def key_flags(self): + return self.hsm.pkey_get_key_flags(self) + + @cached_property + def public_key_len(self): + return self.hsm.pkey_get_public_key_len(self) + + @cached_property + def public_key(self): + return self.hsm.pkey_get_public_key(self, self.public_key_len) + + def sign(self, hash = 0, data = "", length = 1024): + return self.hsm.pkey_sign(self, hash, data, length) + + def verify(self, hash = 0, data = "", signature = None): + self.hsm.pkey_verify(self, hash, data, signature) + + def set_attribute(self, type, value): + self.hsm.pkey_set_attribute(self, type, value) + + def get_attribute(self, type): + return self.hsm.pkey_get_attribute(self, type) + + def delete_attribute(self, type): + self.hsm.pkey_delete_attribute(self, type) + + +class HSM(object): debug = True @@ -257,7 +352,7 @@ class RPC(object): def _pack(self, packer, args): for arg in args: - if isinstance(arg, (int, long)): + if isinstance(arg, (int, long, Handle)): packer.pack_uint(arg) elif isinstance(arg, str): packer.pack_bytes(arg) @@ -269,7 +364,7 @@ class RPC(object): else: raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg))) - def _call(self, code, *args, **kwargs): + def rpc(self, code, *args, **kwargs): client = kwargs.get("client", 0) packer = xdrlib.Packer() packer.pack_uint(code) @@ -282,121 +377,121 @@ class RPC(object): return unpacker def get_version(self): - u = self._call(RPC_FUNC_GET_VERSION) + u = self.rpc(RPC_FUNC_GET_VERSION) r = u.unpack_uint() u.done() return r def get_random(self, n): - u = self._call(RPC_FUNC_GET_RANDOM, n) + u = self.rpc(RPC_FUNC_GET_RANDOM, n) r = u.unpack_bytes() u.done() return r def set_pin(self, user, pin): - u = self._call(RPC_FUNC_SET_PIN, user, pin) + u = self.rpc(RPC_FUNC_SET_PIN, user, pin) u.done() def login(self, user, pin): - u = self._call(RPC_FUNC_LOGIN, user, pin) + u = self.rpc(RPC_FUNC_LOGIN, user, pin) u.done() def logout(self): - u = self._call(RPC_FUNC_LOGOUT) + u = self.rpc(RPC_FUNC_LOGOUT) u.done() def logout_all(self): - u = self._call(RPC_FUNC_LOGOUT_ALL) + u = self.rpc(RPC_FUNC_LOGOUT_ALL) u.done() def is_logged_in(self, user): - u = self._call(RPC_FUNC_IS_LOGGED_IN, user) + u = self.rpc(RPC_FUNC_IS_LOGGED_IN, user) r = u.unpack_bool() u.done() return r def hash_get_digest_length(self, alg): - u = self._call(RPC_FUNC_HASH_GET_DIGEST_LEN, alg) + u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg) r = u.unpack_uint() u.done() return r def hash_get_digest_algorithm_id(self, alg, max_len = 256): - u = self._call(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len) + u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len) r = u.unpack_bytes() u.done() return r def hash_get_algorithm(self, handle): - u = self._call(RPC_FUNC_HASH_GET_ALGORITHM, handle) + u = self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle) r = u.unpack_uint() u.done() return r def hash_initialize(self, alg, key = "", client = 0, session = 0): - u = self._call(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client) - r = u.unpack_uint() + u = self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client) + r = Digest(self, u.unpack_uint(), alg) u.done() return r def hash_update(self, handle, data): - u = self._call(RPC_FUNC_HASH_UPDATE, handle, data) + u = self.rpc(RPC_FUNC_HASH_UPDATE, handle, data) u.done() def hash_finalize(self, handle, length = None): if length is None: length = self.hash_get_digest_length(self.hash_get_algorithm(handle)) - u = self._call(RPC_FUNC_HASH_FINALIZE, handle, length) + u = self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length) r = u.unpack_bytes() u.done() return r def pkey_load(self, type, curve, der, flags = 0, client = 0, session = 0): - u = self._call(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) - r = u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()) + u = self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) + r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) u.done() return r def pkey_find(self, uuid, flags = 0, client = 0, session = 0): - u = self._call(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client) - r = u.unpack_uint() + u = self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client) + r = PKey(self, u.unpack_uint(), uuid) u.done() return r def pkey_generate_rsa(self, keylen, exponent, flags = 0, client = 0, session = 0): - u = self._call(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) - r = u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()) + u = self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) + r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) u.done() return r def pkey_generate_ec(self, curve, flags = 0, client = 0, session = 0): - u = self._call(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) - r = u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()) + u = self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) + r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) u.done() return r def pkey_close(self, pkey): - u = self._call(RPC_FUNC_PKEY_CLOSE, pkey) + u = self.rpc(RPC_FUNC_PKEY_CLOSE, pkey) u.done() def pkey_delete(self, pkey): - u = self._call(RPC_FUNC_PKEY_DELETE, pkey) + u = self.rpc(RPC_FUNC_PKEY_DELETE, pkey) u.done() def pkey_get_key_type(self, pkey): - u = self._call(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey) + u = self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey) r = u.unpack_uint() u.done() return r def pkey_get_key_flags(self, pkey): - u = self._call(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey) + u = self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey) r = u.unpack_uint() u.done() return r def pkey_get_public_key_len(self, pkey): - u = self._call(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey) + u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey) r = u.unpack_uint() u.done() return r @@ -404,23 +499,23 @@ class RPC(object): def pkey_get_public_key(self, pkey, length = None): if length is None: length = self.pkey_get_public_key_len(pkey) - u = self._call(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length) + u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length) r = u.unpack_bytes() u.done() return r def pkey_sign(self, pkey, hash = 0, data = "", length = 1024): - u = self._call(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length) + u = self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length) r = u.unpack_bytes() u.done() return r def pkey_verify(self, pkey, hash = 0, data = "", signature = None): - u = self._call(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature) + u = self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature) u.done() def pkey_list(self, flags = 0, client = 0, session = 0, length = 512): - u = self._call(RPC_FUNC_PKEY_LIST, session, length, flags, client = client) + u = self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client) r = tuple((u.unpack_uint(), u.unpack_uint(), u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes())) for i in xrange(u.unpack_uint())) @@ -429,7 +524,8 @@ class RPC(object): def pkey_match(self, type = 0, curve = 0, flags = 0, attributes = (), previous_uuid = uuid.UUID(int = 0), length = 512, client = 0, session = 0): - u = self._call(RPC_FUNC_PKEY_MATCH, session, type, curve, flags, attributes, length, previous_uuid, client = client) + u = self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags, + attributes, length, previous_uuid, client = client) r = tuple(uuid.UUID(bytes = u.unpack_bytes()) for i in xrange(u.unpack_uint())) x = uuid.UUID(bytes = u.unpack_bytes()) @@ -438,51 +534,49 @@ class RPC(object): return r def pkey_set_attribute(self, pkey, type, value): - u = self._call(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, type, value) + u = self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, type, value) u.done() def pkey_get_attribute(self, pkey, type): - u = self._call(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, type) + u = self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, type) r = u.unpack_bytes() u.done() return r def pkey_delete_attribute(self, pkey, type): - u = self._call(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, type) + u = self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, type) u.done() + if __name__ == "__main__": def hexstr(s): return "".join("{:02x}".format(ord(c)) for c in s) - rpc = RPC() + hsm = HSM() - print hex(rpc.get_version()) + print hex(hsm.get_version()) - print hexstr(rpc.get_random(16)) + print hexstr(hsm.get_random(16)) - h = rpc.hash_initialize(hal_digest_algorithm_sha256) - rpc.hash_update(h, "Hi, Mom") - print hexstr(rpc.hash_finalize(h)) + h = hsm.hash_initialize(hal_digest_algorithm_sha256) + h.update("Hi, Mom") + print hexstr(h.finalize()) - h = rpc.hash_initialize(hal_digest_algorithm_sha256, key = "secret") - rpc.hash_update(h, "Hi, Dad") - print hexstr(rpc.hash_finalize(h)) + h = hsm.hash_initialize(hal_digest_algorithm_sha256, key = "secret") + h.update("Hi, Dad") + print hexstr(h.finalize()) - k, u = rpc.pkey_generate_ec(HAL_CURVE_P256) - t = rpc.pkey_get_key_type(k) - f = rpc.pkey_get_key_flags(k) - d = rpc.pkey_get_public_key(k) - print u, t, f, hexstr(d) - rpc.pkey_close(k) - k = rpc.pkey_find(u) - rpc.pkey_delete(k) + k = hsm.pkey_generate_ec(HAL_CURVE_P256) + print "{0.uuid} {0.key_type} {0.key_flags} {1}".format(k, hexstr(k.public_key)) + hsm.pkey_close(k) + k = hsm.pkey_find(k.uuid) + hsm.pkey_delete(k) for flags in (0, HAL_KEY_FLAG_TOKEN): - for t, c, f, u in rpc.pkey_list(flags = flags): + for t, c, f, u in hsm.pkey_list(flags = flags): print u, t, c, f for f in (HAL_KEY_FLAG_TOKEN, 0): - for u in rpc.pkey_match(flags = f): + for u in hsm.pkey_match(flags = f): print u -- cgit v1.2.3