From cc46a697de71e66e90653e3ac7fffe413acfd8c8 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Tue, 11 Apr 2017 00:14:59 -0400 Subject: API cleanup: pkey_open() and pkey_match(). pkey_open() now looks in both keystores rather than requiring the user to know. The chance of collision with randomly-generated UUID is low enough that we really ought to be able to present a single namespace. So now we do. pkey_match() now takes a couple of extra arguments which allow a single search to cover both keystores, as well as matching for specific key flags. The former interface was pretty much useless for anything involving flags, and required the user to issue a separate call for each keystore. User wheel is now exempt from the per-session key lookup constraints, Whether this is a good idea or not is an interesting question, but the whole PKCS #11 derived per-session key thing is weird to begin with, and having keystore listings on the console deliberately ignore session keys was just too confusing. --- cryptech_backup | 64 +++++-------- ecdsa.c | 2 +- hal.h | 7 +- hal_internal.h | 9 +- ks_flash.c | 4 +- ks_volatile.c | 9 +- libhal.py | 17 ++-- rpc_api.c | 13 +-- rpc_client.c | 18 ++-- rpc_pkey.c | 247 ++++++++++++++++++++++++++++++++------------------ rpc_server.c | 35 +++---- tests/test-rpc_pkey.c | 12 ++- unit-tests.py | 46 ++++++---- 13 files changed, 289 insertions(+), 194 deletions(-) diff --git a/cryptech_backup b/cryptech_backup index 7360a0d..7e465b8 100755 --- a/cryptech_backup +++ b/cryptech_backup @@ -8,22 +8,10 @@ # # Load KEKEK public <---------------- Export KEKEK public # -# { -# "kekek-uuid": "[UUID]", -# "kekek": "[Base64]" -# } -# # hal_rpc_pkey_load() # hal_rpc_pkey_export() # -# Export PKCS #8 and KEK ----------> Load PKCS #8 and KEK, import key: -# -# { -# "kekek-uuid": "[UUID]", -# "pkey": "[Base64]", -# "kek": "[Base64]" -# } -# +# Export PKCS #8 and KEK ----------> Load PKCS #8 and KEK, import key # # hal_rpc_pkey_import() @@ -125,10 +113,11 @@ def cmd_setup(args, hsm): elif not args.new: uuids.extend(hsm.pkey_match( type = HAL_KEY_TYPE_RSA_PRIVATE, + mask = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT | HAL_KEY_FLAG_TOKEN, flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT | HAL_KEY_FLAG_TOKEN)) for uuid in uuids: - with hsm.pkey_open(uuid, HAL_KEY_FLAG_TOKEN) as kekek: + with hsm.pkey_open(uuid) as kekek: if kekek.key_type != HAL_KEY_TYPE_RSA_PRIVATE: sys.stderr.write("Key {} is not an RSA private key\n".format(uuid)) elif (kekek.key_flags & HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT) == 0: @@ -179,31 +168,26 @@ def cmd_export(args, hsm): kekek = hsm.pkey_load(der = b64join(db["kekek_pubkey"]), flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT) - # What we *should* do here is a single .pkey_match() loop - # matching exactly the keys we want, but the current semantics - # of .pkey_match() are a bit confused. While that yak is - # waiting for its shave, we do this the dumb way by iterating - # over all keys then skipping the ones we don't want. - - for flags in (0, HAL_KEY_FLAG_TOKEN): - for uuid in hsm.pkey_match(flags = flags): - with hsm.pkey_open(uuid, flags) as pkey: - if (pkey.key_flags & HAL_KEY_FLAG_EXPORTABLE) == 0: - continue - if pkey.key_type in (HAL_KEY_TYPE_RSA_PRIVATE, HAL_KEY_TYPE_EC_PRIVATE): - pkcs8, kek = kekek.export_pkey(pkey) - result.append(dict( - comment = "Encrypted private key", - pkcs8 = b64(pkcs8), - kek = b64(kek), - uuid = str(pkey.uuid), - flags = pkey.key_flags)) - elif pkey.key_type in (HAL_KEY_TYPE_RSA_PUBLIC, HAL_KEY_TYPE_EC_PUBLIC): - result.append(dict( - comment = "Public key", - spki = b64(pkey.public_key), - uuid = str(pkey.uuid), - flags = pkey.key_flags)) + for uuid in hsm.pkey_match(mask = HAL_KEY_FLAG_EXPORTABLE, + flags = HAL_KEY_FLAG_EXPORTABLE): + with hsm.pkey_open(uuid) as pkey: + + if pkey.key_type in (HAL_KEY_TYPE_RSA_PRIVATE, HAL_KEY_TYPE_EC_PRIVATE): + pkcs8, kek = kekek.export_pkey(pkey) + result.append(dict( + comment = "Encrypted private key", + pkcs8 = b64(pkcs8), + kek = b64(kek), + uuid = str(pkey.uuid), + flags = pkey.key_flags)) + + elif pkey.key_type in (HAL_KEY_TYPE_RSA_PUBLIC, HAL_KEY_TYPE_EC_PUBLIC): + result.append(dict( + comment = "Public key", + spki = b64(pkey.public_key), + uuid = str(pkey.uuid), + flags = pkey.key_flags)) + finally: if kekek is not None: kekek.delete() @@ -222,7 +206,7 @@ def cmd_import(args, hsm): """ db = json.load(args.input) - with hsm.pkey_open(uuid.UUID(db["kekek_uuid"]).bytes, HAL_KEY_FLAG_TOKEN) as kekek: + with hsm.pkey_open(uuid.UUID(db["kekek_uuid"]).bytes) as kekek: for k in db["keys"]: pkcs8 = b64join(k.get("pkcs8", "")) spki = b64join(k.get("spki", "")) diff --git a/ecdsa.c b/ecdsa.c index 3fc1462..27c4c2e 100644 --- a/ecdsa.c +++ b/ecdsa.c @@ -1315,7 +1315,7 @@ hal_error_t hal_ecdsa_private_key_to_der(const hal_ecdsa_key_t * const key, NULL, hlen + vlen, NULL, der_len, der_max)) != HAL_OK) return err; - + if (der == NULL) return HAL_OK; diff --git a/hal.h b/hal.h index 4d246a3..bfb727a 100644 --- a/hal.h +++ b/hal.h @@ -358,7 +358,7 @@ extern hal_error_t hal_aes_keywrap(hal_core_t *core, extern hal_error_t hal_aes_keyunwrap(hal_core_t *core, const uint8_t *kek, const size_t kek_length, const uint8_t *ciphertext, const size_t ciphertext_length, - unsigned char *plaintext, size_t *plaintext_length); + uint8_t *plaintext, size_t *plaintext_length); extern size_t hal_aes_keywrap_ciphertext_length(const size_t plaintext_length); @@ -756,8 +756,7 @@ extern hal_error_t hal_rpc_pkey_load(const hal_client_handle_t client, extern hal_error_t hal_rpc_pkey_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags); + const hal_uuid_t * const name); extern hal_error_t hal_rpc_pkey_generate_rsa(const hal_client_handle_t client, const hal_session_handle_t session, @@ -806,9 +805,11 @@ extern hal_error_t hal_rpc_pkey_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, diff --git a/hal_internal.h b/hal_internal.h index 82d0081..1822781 100644 --- a/hal_internal.h +++ b/hal_internal.h @@ -199,8 +199,7 @@ typedef struct { hal_error_t (*open)(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags); + const hal_uuid_t * const name); hal_error_t (*generate_rsa)(const hal_client_handle_t client, const hal_session_handle_t session, @@ -249,9 +248,11 @@ typedef struct { const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, @@ -485,6 +486,7 @@ struct hal_ks_driver { const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -614,6 +616,7 @@ static inline hal_error_t hal_ks_match(hal_ks_t *ks, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -628,7 +631,7 @@ static inline hal_error_t hal_ks_match(hal_ks_t *ks, if (ks->driver->match == NULL) return HAL_ERROR_NOT_IMPLEMENTED; - return ks->driver->match(ks, client, session, type, curve, flags, attributes, attributes_len, + return ks->driver->match(ks, client, session, type, curve, mask, flags, attributes, attributes_len, result, result_len, result_max, previous_uuid); } diff --git a/ks_flash.c b/ks_flash.c index 7a87d5c..b14e568 100644 --- a/ks_flash.c +++ b/ks_flash.c @@ -1239,6 +1239,7 @@ static hal_error_t ks_match(hal_ks_t *ks, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -1284,7 +1285,8 @@ static hal_error_t ks_match(hal_ks_t *ks, if (db.ksi.names[b].chunk == 0) { memset(need_attr, 1, sizeof(need_attr)); possible = ((type == HAL_KEY_TYPE_NONE || type == block->key.type) && - (curve == HAL_CURVE_NONE || curve == block->key.curve)); + (curve == HAL_CURVE_NONE || curve == block->key.curve) && + ((flags ^ block->key.flags) & mask) == 0); } if (!possible) diff --git a/ks_volatile.c b/ks_volatile.c index 9762da3..6d65578 100644 --- a/ks_volatile.c +++ b/ks_volatile.c @@ -124,7 +124,10 @@ static inline int key_visible_to_session(const ks_t * const ksv, const hal_session_handle_t session, const ks_key_t * const k) { - return !ksv->per_session || client.handle == HAL_HANDLE_NONE || k->client.handle == client.handle; + return (!ksv->per_session || + client.handle == HAL_HANDLE_NONE || + k->client.handle == client.handle || + hal_rpc_is_logged_in(client, HAL_USER_WHEEL) == HAL_OK); } static inline void *gnaw(uint8_t **mem, size_t *len, const size_t size) @@ -385,6 +388,7 @@ static hal_error_t ks_match(hal_ks_t *ks, hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -429,6 +433,9 @@ static hal_error_t ks_match(hal_ks_t *ks, if (curve != HAL_CURVE_NONE && curve != ksv->db->keys[b].curve) continue; + if (((flags ^ ksv->db->keys[b].flags) & mask) != 0) + continue; + if (!key_visible_to_session(ksv, client, session, &ksv->db->keys[b])) continue; diff --git a/libhal.py b/libhal.py index 5a16c40..0c6b3f6 100644 --- a/libhal.py +++ b/libhal.py @@ -420,7 +420,8 @@ class HSM(object): if status != 0: raise HALError.table[status]() - def __init__(self, sockname = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", "/tmp/.cryptech_muxd.rpc")): + def __init__(self, sockname = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", + "/tmp/.cryptech_muxd.rpc")): self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.socket.connect(sockname) self.sockfile = self.socket.makefile("rb") @@ -561,8 +562,8 @@ class HSM(object): logger.debug("Loaded pkey %s", pkey.uuid) return pkey - def pkey_open(self, uuid, flags = 0, client = 0, session = 0): - with self.rpc(RPC_FUNC_PKEY_OPEN, session, uuid, flags, client = client) as r: + def pkey_open(self, uuid, client = 0, session = 0): + with self.rpc(RPC_FUNC_PKEY_OPEN, session, uuid, client = client) as r: pkey = PKey(self, r.unpack_uint(), uuid) logger.debug("Opened pkey %s", pkey.uuid) return pkey @@ -631,13 +632,15 @@ class HSM(object): with self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature): return - def pkey_match(self, type = 0, curve = 0, flags = 0, attributes = {}, - length = 64, client = 0, session = 0): + def pkey_match(self, type = 0, curve = 0, mask = 0, flags = 0, + attributes = {}, length = 64, client = 0, session = 0): u = UUID(int = 0) n = length + s = 0 while n == length: - with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags, - attributes, length, u, client = client) as r: + with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, mask, flags, + attributes, s, length, u, client = client) as r: + s = r.unpack_uint() n = r.unpack_uint() for i in xrange(n): u = UUID(bytes = r.unpack_bytes()) diff --git a/rpc_api.c b/rpc_api.c index 37ca968..1a2d268 100644 --- a/rpc_api.c +++ b/rpc_api.c @@ -230,12 +230,11 @@ hal_error_t hal_rpc_pkey_load(const hal_client_handle_t client, hal_error_t hal_rpc_pkey_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags) + const hal_uuid_t * const name) { if (pkey == NULL || name == NULL) return HAL_ERROR_BAD_ARGUMENTS; - return hal_rpc_pkey_dispatch->open(client, session, pkey, name, flags); + return hal_rpc_pkey_dispatch->open(client, session, pkey, name); } hal_error_t hal_rpc_pkey_generate_rsa(const hal_client_handle_t client, @@ -338,16 +337,18 @@ hal_error_t hal_rpc_pkey_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, const hal_uuid_t * const previous_uuid) { if ((attributes == NULL && attributes_len > 0) || previous_uuid == NULL || - result == NULL || result_len == NULL || result_max == 0) + state == NULL || result == NULL || result_len == NULL || result_max == 0) return HAL_ERROR_BAD_ARGUMENTS; if (attributes != NULL) @@ -355,9 +356,9 @@ hal_error_t hal_rpc_pkey_match(const hal_client_handle_t client, if (attributes[i].value == NULL) return HAL_ERROR_BAD_ARGUMENTS; - return hal_rpc_pkey_dispatch->match(client, session, type, curve, flags, + return hal_rpc_pkey_dispatch->match(client, session, type, curve, mask, flags, attributes, attributes_len, - result, result_len, result_max, previous_uuid); + state, result, result_len, result_max, previous_uuid); } hal_error_t hal_rpc_pkey_set_attributes(const hal_pkey_handle_t pkey, diff --git a/rpc_client.c b/rpc_client.c index e856cce..aad9edf 100644 --- a/rpc_client.c +++ b/rpc_client.c @@ -454,10 +454,9 @@ static hal_error_t pkey_remote_load(const hal_client_handle_t client, static hal_error_t pkey_remote_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags) + const hal_uuid_t * const name) { - uint8_t outbuf[nargs(5) + pad(sizeof(name->uuid))], *optr = outbuf, *olimit = outbuf + sizeof(outbuf); + uint8_t outbuf[nargs(4) + pad(sizeof(name->uuid))], *optr = outbuf, *olimit = outbuf + sizeof(outbuf); uint8_t inbuf[nargs(4)]; const uint8_t *iptr = inbuf, *ilimit = inbuf + sizeof(inbuf); hal_error_t rpc_ret; @@ -466,7 +465,6 @@ static hal_error_t pkey_remote_open(const hal_client_handle_t client, check(hal_xdr_encode_int(&optr, olimit, client.handle)); check(hal_xdr_encode_int(&optr, olimit, session.handle)); check(hal_xdr_encode_buffer(&optr, olimit, name->uuid, sizeof(name->uuid))); - check(hal_xdr_encode_int(&optr, olimit, flags)); check(hal_rpc_send(outbuf, optr - outbuf)); check(read_matching_packet(RPC_FUNC_PKEY_OPEN, inbuf, sizeof(inbuf), &iptr, &ilimit)); @@ -772,9 +770,11 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, @@ -785,9 +785,9 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, for (int i = 0; i < attributes_len; i++) attributes_buffer_len += pad(attributes[i].length); - uint8_t outbuf[nargs(9 + attributes_len * 2) + attributes_buffer_len + pad(sizeof(hal_uuid_t))]; + uint8_t outbuf[nargs(11 + attributes_len * 2) + attributes_buffer_len + pad(sizeof(hal_uuid_t))]; uint8_t *optr = outbuf, *olimit = outbuf + sizeof(outbuf); - uint8_t inbuf[nargs(4) + pad(result_max * sizeof(hal_uuid_t))]; + uint8_t inbuf[nargs(5) + pad(result_max * sizeof(hal_uuid_t))]; const uint8_t *iptr = inbuf, *ilimit = inbuf + sizeof(inbuf); hal_error_t rpc_ret; @@ -796,6 +796,7 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, check(hal_xdr_encode_int(&optr, olimit, session.handle)); check(hal_xdr_encode_int(&optr, olimit, type)); check(hal_xdr_encode_int(&optr, olimit, curve)); + check(hal_xdr_encode_int(&optr, olimit, mask)); check(hal_xdr_encode_int(&optr, olimit, flags)); check(hal_xdr_encode_int(&optr, olimit, attributes_len)); if (attributes != NULL) { @@ -804,6 +805,7 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, check(hal_xdr_encode_buffer(&optr, olimit, attributes[i].value, attributes[i].length)); } } + check(hal_xdr_encode_int(&optr, olimit, *state)); check(hal_xdr_encode_int(&optr, olimit, result_max)); check(hal_xdr_encode_buffer(&optr, olimit, previous_uuid->uuid, sizeof(previous_uuid->uuid))); check(hal_rpc_send(outbuf, optr - outbuf)); @@ -812,8 +814,10 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, check(hal_xdr_decode_int(&iptr, ilimit, &rpc_ret)); if (rpc_ret == HAL_OK) { - uint32_t array_len; + uint32_t array_len, ustate; *result_len = 0; + check(hal_xdr_decode_int(&iptr, ilimit, &ustate)); + *state = ustate; check(hal_xdr_decode_int(&iptr, ilimit, &array_len)); for (int i = 0; i < array_len; ++i) { uint32_t uuid_len = sizeof(result[i].uuid); diff --git a/rpc_pkey.c b/rpc_pkey.c index cb83b98..e0d9bdc 100644 --- a/rpc_pkey.c +++ b/rpc_pkey.c @@ -57,7 +57,10 @@ static hal_pkey_slot_t pkey_slot[HAL_STATIC_PKEY_STATE_BLOCKS]; * * The high order bit of the pkey handle is left free for * HAL_PKEY_HANDLE_TOKEN_FLAG, which is used by the mixed-mode - * handlers to route calls to the appropriate destination. + * handlers to route calls to the appropriate destination. In most + * cases this flag is set here, but pkey_local_open() also sets it + * directly, so that we can present a unified UUID namespace + * regardless of which keystore holds a particular key. */ static inline hal_pkey_slot_t *alloc_slot(const hal_key_flags_t flags) @@ -263,6 +266,44 @@ static inline hal_error_t ks_open_from_flags(hal_ks_t **ks, const hal_key_flags_ ks); } +/* + * Fetch a key from a driver. + */ + +static inline hal_error_t ks_fetch_from_driver(const hal_ks_driver_t * const driver, + hal_pkey_slot_t *slot, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + hal_ks_t *ks = NULL; + hal_error_t err; + + if ((err = hal_ks_open(driver, &ks)) != HAL_OK) + return err; + + if ((err = hal_ks_fetch(ks, slot, der, der_len, der_max)) == HAL_OK) + err = hal_ks_close(ks); + else + (void) hal_ks_close(ks); + + return err; +} + +/* + * Same thing but from key flag in slot object rather than explict driver. + */ + +static inline hal_error_t ks_fetch_from_flags(hal_pkey_slot_t *slot, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + assert(slot != NULL); + + return ks_fetch_from_driver((slot->flags & HAL_KEY_FLAG_TOKEN) == 0 + ? hal_ks_volatile_driver + : hal_ks_token_driver, + slot, der, der_len, der_max); +} + + /* * Receive key from application, generate a name (UUID), store it, and * return a key handle and the name. @@ -324,38 +365,38 @@ static hal_error_t pkey_local_load(const hal_client_handle_t client, static hal_error_t pkey_local_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags) + const hal_uuid_t * const name) { assert(pkey != NULL && name != NULL); hal_pkey_slot_t *slot; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = check_readable(client, flags)) != HAL_OK) + if ((err = check_readable(client, 0)) != HAL_OK) return err; - if ((slot = alloc_slot(flags)) == NULL) + if ((slot = alloc_slot(0)) == NULL) return HAL_ERROR_NO_KEY_SLOTS_AVAILABLE; slot->name = *name; slot->client_handle = client; slot->session_handle = session; - if ((err = ks_open_from_flags(&ks, flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, NULL, NULL, 0)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); + if ((err = ks_fetch_from_driver(hal_ks_token_driver, slot, NULL, NULL, 0)) == HAL_OK) + slot->pkey_handle.handle |= HAL_PKEY_HANDLE_TOKEN_FLAG; - if (err != HAL_OK) { - slot->type = HAL_KEY_TYPE_NONE; - return err; - } + else if (err == HAL_ERROR_KEY_NOT_FOUND) + err = ks_fetch_from_driver(hal_ks_volatile_driver, slot, NULL, NULL, 0); + + if (err != HAL_OK) + goto fail; *pkey = slot->pkey_handle; return HAL_OK; + + fail: + memset(slot, 0, sizeof(*slot)); + return err; } /* @@ -608,16 +649,9 @@ static size_t pkey_local_get_public_key_len(const hal_pkey_handle_t pkey) hal_ecdsa_key_t *ecdsa_key = NULL; uint8_t der[HAL_KS_WRAPPED_KEYSIZE]; size_t der_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) { + if ((err = ks_fetch_from_flags(slot, der, &der_len, sizeof(der))) == HAL_OK) { switch (slot->type) { case HAL_KEY_TYPE_RSA_PUBLIC: @@ -658,21 +692,15 @@ static hal_error_t pkey_local_get_public_key(const hal_pkey_handle_t pkey, if (slot == NULL) return HAL_ERROR_KEY_NOT_FOUND; - uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; + uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size + ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; hal_rsa_key_t *rsa_key = NULL; hal_ecdsa_key_t *ecdsa_key = NULL; uint8_t buf[HAL_KS_WRAPPED_KEYSIZE]; size_t buf_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, buf, &buf_len, sizeof(buf))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) { + if ((err = ks_fetch_from_flags(slot, buf, &buf_len, sizeof(buf))) == HAL_OK) { switch (slot->type) { case HAL_KEY_TYPE_RSA_PUBLIC: @@ -813,20 +841,15 @@ static hal_error_t pkey_local_sign(const hal_pkey_handle_t pkey, if ((slot->flags & HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) == 0) return HAL_ERROR_FORBIDDEN; - uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; + uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size + ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; uint8_t der[HAL_KS_WRAPPED_KEYSIZE]; size_t der_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) - err = signer(keybuf, sizeof(keybuf), der, der_len, hash, input, input_len, signature, signature_len, signature_max); + if ((err = ks_fetch_from_flags(slot, der, &der_len, sizeof(der))) == HAL_OK) + err = signer(keybuf, sizeof(keybuf), der, der_len, hash, input, input_len, + signature, signature_len, signature_max); memset(keybuf, 0, sizeof(keybuf)); memset(der, 0, sizeof(der)); @@ -964,20 +987,15 @@ static hal_error_t pkey_local_verify(const hal_pkey_handle_t pkey, if ((slot->flags & HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) == 0) return HAL_ERROR_FORBIDDEN; - uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; + uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size + ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; uint8_t der[HAL_KS_WRAPPED_KEYSIZE]; size_t der_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) - err = verifier(keybuf, sizeof(keybuf), slot->type, der, der_len, hash, input, input_len, signature, signature_len); + if ((err = ks_fetch_from_flags(slot, der, &der_len, sizeof(der))) == HAL_OK) + err = verifier(keybuf, sizeof(keybuf), slot->type, der, der_len, hash, + input, input_len, signature, signature_len); memset(keybuf, 0, sizeof(keybuf)); memset(der, 0, sizeof(der)); @@ -985,40 +1003,111 @@ static hal_error_t pkey_local_verify(const hal_pkey_handle_t pkey, return err; } +static inline hal_error_t match_one_keystore(const hal_ks_driver_t * const driver, + const hal_client_handle_t client, + const hal_session_handle_t session, + const hal_key_type_t type, + const hal_curve_name_t curve, + const hal_key_flags_t mask, + const hal_key_flags_t flags, + const hal_pkey_attribute_t *attributes, + const unsigned attributes_len, + hal_uuid_t **result, + unsigned *result_len, + const unsigned result_max, + const hal_uuid_t * const previous_uuid) +{ + hal_ks_t *ks = NULL; + hal_error_t err; + unsigned len; + + if ((err = hal_ks_open(driver, &ks)) != HAL_OK) + return err; + + if ((err = hal_ks_match(ks, client, session, type, curve, + mask, flags, attributes, attributes_len, + *result, &len, result_max - *result_len, + previous_uuid)) != HAL_OK) { + (void) hal_ks_close(ks); + return err; + } + + if ((err = hal_ks_close(ks)) != HAL_OK) + return err; + + *result += len; + *result_len += len; + + return HAL_OK; +} + +typedef enum { + MATCH_STATE_START, + MATCH_STATE_TOKEN, + MATCH_STATE_VOLATILE, + MATCH_STATE_DONE +} match_state_t; + static hal_error_t pkey_local_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, const hal_uuid_t * const previous_uuid) { - hal_ks_t *ks = NULL; + assert(state != NULL && result_len != NULL); + + static const hal_uuid_t uuid_zero[1] = {{{0}}}; + const hal_uuid_t *prev = previous_uuid; hal_error_t err; - err = check_readable(client, flags); + *result_len = 0; - if (err == HAL_ERROR_FORBIDDEN) { - assert(result_len != NULL); - *result_len = 0; + if ((err = check_readable(client, flags)) == HAL_ERROR_FORBIDDEN) return HAL_OK; - } - - if (err != HAL_OK) + else if (err != HAL_OK) return err; - if ((err = ks_open_from_flags(&ks, flags)) == HAL_OK && - (err = hal_ks_match(ks, client, session, type, curve, flags, attributes, attributes_len, - result, result_len, result_max, previous_uuid)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); + switch ((match_state_t) *state) { - return err; + case MATCH_STATE_START: + prev = uuid_zero; + ++*state; + + case MATCH_STATE_TOKEN: + if (((mask & HAL_KEY_FLAG_TOKEN) == 0 || (mask & flags & HAL_KEY_FLAG_TOKEN) != 0) && + (err = match_one_keystore(hal_ks_token_driver, client, session, type, curve, + mask, flags, attributes, attributes_len, + &result, result_len, result_max - *result_len, prev)) != HAL_OK) + return err; + if (*result_len == result_max) + return HAL_OK; + prev = uuid_zero; + ++*state; + + case MATCH_STATE_VOLATILE: + if (((mask & HAL_KEY_FLAG_TOKEN) == 0 || (mask & flags & HAL_KEY_FLAG_TOKEN) == 0) && + (err = match_one_keystore(hal_ks_volatile_driver, client, session, type, curve, + mask, flags, attributes, attributes_len, + &result, result_len, result_max - *result_len, prev)) != HAL_OK) + return err; + if (*result_len == result_max) + return HAL_OK; + ++*state; + + case MATCH_STATE_DONE: + return HAL_OK; + + default: + return HAL_ERROR_BAD_ARGUMENTS; + } } static hal_error_t pkey_local_set_attributes(const hal_pkey_handle_t pkey, @@ -1078,7 +1167,6 @@ static hal_error_t pkey_local_export(const hal_pkey_handle_t pkey_handle, uint8_t rsabuf[hal_rsa_key_t_size]; hal_rsa_key_t *rsa = NULL; - hal_ks_t *ks = NULL; hal_error_t err; size_t len; @@ -1100,12 +1188,7 @@ static hal_error_t pkey_local_export(const hal_pkey_handle_t pkey_handle, if (pkcs8_max < HAL_KS_WRAPPED_KEYSIZE) return HAL_ERROR_RESULT_TOO_LONG; - if ((err = ks_open_from_flags(&ks, kekek->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, kekek, pkcs8, &len, pkcs8_max)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - if (err != HAL_OK) + if ((err = ks_fetch_from_flags(kekek, pkcs8, &len, pkcs8_max)) != HAL_OK) goto fail; switch (kekek->type) { @@ -1129,13 +1212,7 @@ static hal_error_t pkey_local_export(const hal_pkey_handle_t pkey_handle, goto fail; } - if ((err = ks_open_from_flags(&ks, pkey->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, pkey, pkcs8, &len, pkcs8_max)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err != HAL_OK) + if ((err = ks_fetch_from_flags(pkey, pkcs8, &len, pkcs8_max)) != HAL_OK) goto fail; if ((err = hal_get_random(NULL, kek, KEK_LENGTH)) != HAL_OK) @@ -1189,7 +1266,6 @@ static hal_error_t pkey_local_import(const hal_client_handle_t client, size_t der_len, oid_len, data_len; const uint8_t *oid, *data; hal_rsa_key_t *rsa = NULL; - hal_ks_t *ks = NULL; hal_error_t err; hal_pkey_slot_t * const kekek = find_handle(kekek_handle); @@ -1203,12 +1279,7 @@ static hal_error_t pkey_local_import(const hal_client_handle_t client, if (kekek->type != HAL_KEY_TYPE_RSA_PRIVATE) return HAL_ERROR_UNSUPPORTED_KEY; - if ((err = ks_open_from_flags(&ks, kekek->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, kekek, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - if (err != HAL_OK) + if ((err = ks_fetch_from_flags(kekek, der, &der_len, sizeof(der))) != HAL_OK) goto fail; if ((err = hal_rsa_private_key_from_der(&rsa, rsabuf, sizeof(rsabuf), der, der_len)) != HAL_OK) diff --git a/rpc_server.c b/rpc_server.c index 55f15fe..d946b06 100644 --- a/rpc_server.c +++ b/rpc_server.c @@ -356,19 +356,17 @@ static hal_error_t pkey_open(const uint8_t **iptr, const uint8_t * const ilimit, hal_pkey_handle_t pkey; const uint8_t *name_ptr; uint32_t name_len; - hal_key_flags_t flags; hal_error_t ret; check(hal_xdr_decode_int(iptr, ilimit, &client.handle)); check(hal_xdr_decode_int(iptr, ilimit, &session.handle)); check(hal_xdr_decode_buffer_in_place(iptr, ilimit, &name_ptr, &name_len)); - check(hal_xdr_decode_int(iptr, ilimit, &flags)); if (name_len != sizeof(hal_uuid_t)) return HAL_ERROR_KEY_NAME_TOO_LONG; /* call the local function */ - ret = hal_rpc_pkey_open(client, session, &pkey, (const hal_uuid_t *) name_ptr, flags); + ret = hal_rpc_pkey_open(client, session, &pkey, (const hal_uuid_t *) name_ptr); if (ret == HAL_OK) check(hal_xdr_encode_int(optr, olimit, pkey.handle)); @@ -643,15 +641,17 @@ static hal_error_t pkey_match(const uint8_t **iptr, const uint8_t * const ilimit { hal_client_handle_t client; hal_session_handle_t session; - uint32_t type, curve, attributes_len, result_max, previous_uuid_len; + uint32_t type, curve, attributes_len, state, result_max, previous_uuid_len; const uint8_t *previous_uuid_ptr; - hal_key_flags_t flags; + hal_key_flags_t mask, flags; + uint8_t *optr_orig = *optr; hal_error_t ret; check(hal_xdr_decode_int(iptr, ilimit, &client.handle)); check(hal_xdr_decode_int(iptr, ilimit, &session.handle)); check(hal_xdr_decode_int(iptr, ilimit, &type)); check(hal_xdr_decode_int(iptr, ilimit, &curve)); + check(hal_xdr_decode_int(iptr, ilimit, &mask)); check(hal_xdr_decode_int(iptr, ilimit, &flags)); check(hal_xdr_decode_int(iptr, ilimit, &attributes_len)); @@ -667,6 +667,7 @@ static hal_error_t pkey_match(const uint8_t **iptr, const uint8_t * const ilimit a->length = value_len; } + check(hal_xdr_decode_int(iptr, ilimit, &state)); check(hal_xdr_decode_int(iptr, ilimit, &result_max)); check(hal_xdr_decode_buffer_in_place(iptr, ilimit, &previous_uuid_ptr, &previous_uuid_len)); @@ -676,22 +677,24 @@ static hal_error_t pkey_match(const uint8_t **iptr, const uint8_t * const ilimit const hal_uuid_t * const previous_uuid = (const void *) previous_uuid_ptr; hal_uuid_t result[result_max]; - unsigned result_len; + unsigned result_len, ustate = state; - ret = hal_rpc_pkey_match(client, session, type, curve, flags, + ret = hal_rpc_pkey_match(client, session, type, curve, mask, flags, attributes, attributes_len, - result, &result_len, result_max, + &ustate, result, &result_len, result_max, previous_uuid); - if (ret == HAL_OK) { - uint8_t *optr_orig = *optr; + if (ret == HAL_OK) + ret = hal_xdr_encode_int(optr, olimit, ustate); + + if (ret == HAL_OK) ret = hal_xdr_encode_int(optr, olimit, result_len); - for (int i = 0; ret == HAL_OK && i < result_len; ++i) - ret = hal_xdr_encode_buffer(optr, olimit, result[i].uuid, - sizeof(result[i].uuid)); - if (ret != HAL_OK) - *optr = optr_orig; - } + + for (int i = 0; ret == HAL_OK && i < result_len; ++i) + ret = hal_xdr_encode_buffer(optr, olimit, result[i].uuid, + sizeof(result[i].uuid)); + if (ret != HAL_OK) + *optr = optr_orig; return ret; } diff --git a/tests/test-rpc_pkey.c b/tests/test-rpc_pkey.c index 1b5f86a..1f00fb8 100644 --- a/tests/test-rpc_pkey.c +++ b/tests/test-rpc_pkey.c @@ -98,25 +98,27 @@ static int test_attributes(const hal_pkey_handle_t pkey, const hal_client_handle_t client = {HAL_HANDLE_NONE}; const hal_session_handle_t session = {HAL_HANDLE_NONE}; hal_uuid_t result[10], previous_uuid = {{0}}; - unsigned result_len; + unsigned result_len, state; - if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, flags, NULL, 0, - result, &result_len, sizeof(result)/sizeof(*result), + state = 0; + if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, 0, 0, NULL, 0, + &state, result, &result_len, sizeof(result)/sizeof(*result), &previous_uuid)) != HAL_OK) lose("Unrestricted match() failed: %s\n", hal_error_string(err)); if (result_len == 0) lose("Unrestricted match found no results\n"); + state = 0; for (const size_t *size = sizes; *size; size++) { uint8_t buf[*size]; memset(buf, 0x55, sizeof(buf)); snprintf((char *) buf, sizeof(buf), format, (unsigned long) *size); hal_pkey_attribute_t attribute[1] = {{ *size, sizeof(buf), buf }}; - if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, flags, + if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, 0, 0, attribute, sizeof(attribute)/sizeof(*attribute), - result, &result_len, sizeof(result)/sizeof(*result), + &state, result, &result_len, sizeof(result)/sizeof(*result), &previous_uuid)) != HAL_OK) lose("Restricted match() for attribute %lu failed: %s\n", (unsigned long) *size, hal_error_string(err)); diff --git a/unit-tests.py b/unit-tests.py index d5b6c77..295bf40 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -550,43 +550,54 @@ class TestPKeyMatch(TestCaseLoggedIn): Tests involving PKey list and match functions. """ + @staticmethod + def key_flag_names(flags): + names = dict(digitalsignature = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE, + keyencipherment = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT, + dataencipherment = HAL_KEY_FLAG_USAGE_DATAENCIPHERMENT, + token = HAL_KEY_FLAG_TOKEN, + public = HAL_KEY_FLAG_PUBLIC, + exportable = HAL_KEY_FLAG_EXPORTABLE) + return ", ".join(sorted(k for k, v in names.iteritems() if (flags & v) != 0)) + def load_keys(self, flags): uuids = set() for obj in PreloadedKey.db.itervalues(): with hsm.pkey_load(obj.der, flags) as k: - self.addCleanup(lambda uuid: hsm.pkey_open(uuid, flags = flags).delete(), k.uuid) + self.addCleanup(lambda uuid: hsm.pkey_open(uuid).delete(), k.uuid) uuids.add(k.uuid) + #print k.uuid, k.key_type, k.key_curve, self.key_flag_names(k.key_flags) k.set_attributes(dict((i, a) for i, a in enumerate((str(obj.keytype), str(obj.fn2))))) return uuids - def match(self, flags, **kwargs): - uuids = kwargs.pop("uuids", None) - kwargs.update(flags = flags) + def match(self, uuids, **kwargs): n = 0 for uuid in hsm.pkey_match(**kwargs): - if uuids is None or uuid in uuids: - with hsm.pkey_open(uuid, flags) as k: + if uuid in uuids: + with hsm.pkey_open(uuid) as k: n += 1 yield n, k - def ks_match(self, flags): + def ks_match(self, mask, flags): tags = [] uuids = set() for i in xrange(2): - uuids |= self.load_keys(flags) + uuids |= self.load_keys(flags if mask else HAL_KEY_FLAG_TOKEN * i) tags.extend(PreloadedKey.db) self.assertEqual(len(tags), len(uuids)) - self.assertEqual(uuids, set(k.uuid for n, k in self.match(flags = flags, uuids = uuids))) + self.assertEqual(uuids, set(k.uuid for n, k in self.match(mask = mask, + flags = flags, + uuids = uuids))) for keytype in set(HALKeyType.index.itervalues()) - {HAL_KEY_TYPE_NONE}: - for n, k in self.match(flags = flags, uuids = uuids, type = keytype): + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = keytype): self.assertEqual(k.key_type, keytype) self.assertEqual(k.get_attributes({0}).pop(0), str(keytype)) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype)) for curve in set(HALCurve.index.itervalues()) - {HAL_CURVE_NONE}: - for n, k in self.match(flags = flags, uuids = uuids, curve = curve): + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, curve = curve): self.assertEqual(k.key_curve, curve) self.assertEqual(k.get_attributes({1}).pop(1), str(curve)) self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC, @@ -594,7 +605,7 @@ class TestPKeyMatch(TestCaseLoggedIn): self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve)) for keylen in set(kl for kt, kl in tags if not isinstance(kl, Enum)): - for n, k in self.match(flags = flags, uuids = uuids, + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, attributes = {1 : str(keylen)}): self.assertEqual(keylen, int(k.get_attributes({1}).pop(1))) self.assertIn(k.key_type, (HAL_KEY_TYPE_RSA_PUBLIC, @@ -602,17 +613,20 @@ class TestPKeyMatch(TestCaseLoggedIn): self.assertEqual(n, sum(1 for t1, t2 in tags if not isinstance(t2, Enum) and t2 == keylen)) - for n, k in self.match(flags = flags, uuids = uuids, + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}): self.assertEqual(k.key_type, HAL_KEY_TYPE_RSA_PUBLIC) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == HAL_KEY_TYPE_RSA_PUBLIC and t2 == 2048)) def test_ks_match_token(self): - self.ks_match(HAL_KEY_FLAG_TOKEN) + self.ks_match(mask = HAL_KEY_FLAG_TOKEN, flags = HAL_KEY_FLAG_TOKEN) def test_ks_match_volatile(self): - self.ks_match(0) + self.ks_match(mask = HAL_KEY_FLAG_TOKEN, flags = 0) + + def test_ks_match_all(self): + self.ks_match(mask = 0, flags = 0) class TestPKeyAttribute(TestCaseLoggedIn): @@ -626,7 +640,7 @@ class TestPKeyAttribute(TestCaseLoggedIn): for obj in PreloadedKey.db.itervalues(): with hsm.pkey_load(obj.der, flags) as k: pinwheel() - self.addCleanup(lambda uuid: hsm.pkey_open(uuid, flags = flags).delete(), k.uuid) + self.addCleanup(lambda uuid: hsm.pkey_open(uuid).delete(), k.uuid) k.set_attributes(dict((j, "Attribute {}{}".format(j, "*" * n_fill)) for j in xrange(n_attrs))) pinwheel() -- cgit v1.2.3