aboutsummaryrefslogtreecommitdiff
path: root/libhal.py
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2016-10-21 00:44:46 -0400
committerRob Austein <sra@hactrn.net>2016-10-21 00:44:46 -0400
commitbbb84e218f971d9dd134e85557951b36146c017a (patch)
tree58f155d1bfc4bb85b42b31a7e137fb688f764925 /libhal.py
parent7c47b5772bb2846fc6193b3c8128b376637c32e7 (diff)
Better enum handling, more readable RPC methods.
Using a context manager allows us to write the individual RPC methods fairly legibly, while still enforcing xdrlib.Unpacker.done() logic. Python doesn't really have enums in the sense that C does, and many people have put entirely too much skull sweat into trying to invent the Most Pythonic reimplementation of the enum concept, but an int subclass with a few extra methods is close enough for our purposes.
Diffstat (limited to 'libhal.py')
-rw-r--r--libhal.py306
1 files changed, 163 insertions, 143 deletions
diff --git a/libhal.py b/libhal.py
index 0924863..5e5832b 100644
--- a/libhal.py
+++ b/libhal.py
@@ -44,6 +44,7 @@ import time
import uuid
import xdrlib
import serial
+import contextlib
SLIP_END = chr(0300) # indicates end of packet
SLIP_ESC = chr(0333) # indicates byte stuffing
@@ -108,11 +109,32 @@ HALError.define(HAL_ERROR_ATTRIBUTE_NOT_FOUND = "Attribute not found")
HALError.define(HAL_ERROR_NO_KEY_INDEX_SLOTS = "No key index slots available")
-def def_enum(text):
- for i, name in enumerate(text.translate(None, ",").split()):
- globals()[name] = i
+class Enum(int):
-def_enum('''
+ def __new__(cls, name, value):
+ self = int.__new__(cls, value)
+ self._name = name
+ setattr(self.__class__, name, self)
+ return self
+
+ def __str__(self):
+ return self._name
+
+ def __repr__(self):
+ return "<Enum:{0.__class__.__name__} {0._name}:{0:d}>".format(self)
+
+ @classmethod
+ def define(cls, names):
+ cls.index = tuple(cls(name, i) for i, name in enumerate(names.translate(None, ",").split()))
+ globals().update((symbol._name, symbol) for symbol in cls.index)
+
+ def xdr_packer(self, packer):
+ packer.pack_uint(self)
+
+
+class RPCFunc(Enum): pass
+
+RPCFunc.define('''
RPC_FUNC_GET_VERSION,
RPC_FUNC_GET_RANDOM,
RPC_FUNC_SET_PIN,
@@ -146,7 +168,9 @@ def_enum('''
RPC_FUNC_PKEY_DELETE_ATTRIBUTE,
''')
-def_enum('''
+class HALDigestAlgorithm(Enum): pass
+
+HALDigestAlgorithm.define('''
hal_digest_algorithm_none,
hal_digest_algorithm_sha1,
hal_digest_algorithm_sha224,
@@ -157,7 +181,9 @@ def_enum('''
hal_digest_algorithm_sha512
''')
-def_enum('''
+class HALKeyType(Enum): pass
+
+HALKeyType.define('''
HAL_KEY_TYPE_NONE,
HAL_KEY_TYPE_RSA_PRIVATE,
HAL_KEY_TYPE_RSA_PUBLIC,
@@ -165,14 +191,18 @@ def_enum('''
HAL_KEY_TYPE_EC_PUBLIC
''')
-def_enum('''
+class HALCurve(Enum): pass
+
+HALCurve.define('''
HAL_CURVE_NONE,
HAL_CURVE_P256,
HAL_CURVE_P384,
HAL_CURVE_P521
''')
-def_enum('''
+class HALUser(Enum): pass
+
+HALUser.define('''
HAL_USER_NONE,
HAL_USER_NORMAL,
HAL_USER_SO,
@@ -196,6 +226,12 @@ class Attribute(object):
packer.pack_bytes(self.value)
+class UUID(uuid.UUID):
+
+ def xdr_packer(self, packer):
+ packer.pack_bytes(self.bytes)
+
+
def cached_property(func):
attr_name = "_" + func.__name__
@@ -221,6 +257,9 @@ class Handle(object):
def __cmp__(self, other):
return cmp(self.handle, int(other))
+ def xdr_packer(self, packer):
+ packer.pack_uint(self.handle)
+
class Digest(Handle):
@@ -279,20 +318,22 @@ class PKey(Handle):
def verify(self, hash = 0, data = "", signature = None):
self.hsm.pkey_verify(self, hash, data, signature)
- def set_attribute(self, type, value):
- self.hsm.pkey_set_attribute(self, type, value)
+ def set_attribute(self, attr_type, attr_value = None):
+ self.hsm.pkey_set_attribute(self, attr_type, attr_value)
- def get_attribute(self, type):
- return self.hsm.pkey_get_attribute(self, type)
+ def get_attribute(self, attr_type):
+ return self.hsm.pkey_get_attribute(self, attr_type)
- def delete_attribute(self, type):
- self.hsm.pkey_delete_attribute(self, type)
+ def delete_attribute(self, attr_type):
+ self.hsm.pkey_delete_attribute(self, attr_type)
class HSM(object):
debug = False
+ _send_delay = 0 # 0.1
+
def _raise_if_error(self, status):
if status != 0:
raise HALError.table[status]()
@@ -300,7 +341,7 @@ class HSM(object):
def __init__(self, device = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE", "/dev/ttyUSB0")):
while True:
try:
- self.tty = serial.Serial(device, 921600, timeout=0.1)
+ self.tty = serial.Serial(device, 921600, timeout = 0.1)
break
except serial.SerialException:
time.sleep(0.2)
@@ -309,7 +350,8 @@ class HSM(object):
if self.debug:
sys.stdout.write("{:02x}".format(ord(c)))
self.tty.write(c)
- time.sleep(0.1)
+ if self._send_delay > 0:
+ time.sleep(self._send_delay)
def _send(self, msg): # Expects an xdrlib.Packer
if self.debug:
@@ -365,18 +407,30 @@ class HSM(object):
for arg in args:
if hasattr(arg, "xdr_packer"):
arg.xdr_packer(packer)
- elif isinstance(arg, (int, long, Handle)):
- packer.pack_uint(arg)
- elif isinstance(arg, str):
- packer.pack_bytes(arg)
- elif isinstance(arg, uuid.UUID):
- packer.pack_bytes(arg.bytes)
- elif isinstance(arg, (list, tuple)):
- packer.pack_uint(len(arg))
- self._pack(packer, arg)
else:
- raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg)))
+ try:
+ func = getattr(self, "_pack_" + type(arg).__name__)
+ except AttributeError:
+ raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg)))
+ else:
+ func(packer, arg)
+
+ @staticmethod
+ def _pack_int(packer, arg):
+ packer.pack_uint(arg)
+
+ @staticmethod
+ def _pack_str(packer, arg):
+ packer.pack_bytes(arg)
+
+ def _pack_tuple(self, packer, arg):
+ packer.pack_uint(len(arg))
+ self._pack(packer, arg)
+ _pack_long = _pack_int
+ _pack_list = _pack_tuple
+
+ @contextlib.contextmanager
def rpc(self, code, *args, **kwargs):
client = kwargs.get("client", 0)
packer = xdrlib.Packer()
@@ -387,179 +441,144 @@ class HSM(object):
unpacker = self._recv(code)
client = unpacker.unpack_uint()
self._raise_if_error(unpacker.unpack_uint())
- return unpacker
+ yield unpacker
+ unpacker.done()
def get_version(self):
- u = self.rpc(RPC_FUNC_GET_VERSION)
- r = u.unpack_uint()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_GET_VERSION) as r:
+ return r.unpack_uint()
def get_random(self, n):
- u = self.rpc(RPC_FUNC_GET_RANDOM, n)
- r = u.unpack_bytes()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_GET_RANDOM, n) as r:
+ return r.unpack_bytes()
def set_pin(self, user, pin):
- u = self.rpc(RPC_FUNC_SET_PIN, user, pin)
- u.done()
+ with self.rpc(RPC_FUNC_SET_PIN, user, pin):
+ return
def login(self, user, pin):
- u = self.rpc(RPC_FUNC_LOGIN, user, pin)
- u.done()
+ with self.rpc(RPC_FUNC_LOGIN, user, pin):
+ return
def logout(self):
- u = self.rpc(RPC_FUNC_LOGOUT)
- u.done()
+ with self.rpc(RPC_FUNC_LOGOUT):
+ return
def logout_all(self):
- u = self.rpc(RPC_FUNC_LOGOUT_ALL)
- u.done()
+ with self.rpc(RPC_FUNC_LOGOUT_ALL):
+ return
def is_logged_in(self, user):
- u = self.rpc(RPC_FUNC_IS_LOGGED_IN, user)
- r = u.unpack_bool()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_IS_LOGGED_IN, user) as r:
+ return r.unpack_bool()
def hash_get_digest_length(self, alg):
- u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg)
- r = u.unpack_uint()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg) as r:
+ return r.unpack_uint()
def hash_get_digest_algorithm_id(self, alg, max_len = 256):
- u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len)
- r = u.unpack_bytes()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len) as r:
+ return r.unpack_bytes()
def hash_get_algorithm(self, handle):
- u = self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle)
- r = u.unpack_uint()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle) as r:
+ return HALDigestAlgorithm.index[r.unpack_uint()]
def hash_initialize(self, alg, key = "", client = 0, session = 0):
- u = self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client)
- r = Digest(self, u.unpack_uint(), alg)
- u.done()
- return r
+ with self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client) as r:
+ return Digest(self, r.unpack_uint(), alg)
def hash_update(self, handle, data):
- u = self.rpc(RPC_FUNC_HASH_UPDATE, handle, data)
- u.done()
+ with self.rpc(RPC_FUNC_HASH_UPDATE, handle, data):
+ return
def hash_finalize(self, handle, length = None):
if length is None:
length = self.hash_get_digest_length(self.hash_get_algorithm(handle))
- u = self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length)
- r = u.unpack_bytes()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length) as r:
+ return r.unpack_bytes()
def pkey_load(self, type, curve, der, flags = 0, client = 0, session = 0):
- u = self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client)
- r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()))
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) as r:
+ return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes()))
def pkey_find(self, uuid, flags = 0, client = 0, session = 0):
- u = self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client)
- r = PKey(self, u.unpack_uint(), uuid)
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client) as r:
+ return PKey(self, r.unpack_uint(), uuid)
def pkey_generate_rsa(self, keylen, exponent, flags = 0, client = 0, session = 0):
- u = self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client)
- r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()))
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) as r:
+ return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes()))
def pkey_generate_ec(self, curve, flags = 0, client = 0, session = 0):
- u = self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client)
- r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()))
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) as r:
+ return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes()))
def pkey_close(self, pkey):
- u = self.rpc(RPC_FUNC_PKEY_CLOSE, pkey)
- u.done()
+ with self.rpc(RPC_FUNC_PKEY_CLOSE, pkey):
+ return
def pkey_delete(self, pkey):
- u = self.rpc(RPC_FUNC_PKEY_DELETE, pkey)
- u.done()
+ with self.rpc(RPC_FUNC_PKEY_DELETE, pkey):
+ return
def pkey_get_key_type(self, pkey):
- u = self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey)
- r = u.unpack_uint()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey) as r:
+ return HALKeyType.index[r.unpack_uint()]
def pkey_get_key_flags(self, pkey):
- u = self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey)
- r = u.unpack_uint()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey) as r:
+ return r.unpack_uint()
def pkey_get_public_key_len(self, pkey):
- u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey)
- r = u.unpack_uint()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey) as r:
+ return r.unpack_uint()
def pkey_get_public_key(self, pkey, length = None):
if length is None:
length = self.pkey_get_public_key_len(pkey)
- u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length)
- r = u.unpack_bytes()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length) as r:
+ return r.unpack_bytes()
def pkey_sign(self, pkey, hash = 0, data = "", length = 1024):
- u = self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length)
- r = u.unpack_bytes()
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length) as r:
+ return r.unpack_bytes()
def pkey_verify(self, pkey, hash = 0, data = "", signature = None):
- u = self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature)
- u.done()
+ with self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature):
+ return
def pkey_list(self, flags = 0, client = 0, session = 0, length = 512):
- u = self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client)
- r = tuple((u.unpack_uint(), u.unpack_uint(), u.unpack_uint(),
- uuid.UUID(bytes = u.unpack_bytes()))
- for i in xrange(u.unpack_uint()))
- u.done()
- return r
+ with self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client) as r:
+ return tuple((HALKeyType.index[r.unpack_uint()],
+ HALCurve.index[r.unpack_uint()],
+ r.unpack_uint(),
+ UUID(bytes = r.unpack_bytes()))
+ for i in xrange(r.unpack_uint()))
def pkey_match(self, type = 0, curve = 0, flags = 0, attributes = (),
- previous_uuid = uuid.UUID(int = 0), length = 512, client = 0, session = 0):
- u = self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags,
- attributes, length, previous_uuid, client = client)
- r = tuple(uuid.UUID(bytes = u.unpack_bytes())
- for i in xrange(u.unpack_uint()))
- x = uuid.UUID(bytes = u.unpack_bytes())
- u.done()
- assert len(r) == 0 or x == r[-1]
- return r
-
- def pkey_set_attribute(self, pkey, type, value):
- u = self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, type, value)
- u.done()
-
- def pkey_get_attribute(self, pkey, type):
- u = self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, type)
- r = u.unpack_bytes()
- u.done()
- return r
-
- def pkey_delete_attribute(self, pkey, type):
- u = self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, type)
- u.done()
-
+ previous_uuid = UUID(int = 0), length = 512, client = 0, session = 0):
+ with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags,
+ attributes, length, previous_uuid, client = client) as r:
+ x = tuple(UUID(bytes = r.unpack_bytes())
+ for i in xrange(r.unpack_uint()))
+ y = UUID(bytes = r.unpack_bytes())
+ assert len(x) == 0 or y == x[-1]
+ return x
+
+ def pkey_set_attribute(self, pkey, attr_type, attr_value = None):
+ if attr_value is None and isinstance(attr_type, Attribute):
+ attr_type, attr_value = attr_type.type, attr_type.attr_value
+ with self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, attr_type, attr_value):
+ return
+
+ def pkey_get_attribute(self, pkey, attr_type):
+ with self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, attr_type) as r:
+ return Attribute(attr_type, r.unpack_bytes())
+
+ def pkey_delete_attribute(self, pkey, attr_type):
+ with self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, attr_type):
+ return
if __name__ == "__main__":
@@ -583,8 +602,6 @@ if __name__ == "__main__":
k = hsm.pkey_generate_ec(HAL_CURVE_P256)
print "{0.uuid} {0.key_type} {0.key_flags} {1}".format(k, hexstr(k.public_key))
hsm.pkey_close(k)
- k = hsm.pkey_find(k.uuid)
- hsm.pkey_delete(k)
for flags in (0, HAL_KEY_FLAG_TOKEN):
for t, c, f, u in hsm.pkey_list(flags = flags):
@@ -593,3 +610,6 @@ if __name__ == "__main__":
for f in (HAL_KEY_FLAG_TOKEN, 0):
for u in hsm.pkey_match(flags = f):
print u
+
+ k = hsm.pkey_find(k.uuid)
+ hsm.pkey_delete(k)