diff options
-rwxr-xr-x | cryptech_backup | 64 | ||||
-rw-r--r-- | ecdsa.c | 2 | ||||
-rw-r--r-- | hal.h | 7 | ||||
-rw-r--r-- | hal_internal.h | 9 | ||||
-rw-r--r-- | ks_flash.c | 4 | ||||
-rw-r--r-- | ks_volatile.c | 9 | ||||
-rw-r--r-- | libhal.py | 17 | ||||
-rw-r--r-- | rpc_api.c | 13 | ||||
-rw-r--r-- | rpc_client.c | 18 | ||||
-rw-r--r-- | rpc_pkey.c | 247 | ||||
-rw-r--r-- | rpc_server.c | 35 | ||||
-rw-r--r-- | tests/test-rpc_pkey.c | 12 | ||||
-rw-r--r-- | unit-tests.py | 46 |
13 files changed, 289 insertions, 194 deletions
diff --git a/cryptech_backup b/cryptech_backup index 7360a0d..7e465b8 100755 --- a/cryptech_backup +++ b/cryptech_backup @@ -8,22 +8,10 @@ # # Load KEKEK public <---------------- Export KEKEK public # -# { -# "kekek-uuid": "[UUID]", -# "kekek": "[Base64]" -# } -# # hal_rpc_pkey_load() # hal_rpc_pkey_export() # -# Export PKCS #8 and KEK ----------> Load PKCS #8 and KEK, import key: -# -# { -# "kekek-uuid": "[UUID]", -# "pkey": "[Base64]", -# "kek": "[Base64]" -# } -# +# Export PKCS #8 and KEK ----------> Load PKCS #8 and KEK, import key # # hal_rpc_pkey_import() @@ -125,10 +113,11 @@ def cmd_setup(args, hsm): elif not args.new: uuids.extend(hsm.pkey_match( type = HAL_KEY_TYPE_RSA_PRIVATE, + mask = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT | HAL_KEY_FLAG_TOKEN, flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT | HAL_KEY_FLAG_TOKEN)) for uuid in uuids: - with hsm.pkey_open(uuid, HAL_KEY_FLAG_TOKEN) as kekek: + with hsm.pkey_open(uuid) as kekek: if kekek.key_type != HAL_KEY_TYPE_RSA_PRIVATE: sys.stderr.write("Key {} is not an RSA private key\n".format(uuid)) elif (kekek.key_flags & HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT) == 0: @@ -179,31 +168,26 @@ def cmd_export(args, hsm): kekek = hsm.pkey_load(der = b64join(db["kekek_pubkey"]), flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT) - # What we *should* do here is a single .pkey_match() loop - # matching exactly the keys we want, but the current semantics - # of .pkey_match() are a bit confused. While that yak is - # waiting for its shave, we do this the dumb way by iterating - # over all keys then skipping the ones we don't want. - - for flags in (0, HAL_KEY_FLAG_TOKEN): - for uuid in hsm.pkey_match(flags = flags): - with hsm.pkey_open(uuid, flags) as pkey: - if (pkey.key_flags & HAL_KEY_FLAG_EXPORTABLE) == 0: - continue - if pkey.key_type in (HAL_KEY_TYPE_RSA_PRIVATE, HAL_KEY_TYPE_EC_PRIVATE): - pkcs8, kek = kekek.export_pkey(pkey) - result.append(dict( - comment = "Encrypted private key", - pkcs8 = b64(pkcs8), - kek = b64(kek), - uuid = str(pkey.uuid), - flags = pkey.key_flags)) - elif pkey.key_type in (HAL_KEY_TYPE_RSA_PUBLIC, HAL_KEY_TYPE_EC_PUBLIC): - result.append(dict( - comment = "Public key", - spki = b64(pkey.public_key), - uuid = str(pkey.uuid), - flags = pkey.key_flags)) + for uuid in hsm.pkey_match(mask = HAL_KEY_FLAG_EXPORTABLE, + flags = HAL_KEY_FLAG_EXPORTABLE): + with hsm.pkey_open(uuid) as pkey: + + if pkey.key_type in (HAL_KEY_TYPE_RSA_PRIVATE, HAL_KEY_TYPE_EC_PRIVATE): + pkcs8, kek = kekek.export_pkey(pkey) + result.append(dict( + comment = "Encrypted private key", + pkcs8 = b64(pkcs8), + kek = b64(kek), + uuid = str(pkey.uuid), + flags = pkey.key_flags)) + + elif pkey.key_type in (HAL_KEY_TYPE_RSA_PUBLIC, HAL_KEY_TYPE_EC_PUBLIC): + result.append(dict( + comment = "Public key", + spki = b64(pkey.public_key), + uuid = str(pkey.uuid), + flags = pkey.key_flags)) + finally: if kekek is not None: kekek.delete() @@ -222,7 +206,7 @@ def cmd_import(args, hsm): """ db = json.load(args.input) - with hsm.pkey_open(uuid.UUID(db["kekek_uuid"]).bytes, HAL_KEY_FLAG_TOKEN) as kekek: + with hsm.pkey_open(uuid.UUID(db["kekek_uuid"]).bytes) as kekek: for k in db["keys"]: pkcs8 = b64join(k.get("pkcs8", "")) spki = b64join(k.get("spki", "")) @@ -1315,7 +1315,7 @@ hal_error_t hal_ecdsa_private_key_to_der(const hal_ecdsa_key_t * const key, NULL, hlen + vlen, NULL, der_len, der_max)) != HAL_OK) return err; - + if (der == NULL) return HAL_OK; @@ -358,7 +358,7 @@ extern hal_error_t hal_aes_keywrap(hal_core_t *core, extern hal_error_t hal_aes_keyunwrap(hal_core_t *core, const uint8_t *kek, const size_t kek_length, const uint8_t *ciphertext, const size_t ciphertext_length, - unsigned char *plaintext, size_t *plaintext_length); + uint8_t *plaintext, size_t *plaintext_length); extern size_t hal_aes_keywrap_ciphertext_length(const size_t plaintext_length); @@ -756,8 +756,7 @@ extern hal_error_t hal_rpc_pkey_load(const hal_client_handle_t client, extern hal_error_t hal_rpc_pkey_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags); + const hal_uuid_t * const name); extern hal_error_t hal_rpc_pkey_generate_rsa(const hal_client_handle_t client, const hal_session_handle_t session, @@ -806,9 +805,11 @@ extern hal_error_t hal_rpc_pkey_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, diff --git a/hal_internal.h b/hal_internal.h index 82d0081..1822781 100644 --- a/hal_internal.h +++ b/hal_internal.h @@ -199,8 +199,7 @@ typedef struct { hal_error_t (*open)(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags); + const hal_uuid_t * const name); hal_error_t (*generate_rsa)(const hal_client_handle_t client, const hal_session_handle_t session, @@ -249,9 +248,11 @@ typedef struct { const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, @@ -485,6 +486,7 @@ struct hal_ks_driver { const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -614,6 +616,7 @@ static inline hal_error_t hal_ks_match(hal_ks_t *ks, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -628,7 +631,7 @@ static inline hal_error_t hal_ks_match(hal_ks_t *ks, if (ks->driver->match == NULL) return HAL_ERROR_NOT_IMPLEMENTED; - return ks->driver->match(ks, client, session, type, curve, flags, attributes, attributes_len, + return ks->driver->match(ks, client, session, type, curve, mask, flags, attributes, attributes_len, result, result_len, result_max, previous_uuid); } @@ -1239,6 +1239,7 @@ static hal_error_t ks_match(hal_ks_t *ks, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -1284,7 +1285,8 @@ static hal_error_t ks_match(hal_ks_t *ks, if (db.ksi.names[b].chunk == 0) { memset(need_attr, 1, sizeof(need_attr)); possible = ((type == HAL_KEY_TYPE_NONE || type == block->key.type) && - (curve == HAL_CURVE_NONE || curve == block->key.curve)); + (curve == HAL_CURVE_NONE || curve == block->key.curve) && + ((flags ^ block->key.flags) & mask) == 0); } if (!possible) diff --git a/ks_volatile.c b/ks_volatile.c index 9762da3..6d65578 100644 --- a/ks_volatile.c +++ b/ks_volatile.c @@ -124,7 +124,10 @@ static inline int key_visible_to_session(const ks_t * const ksv, const hal_session_handle_t session, const ks_key_t * const k) { - return !ksv->per_session || client.handle == HAL_HANDLE_NONE || k->client.handle == client.handle; + return (!ksv->per_session || + client.handle == HAL_HANDLE_NONE || + k->client.handle == client.handle || + hal_rpc_is_logged_in(client, HAL_USER_WHEEL) == HAL_OK); } static inline void *gnaw(uint8_t **mem, size_t *len, const size_t size) @@ -385,6 +388,7 @@ static hal_error_t ks_match(hal_ks_t *ks, hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, @@ -429,6 +433,9 @@ static hal_error_t ks_match(hal_ks_t *ks, if (curve != HAL_CURVE_NONE && curve != ksv->db->keys[b].curve) continue; + if (((flags ^ ksv->db->keys[b].flags) & mask) != 0) + continue; + if (!key_visible_to_session(ksv, client, session, &ksv->db->keys[b])) continue; @@ -420,7 +420,8 @@ class HSM(object): if status != 0: raise HALError.table[status]() - def __init__(self, sockname = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", "/tmp/.cryptech_muxd.rpc")): + def __init__(self, sockname = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", + "/tmp/.cryptech_muxd.rpc")): self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.socket.connect(sockname) self.sockfile = self.socket.makefile("rb") @@ -561,8 +562,8 @@ class HSM(object): logger.debug("Loaded pkey %s", pkey.uuid) return pkey - def pkey_open(self, uuid, flags = 0, client = 0, session = 0): - with self.rpc(RPC_FUNC_PKEY_OPEN, session, uuid, flags, client = client) as r: + def pkey_open(self, uuid, client = 0, session = 0): + with self.rpc(RPC_FUNC_PKEY_OPEN, session, uuid, client = client) as r: pkey = PKey(self, r.unpack_uint(), uuid) logger.debug("Opened pkey %s", pkey.uuid) return pkey @@ -631,13 +632,15 @@ class HSM(object): with self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature): return - def pkey_match(self, type = 0, curve = 0, flags = 0, attributes = {}, - length = 64, client = 0, session = 0): + def pkey_match(self, type = 0, curve = 0, mask = 0, flags = 0, + attributes = {}, length = 64, client = 0, session = 0): u = UUID(int = 0) n = length + s = 0 while n == length: - with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags, - attributes, length, u, client = client) as r: + with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, mask, flags, + attributes, s, length, u, client = client) as r: + s = r.unpack_uint() n = r.unpack_uint() for i in xrange(n): u = UUID(bytes = r.unpack_bytes()) @@ -230,12 +230,11 @@ hal_error_t hal_rpc_pkey_load(const hal_client_handle_t client, hal_error_t hal_rpc_pkey_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags) + const hal_uuid_t * const name) { if (pkey == NULL || name == NULL) return HAL_ERROR_BAD_ARGUMENTS; - return hal_rpc_pkey_dispatch->open(client, session, pkey, name, flags); + return hal_rpc_pkey_dispatch->open(client, session, pkey, name); } hal_error_t hal_rpc_pkey_generate_rsa(const hal_client_handle_t client, @@ -338,16 +337,18 @@ hal_error_t hal_rpc_pkey_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, const hal_uuid_t * const previous_uuid) { if ((attributes == NULL && attributes_len > 0) || previous_uuid == NULL || - result == NULL || result_len == NULL || result_max == 0) + state == NULL || result == NULL || result_len == NULL || result_max == 0) return HAL_ERROR_BAD_ARGUMENTS; if (attributes != NULL) @@ -355,9 +356,9 @@ hal_error_t hal_rpc_pkey_match(const hal_client_handle_t client, if (attributes[i].value == NULL) return HAL_ERROR_BAD_ARGUMENTS; - return hal_rpc_pkey_dispatch->match(client, session, type, curve, flags, + return hal_rpc_pkey_dispatch->match(client, session, type, curve, mask, flags, attributes, attributes_len, - result, result_len, result_max, previous_uuid); + state, result, result_len, result_max, previous_uuid); } hal_error_t hal_rpc_pkey_set_attributes(const hal_pkey_handle_t pkey, diff --git a/rpc_client.c b/rpc_client.c index e856cce..aad9edf 100644 --- a/rpc_client.c +++ b/rpc_client.c @@ -454,10 +454,9 @@ static hal_error_t pkey_remote_load(const hal_client_handle_t client, static hal_error_t pkey_remote_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags) + const hal_uuid_t * const name) { - uint8_t outbuf[nargs(5) + pad(sizeof(name->uuid))], *optr = outbuf, *olimit = outbuf + sizeof(outbuf); + uint8_t outbuf[nargs(4) + pad(sizeof(name->uuid))], *optr = outbuf, *olimit = outbuf + sizeof(outbuf); uint8_t inbuf[nargs(4)]; const uint8_t *iptr = inbuf, *ilimit = inbuf + sizeof(inbuf); hal_error_t rpc_ret; @@ -466,7 +465,6 @@ static hal_error_t pkey_remote_open(const hal_client_handle_t client, check(hal_xdr_encode_int(&optr, olimit, client.handle)); check(hal_xdr_encode_int(&optr, olimit, session.handle)); check(hal_xdr_encode_buffer(&optr, olimit, name->uuid, sizeof(name->uuid))); - check(hal_xdr_encode_int(&optr, olimit, flags)); check(hal_rpc_send(outbuf, optr - outbuf)); check(read_matching_packet(RPC_FUNC_PKEY_OPEN, inbuf, sizeof(inbuf), &iptr, &ilimit)); @@ -772,9 +770,11 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, @@ -785,9 +785,9 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, for (int i = 0; i < attributes_len; i++) attributes_buffer_len += pad(attributes[i].length); - uint8_t outbuf[nargs(9 + attributes_len * 2) + attributes_buffer_len + pad(sizeof(hal_uuid_t))]; + uint8_t outbuf[nargs(11 + attributes_len * 2) + attributes_buffer_len + pad(sizeof(hal_uuid_t))]; uint8_t *optr = outbuf, *olimit = outbuf + sizeof(outbuf); - uint8_t inbuf[nargs(4) + pad(result_max * sizeof(hal_uuid_t))]; + uint8_t inbuf[nargs(5) + pad(result_max * sizeof(hal_uuid_t))]; const uint8_t *iptr = inbuf, *ilimit = inbuf + sizeof(inbuf); hal_error_t rpc_ret; @@ -796,6 +796,7 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, check(hal_xdr_encode_int(&optr, olimit, session.handle)); check(hal_xdr_encode_int(&optr, olimit, type)); check(hal_xdr_encode_int(&optr, olimit, curve)); + check(hal_xdr_encode_int(&optr, olimit, mask)); check(hal_xdr_encode_int(&optr, olimit, flags)); check(hal_xdr_encode_int(&optr, olimit, attributes_len)); if (attributes != NULL) { @@ -804,6 +805,7 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, check(hal_xdr_encode_buffer(&optr, olimit, attributes[i].value, attributes[i].length)); } } + check(hal_xdr_encode_int(&optr, olimit, *state)); check(hal_xdr_encode_int(&optr, olimit, result_max)); check(hal_xdr_encode_buffer(&optr, olimit, previous_uuid->uuid, sizeof(previous_uuid->uuid))); check(hal_rpc_send(outbuf, optr - outbuf)); @@ -812,8 +814,10 @@ static hal_error_t pkey_remote_match(const hal_client_handle_t client, check(hal_xdr_decode_int(&iptr, ilimit, &rpc_ret)); if (rpc_ret == HAL_OK) { - uint32_t array_len; + uint32_t array_len, ustate; *result_len = 0; + check(hal_xdr_decode_int(&iptr, ilimit, &ustate)); + *state = ustate; check(hal_xdr_decode_int(&iptr, ilimit, &array_len)); for (int i = 0; i < array_len; ++i) { uint32_t uuid_len = sizeof(result[i].uuid); @@ -57,7 +57,10 @@ static hal_pkey_slot_t pkey_slot[HAL_STATIC_PKEY_STATE_BLOCKS]; * * The high order bit of the pkey handle is left free for * HAL_PKEY_HANDLE_TOKEN_FLAG, which is used by the mixed-mode - * handlers to route calls to the appropriate destination. + * handlers to route calls to the appropriate destination. In most + * cases this flag is set here, but pkey_local_open() also sets it + * directly, so that we can present a unified UUID namespace + * regardless of which keystore holds a particular key. */ static inline hal_pkey_slot_t *alloc_slot(const hal_key_flags_t flags) @@ -264,6 +267,44 @@ static inline hal_error_t ks_open_from_flags(hal_ks_t **ks, const hal_key_flags_ } /* + * Fetch a key from a driver. + */ + +static inline hal_error_t ks_fetch_from_driver(const hal_ks_driver_t * const driver, + hal_pkey_slot_t *slot, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + hal_ks_t *ks = NULL; + hal_error_t err; + + if ((err = hal_ks_open(driver, &ks)) != HAL_OK) + return err; + + if ((err = hal_ks_fetch(ks, slot, der, der_len, der_max)) == HAL_OK) + err = hal_ks_close(ks); + else + (void) hal_ks_close(ks); + + return err; +} + +/* + * Same thing but from key flag in slot object rather than explict driver. + */ + +static inline hal_error_t ks_fetch_from_flags(hal_pkey_slot_t *slot, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + assert(slot != NULL); + + return ks_fetch_from_driver((slot->flags & HAL_KEY_FLAG_TOKEN) == 0 + ? hal_ks_volatile_driver + : hal_ks_token_driver, + slot, der, der_len, der_max); +} + + +/* * Receive key from application, generate a name (UUID), store it, and * return a key handle and the name. */ @@ -324,38 +365,38 @@ static hal_error_t pkey_local_load(const hal_client_handle_t client, static hal_error_t pkey_local_open(const hal_client_handle_t client, const hal_session_handle_t session, hal_pkey_handle_t *pkey, - const hal_uuid_t * const name, - const hal_key_flags_t flags) + const hal_uuid_t * const name) { assert(pkey != NULL && name != NULL); hal_pkey_slot_t *slot; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = check_readable(client, flags)) != HAL_OK) + if ((err = check_readable(client, 0)) != HAL_OK) return err; - if ((slot = alloc_slot(flags)) == NULL) + if ((slot = alloc_slot(0)) == NULL) return HAL_ERROR_NO_KEY_SLOTS_AVAILABLE; slot->name = *name; slot->client_handle = client; slot->session_handle = session; - if ((err = ks_open_from_flags(&ks, flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, NULL, NULL, 0)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); + if ((err = ks_fetch_from_driver(hal_ks_token_driver, slot, NULL, NULL, 0)) == HAL_OK) + slot->pkey_handle.handle |= HAL_PKEY_HANDLE_TOKEN_FLAG; - if (err != HAL_OK) { - slot->type = HAL_KEY_TYPE_NONE; - return err; - } + else if (err == HAL_ERROR_KEY_NOT_FOUND) + err = ks_fetch_from_driver(hal_ks_volatile_driver, slot, NULL, NULL, 0); + + if (err != HAL_OK) + goto fail; *pkey = slot->pkey_handle; return HAL_OK; + + fail: + memset(slot, 0, sizeof(*slot)); + return err; } /* @@ -608,16 +649,9 @@ static size_t pkey_local_get_public_key_len(const hal_pkey_handle_t pkey) hal_ecdsa_key_t *ecdsa_key = NULL; uint8_t der[HAL_KS_WRAPPED_KEYSIZE]; size_t der_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) { + if ((err = ks_fetch_from_flags(slot, der, &der_len, sizeof(der))) == HAL_OK) { switch (slot->type) { case HAL_KEY_TYPE_RSA_PUBLIC: @@ -658,21 +692,15 @@ static hal_error_t pkey_local_get_public_key(const hal_pkey_handle_t pkey, if (slot == NULL) return HAL_ERROR_KEY_NOT_FOUND; - uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; + uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size + ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; hal_rsa_key_t *rsa_key = NULL; hal_ecdsa_key_t *ecdsa_key = NULL; uint8_t buf[HAL_KS_WRAPPED_KEYSIZE]; size_t buf_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, buf, &buf_len, sizeof(buf))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) { + if ((err = ks_fetch_from_flags(slot, buf, &buf_len, sizeof(buf))) == HAL_OK) { switch (slot->type) { case HAL_KEY_TYPE_RSA_PUBLIC: @@ -813,20 +841,15 @@ static hal_error_t pkey_local_sign(const hal_pkey_handle_t pkey, if ((slot->flags & HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) == 0) return HAL_ERROR_FORBIDDEN; - uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; + uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size + ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; uint8_t der[HAL_KS_WRAPPED_KEYSIZE]; size_t der_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) - err = signer(keybuf, sizeof(keybuf), der, der_len, hash, input, input_len, signature, signature_len, signature_max); + if ((err = ks_fetch_from_flags(slot, der, &der_len, sizeof(der))) == HAL_OK) + err = signer(keybuf, sizeof(keybuf), der, der_len, hash, input, input_len, + signature, signature_len, signature_max); memset(keybuf, 0, sizeof(keybuf)); memset(der, 0, sizeof(der)); @@ -964,20 +987,15 @@ static hal_error_t pkey_local_verify(const hal_pkey_handle_t pkey, if ((slot->flags & HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) == 0) return HAL_ERROR_FORBIDDEN; - uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; + uint8_t keybuf[hal_rsa_key_t_size > hal_ecdsa_key_t_size + ? hal_rsa_key_t_size : hal_ecdsa_key_t_size]; uint8_t der[HAL_KS_WRAPPED_KEYSIZE]; size_t der_len; - hal_ks_t *ks = NULL; hal_error_t err; - if ((err = ks_open_from_flags(&ks, slot->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, slot, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err == HAL_OK) - err = verifier(keybuf, sizeof(keybuf), slot->type, der, der_len, hash, input, input_len, signature, signature_len); + if ((err = ks_fetch_from_flags(slot, der, &der_len, sizeof(der))) == HAL_OK) + err = verifier(keybuf, sizeof(keybuf), slot->type, der, der_len, hash, + input, input_len, signature, signature_len); memset(keybuf, 0, sizeof(keybuf)); memset(der, 0, sizeof(der)); @@ -985,40 +1003,111 @@ static hal_error_t pkey_local_verify(const hal_pkey_handle_t pkey, return err; } +static inline hal_error_t match_one_keystore(const hal_ks_driver_t * const driver, + const hal_client_handle_t client, + const hal_session_handle_t session, + const hal_key_type_t type, + const hal_curve_name_t curve, + const hal_key_flags_t mask, + const hal_key_flags_t flags, + const hal_pkey_attribute_t *attributes, + const unsigned attributes_len, + hal_uuid_t **result, + unsigned *result_len, + const unsigned result_max, + const hal_uuid_t * const previous_uuid) +{ + hal_ks_t *ks = NULL; + hal_error_t err; + unsigned len; + + if ((err = hal_ks_open(driver, &ks)) != HAL_OK) + return err; + + if ((err = hal_ks_match(ks, client, session, type, curve, + mask, flags, attributes, attributes_len, + *result, &len, result_max - *result_len, + previous_uuid)) != HAL_OK) { + (void) hal_ks_close(ks); + return err; + } + + if ((err = hal_ks_close(ks)) != HAL_OK) + return err; + + *result += len; + *result_len += len; + + return HAL_OK; +} + +typedef enum { + MATCH_STATE_START, + MATCH_STATE_TOKEN, + MATCH_STATE_VOLATILE, + MATCH_STATE_DONE +} match_state_t; + static hal_error_t pkey_local_match(const hal_client_handle_t client, const hal_session_handle_t session, const hal_key_type_t type, const hal_curve_name_t curve, + const hal_key_flags_t mask, const hal_key_flags_t flags, const hal_pkey_attribute_t *attributes, const unsigned attributes_len, + unsigned *state, hal_uuid_t *result, unsigned *result_len, const unsigned result_max, const hal_uuid_t * const previous_uuid) { - hal_ks_t *ks = NULL; + assert(state != NULL && result_len != NULL); + + static const hal_uuid_t uuid_zero[1] = {{{0}}}; + const hal_uuid_t *prev = previous_uuid; hal_error_t err; - err = check_readable(client, flags); + *result_len = 0; - if (err == HAL_ERROR_FORBIDDEN) { - assert(result_len != NULL); - *result_len = 0; + if ((err = check_readable(client, flags)) == HAL_ERROR_FORBIDDEN) return HAL_OK; - } - - if (err != HAL_OK) + else if (err != HAL_OK) return err; - if ((err = ks_open_from_flags(&ks, flags)) == HAL_OK && - (err = hal_ks_match(ks, client, session, type, curve, flags, attributes, attributes_len, - result, result_len, result_max, previous_uuid)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); + switch ((match_state_t) *state) { - return err; + case MATCH_STATE_START: + prev = uuid_zero; + ++*state; + + case MATCH_STATE_TOKEN: + if (((mask & HAL_KEY_FLAG_TOKEN) == 0 || (mask & flags & HAL_KEY_FLAG_TOKEN) != 0) && + (err = match_one_keystore(hal_ks_token_driver, client, session, type, curve, + mask, flags, attributes, attributes_len, + &result, result_len, result_max - *result_len, prev)) != HAL_OK) + return err; + if (*result_len == result_max) + return HAL_OK; + prev = uuid_zero; + ++*state; + + case MATCH_STATE_VOLATILE: + if (((mask & HAL_KEY_FLAG_TOKEN) == 0 || (mask & flags & HAL_KEY_FLAG_TOKEN) == 0) && + (err = match_one_keystore(hal_ks_volatile_driver, client, session, type, curve, + mask, flags, attributes, attributes_len, + &result, result_len, result_max - *result_len, prev)) != HAL_OK) + return err; + if (*result_len == result_max) + return HAL_OK; + ++*state; + + case MATCH_STATE_DONE: + return HAL_OK; + + default: + return HAL_ERROR_BAD_ARGUMENTS; + } } static hal_error_t pkey_local_set_attributes(const hal_pkey_handle_t pkey, @@ -1078,7 +1167,6 @@ static hal_error_t pkey_local_export(const hal_pkey_handle_t pkey_handle, uint8_t rsabuf[hal_rsa_key_t_size]; hal_rsa_key_t *rsa = NULL; - hal_ks_t *ks = NULL; hal_error_t err; size_t len; @@ -1100,12 +1188,7 @@ static hal_error_t pkey_local_export(const hal_pkey_handle_t pkey_handle, if (pkcs8_max < HAL_KS_WRAPPED_KEYSIZE) return HAL_ERROR_RESULT_TOO_LONG; - if ((err = ks_open_from_flags(&ks, kekek->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, kekek, pkcs8, &len, pkcs8_max)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - if (err != HAL_OK) + if ((err = ks_fetch_from_flags(kekek, pkcs8, &len, pkcs8_max)) != HAL_OK) goto fail; switch (kekek->type) { @@ -1129,13 +1212,7 @@ static hal_error_t pkey_local_export(const hal_pkey_handle_t pkey_handle, goto fail; } - if ((err = ks_open_from_flags(&ks, pkey->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, pkey, pkcs8, &len, pkcs8_max)) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - - if (err != HAL_OK) + if ((err = ks_fetch_from_flags(pkey, pkcs8, &len, pkcs8_max)) != HAL_OK) goto fail; if ((err = hal_get_random(NULL, kek, KEK_LENGTH)) != HAL_OK) @@ -1189,7 +1266,6 @@ static hal_error_t pkey_local_import(const hal_client_handle_t client, size_t der_len, oid_len, data_len; const uint8_t *oid, *data; hal_rsa_key_t *rsa = NULL; - hal_ks_t *ks = NULL; hal_error_t err; hal_pkey_slot_t * const kekek = find_handle(kekek_handle); @@ -1203,12 +1279,7 @@ static hal_error_t pkey_local_import(const hal_client_handle_t client, if (kekek->type != HAL_KEY_TYPE_RSA_PRIVATE) return HAL_ERROR_UNSUPPORTED_KEY; - if ((err = ks_open_from_flags(&ks, kekek->flags)) == HAL_OK && - (err = hal_ks_fetch(ks, kekek, der, &der_len, sizeof(der))) == HAL_OK) - err = hal_ks_close(ks); - else if (ks != NULL) - (void) hal_ks_close(ks); - if (err != HAL_OK) + if ((err = ks_fetch_from_flags(kekek, der, &der_len, sizeof(der))) != HAL_OK) goto fail; if ((err = hal_rsa_private_key_from_der(&rsa, rsabuf, sizeof(rsabuf), der, der_len)) != HAL_OK) diff --git a/rpc_server.c b/rpc_server.c index 55f15fe..d946b06 100644 --- a/rpc_server.c +++ b/rpc_server.c @@ -356,19 +356,17 @@ static hal_error_t pkey_open(const uint8_t **iptr, const uint8_t * const ilimit, hal_pkey_handle_t pkey; const uint8_t *name_ptr; uint32_t name_len; - hal_key_flags_t flags; hal_error_t ret; check(hal_xdr_decode_int(iptr, ilimit, &client.handle)); check(hal_xdr_decode_int(iptr, ilimit, &session.handle)); check(hal_xdr_decode_buffer_in_place(iptr, ilimit, &name_ptr, &name_len)); - check(hal_xdr_decode_int(iptr, ilimit, &flags)); if (name_len != sizeof(hal_uuid_t)) return HAL_ERROR_KEY_NAME_TOO_LONG; /* call the local function */ - ret = hal_rpc_pkey_open(client, session, &pkey, (const hal_uuid_t *) name_ptr, flags); + ret = hal_rpc_pkey_open(client, session, &pkey, (const hal_uuid_t *) name_ptr); if (ret == HAL_OK) check(hal_xdr_encode_int(optr, olimit, pkey.handle)); @@ -643,15 +641,17 @@ static hal_error_t pkey_match(const uint8_t **iptr, const uint8_t * const ilimit { hal_client_handle_t client; hal_session_handle_t session; - uint32_t type, curve, attributes_len, result_max, previous_uuid_len; + uint32_t type, curve, attributes_len, state, result_max, previous_uuid_len; const uint8_t *previous_uuid_ptr; - hal_key_flags_t flags; + hal_key_flags_t mask, flags; + uint8_t *optr_orig = *optr; hal_error_t ret; check(hal_xdr_decode_int(iptr, ilimit, &client.handle)); check(hal_xdr_decode_int(iptr, ilimit, &session.handle)); check(hal_xdr_decode_int(iptr, ilimit, &type)); check(hal_xdr_decode_int(iptr, ilimit, &curve)); + check(hal_xdr_decode_int(iptr, ilimit, &mask)); check(hal_xdr_decode_int(iptr, ilimit, &flags)); check(hal_xdr_decode_int(iptr, ilimit, &attributes_len)); @@ -667,6 +667,7 @@ static hal_error_t pkey_match(const uint8_t **iptr, const uint8_t * const ilimit a->length = value_len; } + check(hal_xdr_decode_int(iptr, ilimit, &state)); check(hal_xdr_decode_int(iptr, ilimit, &result_max)); check(hal_xdr_decode_buffer_in_place(iptr, ilimit, &previous_uuid_ptr, &previous_uuid_len)); @@ -676,22 +677,24 @@ static hal_error_t pkey_match(const uint8_t **iptr, const uint8_t * const ilimit const hal_uuid_t * const previous_uuid = (const void *) previous_uuid_ptr; hal_uuid_t result[result_max]; - unsigned result_len; + unsigned result_len, ustate = state; - ret = hal_rpc_pkey_match(client, session, type, curve, flags, + ret = hal_rpc_pkey_match(client, session, type, curve, mask, flags, attributes, attributes_len, - result, &result_len, result_max, + &ustate, result, &result_len, result_max, previous_uuid); - if (ret == HAL_OK) { - uint8_t *optr_orig = *optr; + if (ret == HAL_OK) + ret = hal_xdr_encode_int(optr, olimit, ustate); + + if (ret == HAL_OK) ret = hal_xdr_encode_int(optr, olimit, result_len); - for (int i = 0; ret == HAL_OK && i < result_len; ++i) - ret = hal_xdr_encode_buffer(optr, olimit, result[i].uuid, - sizeof(result[i].uuid)); - if (ret != HAL_OK) - *optr = optr_orig; - } + + for (int i = 0; ret == HAL_OK && i < result_len; ++i) + ret = hal_xdr_encode_buffer(optr, olimit, result[i].uuid, + sizeof(result[i].uuid)); + if (ret != HAL_OK) + *optr = optr_orig; return ret; } diff --git a/tests/test-rpc_pkey.c b/tests/test-rpc_pkey.c index 1b5f86a..1f00fb8 100644 --- a/tests/test-rpc_pkey.c +++ b/tests/test-rpc_pkey.c @@ -98,25 +98,27 @@ static int test_attributes(const hal_pkey_handle_t pkey, const hal_client_handle_t client = {HAL_HANDLE_NONE}; const hal_session_handle_t session = {HAL_HANDLE_NONE}; hal_uuid_t result[10], previous_uuid = {{0}}; - unsigned result_len; + unsigned result_len, state; - if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, flags, NULL, 0, - result, &result_len, sizeof(result)/sizeof(*result), + state = 0; + if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, 0, 0, NULL, 0, + &state, result, &result_len, sizeof(result)/sizeof(*result), &previous_uuid)) != HAL_OK) lose("Unrestricted match() failed: %s\n", hal_error_string(err)); if (result_len == 0) lose("Unrestricted match found no results\n"); + state = 0; for (const size_t *size = sizes; *size; size++) { uint8_t buf[*size]; memset(buf, 0x55, sizeof(buf)); snprintf((char *) buf, sizeof(buf), format, (unsigned long) *size); hal_pkey_attribute_t attribute[1] = {{ *size, sizeof(buf), buf }}; - if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, flags, + if ((err = hal_rpc_pkey_match(client, session, HAL_KEY_TYPE_NONE, HAL_CURVE_NONE, 0, 0, attribute, sizeof(attribute)/sizeof(*attribute), - result, &result_len, sizeof(result)/sizeof(*result), + &state, result, &result_len, sizeof(result)/sizeof(*result), &previous_uuid)) != HAL_OK) lose("Restricted match() for attribute %lu failed: %s\n", (unsigned long) *size, hal_error_string(err)); diff --git a/unit-tests.py b/unit-tests.py index d5b6c77..295bf40 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -550,43 +550,54 @@ class TestPKeyMatch(TestCaseLoggedIn): Tests involving PKey list and match functions. """ + @staticmethod + def key_flag_names(flags): + names = dict(digitalsignature = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE, + keyencipherment = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT, + dataencipherment = HAL_KEY_FLAG_USAGE_DATAENCIPHERMENT, + token = HAL_KEY_FLAG_TOKEN, + public = HAL_KEY_FLAG_PUBLIC, + exportable = HAL_KEY_FLAG_EXPORTABLE) + return ", ".join(sorted(k for k, v in names.iteritems() if (flags & v) != 0)) + def load_keys(self, flags): uuids = set() for obj in PreloadedKey.db.itervalues(): with hsm.pkey_load(obj.der, flags) as k: - self.addCleanup(lambda uuid: hsm.pkey_open(uuid, flags = flags).delete(), k.uuid) + self.addCleanup(lambda uuid: hsm.pkey_open(uuid).delete(), k.uuid) uuids.add(k.uuid) + #print k.uuid, k.key_type, k.key_curve, self.key_flag_names(k.key_flags) k.set_attributes(dict((i, a) for i, a in enumerate((str(obj.keytype), str(obj.fn2))))) return uuids - def match(self, flags, **kwargs): - uuids = kwargs.pop("uuids", None) - kwargs.update(flags = flags) + def match(self, uuids, **kwargs): n = 0 for uuid in hsm.pkey_match(**kwargs): - if uuids is None or uuid in uuids: - with hsm.pkey_open(uuid, flags) as k: + if uuid in uuids: + with hsm.pkey_open(uuid) as k: n += 1 yield n, k - def ks_match(self, flags): + def ks_match(self, mask, flags): tags = [] uuids = set() for i in xrange(2): - uuids |= self.load_keys(flags) + uuids |= self.load_keys(flags if mask else HAL_KEY_FLAG_TOKEN * i) tags.extend(PreloadedKey.db) self.assertEqual(len(tags), len(uuids)) - self.assertEqual(uuids, set(k.uuid for n, k in self.match(flags = flags, uuids = uuids))) + self.assertEqual(uuids, set(k.uuid for n, k in self.match(mask = mask, + flags = flags, + uuids = uuids))) for keytype in set(HALKeyType.index.itervalues()) - {HAL_KEY_TYPE_NONE}: - for n, k in self.match(flags = flags, uuids = uuids, type = keytype): + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = keytype): self.assertEqual(k.key_type, keytype) self.assertEqual(k.get_attributes({0}).pop(0), str(keytype)) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype)) for curve in set(HALCurve.index.itervalues()) - {HAL_CURVE_NONE}: - for n, k in self.match(flags = flags, uuids = uuids, curve = curve): + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, curve = curve): self.assertEqual(k.key_curve, curve) self.assertEqual(k.get_attributes({1}).pop(1), str(curve)) self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC, @@ -594,7 +605,7 @@ class TestPKeyMatch(TestCaseLoggedIn): self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve)) for keylen in set(kl for kt, kl in tags if not isinstance(kl, Enum)): - for n, k in self.match(flags = flags, uuids = uuids, + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, attributes = {1 : str(keylen)}): self.assertEqual(keylen, int(k.get_attributes({1}).pop(1))) self.assertIn(k.key_type, (HAL_KEY_TYPE_RSA_PUBLIC, @@ -602,17 +613,20 @@ class TestPKeyMatch(TestCaseLoggedIn): self.assertEqual(n, sum(1 for t1, t2 in tags if not isinstance(t2, Enum) and t2 == keylen)) - for n, k in self.match(flags = flags, uuids = uuids, + for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = HAL_KEY_TYPE_RSA_PUBLIC, attributes = {1 : "2048"}): self.assertEqual(k.key_type, HAL_KEY_TYPE_RSA_PUBLIC) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == HAL_KEY_TYPE_RSA_PUBLIC and t2 == 2048)) def test_ks_match_token(self): - self.ks_match(HAL_KEY_FLAG_TOKEN) + self.ks_match(mask = HAL_KEY_FLAG_TOKEN, flags = HAL_KEY_FLAG_TOKEN) def test_ks_match_volatile(self): - self.ks_match(0) + self.ks_match(mask = HAL_KEY_FLAG_TOKEN, flags = 0) + + def test_ks_match_all(self): + self.ks_match(mask = 0, flags = 0) class TestPKeyAttribute(TestCaseLoggedIn): @@ -626,7 +640,7 @@ class TestPKeyAttribute(TestCaseLoggedIn): for obj in PreloadedKey.db.itervalues(): with hsm.pkey_load(obj.der, flags) as k: pinwheel() - self.addCleanup(lambda uuid: hsm.pkey_open(uuid, flags = flags).delete(), k.uuid) + self.addCleanup(lambda uuid: hsm.pkey_open(uuid).delete(), k.uuid) k.set_attributes(dict((j, "Attribute {}{}".format(j, "*" * n_fill)) for j in xrange(n_attrs))) pinwheel() |