diff options
-rw-r--r-- | Makefile | 23 | ||||
-rw-r--r-- | core.c | 10 | ||||
-rwxr-xr-x | cryptech_console | 116 | ||||
-rwxr-xr-x | cryptech_muxd | 422 | ||||
-rw-r--r-- | csprng.c | 23 | ||||
-rw-r--r-- | daemon.c | 330 | ||||
-rw-r--r-- | ecdsa.c | 6 | ||||
-rw-r--r-- | hal_internal.h | 11 | ||||
-rw-r--r-- | ks_attribute.c | 9 | ||||
-rw-r--r-- | ks_flash.c | 979 | ||||
-rw-r--r-- | ks_index.c | 42 | ||||
-rw-r--r-- | ks_volatile.c | 240 | ||||
-rw-r--r-- | libhal.py | 129 | ||||
-rw-r--r-- | locks.c | 108 | ||||
-rw-r--r-- | rpc_client_daemon.c | 41 | ||||
-rw-r--r-- | rpc_misc.c | 20 | ||||
-rw-r--r-- | rpc_pkey.c | 35 | ||||
-rw-r--r-- | unit-tests.py | 60 |
18 files changed, 1639 insertions, 965 deletions
@@ -27,13 +27,13 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Number of static hash and HMAC state blocks to allocate. -# Numbers pulled out of a hat, just testing. +# Number of various kinds of static state blocks to allocate. +# Numbers pulled out of a hat, tune as we go. STATIC_CORE_STATE_BLOCKS = 32 -STATIC_HASH_STATE_BLOCKS = 10 -STATIC_HMAC_STATE_BLOCKS = 4 -STATIC_PKEY_STATE_BLOCKS = 32 +STATIC_HASH_STATE_BLOCKS = 32 +STATIC_HMAC_STATE_BLOCKS = 16 +STATIC_PKEY_STATE_BLOCKS = 256 STATIC_KS_VOLATILE_SLOTS = 128 INC = hal.h hal_internal.h @@ -97,7 +97,7 @@ endif # just "building this is harmless even if we don't use it." OBJ += errorstrings.o hash.o asn1.o ecdsa.o rsa.o xdr.o slip.o -OBJ += rpc_api.o rpc_hash.o uuid.o rpc_pkcs1.o crc32.o +OBJ += rpc_api.o rpc_hash.o uuid.o rpc_pkcs1.o crc32.o locks.o # Object files to build when we're on a platform with direct access # to our hardware (Verilog) cores. @@ -166,16 +166,12 @@ endif # the C preprocessor: we can use symbolic names so long as they're defined as macros # in the C code, but we can't use things like C enum symbols. -ifneq "${RPC_MODE}" "server" - OBJ += rpc_serial.o -endif - RPC_CLIENT_OBJ = rpc_client.o ifeq "${RPC_TRANSPORT}" "loopback" RPC_CLIENT_OBJ += rpc_client_loopback.o else ifeq "${RPC_TRANSPORT}" "serial" - RPC_CLIENT_OBJ += rpc_client_serial.o + RPC_CLIENT_OBJ += rpc_serial.o rpc_client_serial.o else ifeq "${RPC_TRANSPORT}" "daemon" RPC_CLIENT_OBJ += rpc_client_daemon.o endif @@ -269,13 +265,10 @@ server: serial: ${MAKE} RPC_MODE=client-mixed RPC_TRANSPORT=serial -daemon: mixed cryptech_rpcd +daemon: mixed .PHONY: client mixed server serial daemon -cryptech_rpcd: daemon.o ${LIB} - ${CC} ${CFLAGS} -o $@ $^ ${LDFLAGS} - ${OBJ}: ${INC} ${LIB}: ${OBJ} @@ -201,16 +201,6 @@ hal_core_t *hal_core_find(const char *name, hal_core_t *core) return NULL; } -__attribute__((weak)) void hal_critical_section_start(void) -{ - return; -} - -__attribute__((weak)) void hal_critical_section_end(void) -{ - return; -} - hal_error_t hal_core_alloc(const char *name, hal_core_t **pcore) { hal_core_t *core; diff --git a/cryptech_console b/cryptech_console new file mode 100755 index 0000000..5ac12ba --- /dev/null +++ b/cryptech_console @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# +# Copyright (c) 2017, NORDUnet A/S All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# - Neither the name of the NORDUnet nor the names of its contributors may +# be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Console client shim to work with Cryptech Python multiplexer. +""" + +import os +import sys +import socket +import atexit +import termios +import argparse + +import tornado.iostream +import tornado.ioloop +import tornado.gen + +class FemtoTerm(object): + + def __init__(self, s): + self.termios_setup() + self.stdin_stream = tornado.iostream.PipeIOStream(sys.stdin.fileno()) + self.stdout_stream = tornado.iostream.PipeIOStream(sys.stdout.fileno()) + self.socket_stream = tornado.iostream.IOStream(s) + self.closed = False + + def close(self): + self.termios_teardown() + self.stdin_stream.close() + self.stdout_stream.close() + self.socket_stream.close() + self.closed = True + + @tornado.gen.coroutine + def run(self): + yield [self.stdout_loop(), self.stdin_loop()] + + def termios_setup(self): + self.fd = sys.stdin.fileno() + self.old_tcattr = termios.tcgetattr(self.fd) + self.new_tcattr = termios.tcgetattr(self.fd) + atexit.register(self.termios_teardown) + self.new_tcattr[3] &= ~(termios.ICANON | termios.ECHO) # | termios.ISIG + self.new_tcattr[6][termios.VMIN] = 1 + self.new_tcattr[6][termios.VTIME] = 0 + termios.tcsetattr(self.fd, termios.TCSANOW, self.new_tcattr) + + def termios_teardown(self): + if self.fd is not None: + termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.old_tcattr) + self.fd = None + + @tornado.gen.coroutine + def copy_loop(self, stream1, stream2, text1, text2, buffer_size = 1024): + try: + while not self.closed: + buffer = yield stream1.read_bytes(buffer_size, partial = True) + yield stream2.write(buffer.replace(text1, text2)) + except tornado.iostream.StreamClosedError: + self.close() + + def stdin_loop(self): + return self.copy_loop(self.stdin_stream, self.socket_stream, "\n", "\r") + + def stdout_loop(self): + return self.copy_loop(self.socket_stream, self.stdout_stream, "\r\n", "\n") + +def main(): + parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("cty_socket", + nargs = "?", + help = "CTY PF_UNIX socket name", + default = os.getenv("CRYPTECH_CTY_CLIENT_SOCKET_NAME", + "/tmp/.cryptech_muxd.cty")) + args = parser.parse_args() + + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + s.connect(args.cty_socket) + except socket.error: + sys.exit("Couldn't connect to socket {}".format(args.cty_socket)) + tornado.ioloop.IOLoop.current().run_sync(FemtoTerm(s).run) + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + pass diff --git a/cryptech_muxd b/cryptech_muxd new file mode 100755 index 0000000..269ac15 --- /dev/null +++ b/cryptech_muxd @@ -0,0 +1,422 @@ +#!/usr/bin/env python +# +# Copyright (c) 2016-2017, NORDUnet A/S All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# - Neither the name of the NORDUnet nor the names of its contributors may +# be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Implementation of Cryptech RPC protocol multiplexer in Python. + +Unlike the original C implementation, this uses SLIP encapsulation +over a SOCK_STREAM channel, because support for SOCK_SEQPACKET is not +what we might wish. We outsource all the heavy lifting for serial and +network I/O to the PySerial and Tornado libraries, respectively. +""" + +import os +import sys +import time +import struct +import atexit +import weakref +import logging +import argparse +import logging.handlers + +import serial +import serial.tools.list_ports_posix + +import tornado.tcpserver +import tornado.iostream +import tornado.netutil +import tornado.ioloop +import tornado.queues +import tornado.locks +import tornado.gen + + +logger = logging.getLogger("cryptech_muxd") + + +SLIP_END = chr(0300) # Indicates end of SLIP packet +SLIP_ESC = chr(0333) # Indicates byte stuffing +SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte +SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte + +Control_U = chr(0025) # Console: clear line +Control_M = chr(0015) # Console: end of line + + +def slip_encode(buffer): + "Encode a buffer using SLIP encapsulation." + return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END + +def slip_decode(buffer): + "Decode a SLIP-encapsulated buffer." + return buffer.strip(SLIP_END).replace(SLIP_ESC + SLIP_ESC_END, SLIP_END).replace(SLIP_ESC + SLIP_ESC_ESC, SLIP_ESC) + + +def client_handle_get(msg): + "Extract client_handle field from a Cryptech RPC message." + return struct.unpack(">L", msg[4:8])[0] + +def client_handle_set(msg, handle): + "Replace client_handle field in a Cryptech RPC message." + return msg[:4] + struct.pack(">L", handle) + msg[8:] + + +class SerialIOStream(tornado.iostream.BaseIOStream): + """ + Implementation of a Tornado IOStream over a PySerial device. + """ + + def __init__(self, device): + self.serial = serial.Serial(device, 921600, timeout = 0, write_timeout = 0) + self.serial_device = device + super(SerialIOStream, self).__init__() + + def fileno(self): + return self.serial.fileno() + + def close_fd(self): + self.serial.close() + + def write_to_fd(self, data): + return self.serial.write(data) + + def read_from_fd(self): + return self.serial.read(self.read_chunk_size) or None + + +class PFUnixServer(tornado.tcpserver.TCPServer): + """ + Variant on tornado.tcpserver.TCPServer, listening on a PF_UNIX + (aka PF_LOCAL) socket instead of a TCP socket. + """ + + def __init__(self, serial_stream, socket_filename, mode = 0600): + super(PFUnixServer, self).__init__() + self.serial = serial_stream + self.socket_filename = socket_filename + self.add_socket(tornado.netutil.bind_unix_socket(socket_filename, mode)) + atexit.register(self.atexit_unlink) + + def atexit_unlink(self): + try: + os.unlink(self.socket_filename) + except: + pass + + +class RPCIOStream(SerialIOStream): + """ + Tornado IOStream for a serial RPC channel. + """ + + def __init__(self, device): + super(RPCIOStream, self).__init__(device) + self.queues = weakref.WeakValueDictionary() + self.rpc_input_lock = tornado.locks.Lock() + + @tornado.gen.coroutine + def rpc_input(self, query, handle, queue): + "Send a query to the HSM." + logger.debug("RPC send: %s", ":".join("{:02x}".format(ord(c)) for c in query)) + self.queues[handle] = queue + with (yield self.rpc_input_lock.acquire()): + yield self.write(query) + logger.debug("RPC sent") + + @tornado.gen.coroutine + def rpc_output_loop(self): + "Handle reply stream HSM -> network." + while True: + try: + logger.debug("RPC UART read") + reply = yield self.read_until(SLIP_END) + except tornado.iostream.StreamClosedError: + logger.info("RPC UART closed") + for q in self.queues.itervalues(): + q.put_nowait(None) + return + logger.debug("RPC recv: %s", ":".join("{:02x}".format(ord(c)) for c in reply)) + try: + handle = client_handle_get(slip_decode(reply)) + except: + continue + logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s", + handle, self.queues[handle].qsize(), self.queues[handle].maxsize) + self.queues[handle].put_nowait(reply) + + +class QueuedStreamClosedError(tornado.iostream.StreamClosedError): + "Deferred StreamClosedError passed throught a Queue." + + +class RPCServer(PFUnixServer): + """ + Serve multiplexed Cryptech RPC over a PF_UNIX socket. + """ + + @tornado.gen.coroutine + def handle_stream(self, stream, address): + "Handle one network connection." + handle = stream.socket.fileno() + queue = tornado.queues.Queue() + logger.info("RPC connected %r, handle 0x%x", stream, handle) + while True: + try: + logger.debug("RPC socket read, handle 0x%x", handle) + query = yield stream.read_until(SLIP_END) + if len(query) < 9: + continue + query = slip_encode(client_handle_set(slip_decode(query), handle)) + yield self.serial.rpc_input(query, handle, queue) + logger.debug("RPC queue wait, handle 0x%x", handle) + reply = yield queue.get() + if reply is None: + raise QueuedStreamClosedError() + logger.debug("RPC socket write, handle 0x%x", handle) + yield stream.write(SLIP_END + reply) + except tornado.iostream.StreamClosedError: + logger.info("RPC closing %r, handle 0x%x", stream, handle) + stream.close() + return + + +class CTYIOStream(SerialIOStream): + """ + Tornado IOStream for a serial console channel. + """ + + def __init__(self, device): + super(CTYIOStream, self).__init__(device) + self.attached_cty = None + + @tornado.gen.coroutine + def cty_output_loop(self): + while True: + try: + buffer = yield self.read_bytes(self.read_chunk_size, partial = True) + except tornado.iostream.StreamClosedError: + logger.info("CTY UART closed") + if self.attached_cty is not None: + self.attached_cty.close() + return + try: + if self.attached_cty is not None: + yield self.attached_cty.write(buffer) + except tornado.iostream.StreamClosedError: + pass + + +class CTYServer(PFUnixServer): + """ + Serve Cryptech console over a PF_UNIX socket. + """ + + @tornado.gen.coroutine + def handle_stream(self, stream, address): + "Handle one network connection." + + if self.serial.attached_cty is not None: + yield stream.write("[Console already in use, sorry]\n") + stream.close() + return + + logger.info("CTY connected to %r", stream) + + try: + self.serial.attached_cty = stream + while self.serial.attached_cty is stream: + yield self.serial.write((yield stream.read_bytes(1024, partial = True))) + except tornado.iostream.StreamClosedError: + stream.close() + finally: + logger.info("CTY disconnected from %r", stream) + if self.serial.attached_cty is stream: + self.serial.attached_cty = None + + +class ProbeIOStream(SerialIOStream): + """ + Tornado IOStream for probing a serial port. This is nasty. + """ + + def __init__(self, device): + super(ProbeIOStream, self).__init__(device) + + @classmethod + @tornado.gen.coroutine + def run_probes(cls, args): + + if args.rpc_device is not None and args.cty_device is not None: + return + + if args.probe: + devs = set(args.probe) + else: + devs = set(str(port) + for port, desc, hwid in serial.tools.list_ports_posix.comports() + if "VID:PID=0403:6014" in hwid) + + devs.discard(args.rpc_device) + devs.discard(args.cty_device) + + if not devs: + return + + logging.debug("Probing candidate devices %s", " ".join(devs)) + + results = yield dict((dev, ProbeIOStream(dev).run_probe()) for dev in devs) + + for dev, result in results.iteritems(): + + if result == "cty" and args.cty_device is None: + logger.info("Selecting %s as CTY device", dev) + args.cty_device = dev + + if result == "rpc" and args.rpc_device is None: + logger.info("Selecting %s as RPC device", dev) + args.rpc_device = dev + + @tornado.gen.coroutine + def run_probe(self): + + RPC_query = chr(0) * 8 # client_handle = 0, function code = RPC_FUNC_GET_VERSION + RPC_reply = chr(0) * 12 # opcode = RPC_FUNC_GET_VERSION, client_handle = 0, valret = HAL_OK + + probe_string = SLIP_END + Control_U + SLIP_END + RPC_query + SLIP_END + Control_U + Control_M + + yield self.write(probe_string) + yield tornado.gen.sleep(0.5) + response = yield self.read_bytes(self.read_chunk_size, partial = True) + + logger.debug("Probing %s: %r %s", self.serial_device, response, ":".join("{:02x}".format(ord(c)) for c in response)) + + is_cty = any(prompt in response for prompt in ("Username:", "Password:", "cryptech>")) + + try: + is_rpc = response[response.index(SLIP_END + RPC_reply) + len(SLIP_END + RPC_reply) + 4] == SLIP_END + except ValueError: + is_rpc = False + except IndexError: + is_rpc = False + + assert not is_cty or not is_rpc + + result = None + + if is_cty: + result = "cty" + yield self.write(Control_U) + + if is_rpc: + result = "rpc" + yield self.write(SLIP_END) + + self.close() + raise tornado.gen.Return(result) + + + +@tornado.gen.coroutine +def main(): + parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("-v", "--verbose", + action = "count", + help = "blather about what we're doing") + + parser.add_argument("-l", "--log-file", + help = "log to file instead of stderr") + + parser.add_argument("-p", "--probe", + nargs = "*", + metavar = "DEVICE", + help = "probe for device UARTs") + + parser.add_argument("--rpc-device", + help = "RPC serial device name", + default = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE")) + + parser.add_argument("--rpc-socket", + help = "RPC PF_UNIX socket name", + default = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", + "/tmp/.cryptech_muxd.rpc")) + + parser.add_argument("--cty-device", + help = "CTY serial device name", + default = os.getenv("CRYPTECH_CTY_CLIENT_SERIAL_DEVICE")) + + parser.add_argument("--cty-socket", + help = "CTY PF_UNIX socket name", + default = os.getenv("CRYPTECH_CTY_CLIENT_SOCKET_NAME", + "/tmp/.cryptech_muxd.cty")) + + args = parser.parse_args() + + if args.log_file is not None: + logging.getLogger().handlers[:] = [logging.handlers.WatchedFileHandler(args.log_file)] + + logging.getLogger().handlers[0].setFormatter( + logging.Formatter("%(asctime)-15s %(name)s[%(process)d]:%(levelname)s: %(message)s", + "%Y-%m-%d %H:%M:%S")) + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG if args.verbose > 1 else logging.INFO) + + if args.probe is not None: + yield ProbeIOStream.run_probes(args) + + futures = [] + + if args.rpc_device is None: + logger.warn("No RPC device found") + else: + rpc_stream = RPCIOStream(device = args.rpc_device) + rpc_server = RPCServer(rpc_stream, args.rpc_socket) + futures.append(rpc_stream.rpc_output_loop()) + + if args.cty_device is None: + logger.warn("No CTY device found") + else: + cty_stream = CTYIOStream(device = args.cty_device) + cty_server = CTYServer(cty_stream, args.cty_socket) + futures.append(cty_stream.cty_output_loop()) + + # Might want to use WaitIterator(dict(...)) here so we can + # diagnose and restart output loops if they fail? + + if futures: + yield futures + +if __name__ == "__main__": + try: + tornado.ioloop.IOLoop.current().run_sync(main) + except KeyboardInterrupt: + pass @@ -45,35 +45,34 @@ hal_error_t hal_get_random(hal_core_t *core, void *buffer, const size_t length) { - uint8_t temp[4], *buf = buffer; + uint8_t temp[4], ior = 0, * const buf = buffer; hal_error_t err; - size_t i; if ((err = hal_core_alloc(CSPRNG_NAME, &core)) != HAL_OK) return err; - for (i = 0; i < length; i += 4) { + for (size_t i = 0; i < length; i += 4) { const int last = (length - i) < 4; if (WAIT_FOR_CSPRNG_VALID && (err = hal_io_wait_valid(core)) != HAL_OK) - goto out; + break; if ((err = hal_io_read(core, CSPRNG_ADDR_RANDOM, (last ? temp : &buf[i]), 4)) != HAL_OK) - goto out; + break; if (last) for (; i < length; i++) buf[i] = temp[i&3]; } - for (i = 0, buf = buffer; i < length; i++, buf++) - if (*buf != 0) { - err = HAL_OK; - goto out; - } - err = HAL_ERROR_CSPRNG_BROKEN; + if (err == HAL_OK) { + for (size_t i = 0; i < length; i++) + ior |= buf[i]; + + if (ior == 0 && length > 0) + err = HAL_ERROR_CSPRNG_BROKEN; + } -out: hal_core_free(core); return err; } diff --git a/daemon.c b/daemon.c deleted file mode 100644 index ff95353..0000000 --- a/daemon.c +++ /dev/null @@ -1,330 +0,0 @@ -#define DEBUG -/* - * daemon.c - * -------- - * A daemon to arbitrate shared access to a serial connection to the HSM. - * - * Copyright (c) 2016, NORDUnet A/S All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * - Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * - Neither the name of the NORDUnet nor the names of its contributors may - * be used to endorse or promote products derived from this software - * without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS - * IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED - * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A - * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED - * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <sys/socket.h> -#include <sys/un.h> -#include <unistd.h> -#include <poll.h> -#include <getopt.h> /* required with -std=c99 */ -#include <termios.h> /* for default speed */ - -#include "hal_internal.h" -#include "slip_internal.h" -#include "xdr_internal.h" - -static char usage[] = - "usage: %s [-n socketname] [-d ttydevice] [-s ttyspeed]\n"; - -/* - * Work around glibc "feature test" insanity. This isn't the correct - * definition according to the POSIX, but it does what seems to be the - * normal hack on Linux (where this is broken more often than not). - */ - -#ifndef SUN_LEN -#define SUN_LEN(_sun_ptr_) (sizeof(*(_sun_ptr_))) -#endif - -/* select() is hopelessly broken, and epoll() is Linux-specific, so we'll use - * poll() until such a time as libevent or libev seems more appropriate. - * Unfortunately, poll() doesn't come with any macros or functions to manage - * the pollfd array, so we have to invent them. - */ - -static struct pollfd *pollfds = NULL; -static nfds_t nfds = 0; -static nfds_t npollfds = 0; - -static void poll_add(int fd) -{ - /* add 4 entries at a time to avoid having to realloc too often */ -#define NNEW 4 - - /* expand the array if necessary */ - if (nfds == npollfds) { - npollfds = nfds + NNEW; - pollfds = realloc(pollfds, npollfds * sizeof(struct pollfd)); - if (pollfds == NULL) { - perror("realloc"); - exit(EXIT_FAILURE); - } - /* zero the new entries for hygiene */ - memset(&pollfds[nfds], 0, NNEW * sizeof(struct pollfd)); - } - - /* populate the new entry */ - pollfds[nfds].fd = fd; - pollfds[nfds].events = POLLIN; - ++nfds; -} - -static void poll_remove(int fd) -{ - nfds_t i; - - /* search the pollfd array */ - for (i = 0; i < nfds; ++i) { - if (pollfds[i].fd == fd) { - /* shift remainder of the array left by one */ - memmove(&pollfds[i], &pollfds[i + 1], (nfds - i - 1) * sizeof(struct pollfd)); - /* zero the last entry for hygiene */ - memset(&pollfds[nfds - 1], 0, sizeof(struct pollfd)); - --nfds; - return; - } - } - /* if it's not found, return without an error */ -} - -typedef struct { - size_t len; - uint8_t buf[HAL_RPC_MAX_PKT_SIZE]; -} rpc_buffer_t; -static rpc_buffer_t ibuf, obuf; - -const char *socket_name = HAL_CLIENT_DAEMON_DEFAULT_SOCKET_NAME; - -/* Set up an atexit handler to remove the filesystem entry for the unix domain - * socket. This will trigger on error exits, but not on the "normal" SIGKILL. - */ -void atexit_cleanup(void) -{ - unlink(socket_name); -} - -#ifdef DEBUG -static void hexdump(uint8_t *buf, uint32_t len) -{ - for (uint32_t i = 0; i < len; ++i) - printf("%02x%c", buf[i], ((i & 0x07) == 0x07) ? '\n' : ' '); - if ((len & 0x07) != 0) - printf("\n"); -} -#endif - -int main(int argc, char *argv[]) -{ - struct sockaddr_un name; - int ret; - int lsock; - int dsock; - int opt; - const char *device = getenv(HAL_CLIENT_SERIAL_DEVICE_ENVVAR); - const char *speed_ = getenv(HAL_CLIENT_SERIAL_SPEED_ENVVAR); - uint32_t speed = HAL_CLIENT_SERIAL_DEFAULT_SPEED; - - if (device == NULL) - device = HAL_CLIENT_SERIAL_DEFAULT_DEVICE; - - if (speed_ != NULL) - speed = (uint32_t) strtoul(speed_, NULL, 10); - - while ((opt = getopt(argc, argv, "hn:d:s:")) != -1) { - switch (opt) { - case 'h': - printf(usage, argv[0]); - exit(EXIT_SUCCESS); - case 'n': - socket_name = optarg; - break; - case 'd': - device = optarg; - break; - case 's': - speed = (uint32_t) strtoul(optarg, NULL, 10); - switch (speed) { - case 115200: - case 921600: - break; - default: - printf("invalid speed value %s\n", optarg); - exit(EXIT_FAILURE); - } - break; - default: - printf(usage, argv[0]); - exit(EXIT_FAILURE); - } - } - - if (atexit(atexit_cleanup) != 0) { - perror("atexit"); - exit(EXIT_FAILURE); - } - - if (hal_serial_init(device, speed) != HAL_OK) - exit(EXIT_FAILURE); - - int serial_fd = hal_serial_get_fd(); - poll_add(serial_fd); - - /* Remove the filesystem entry for the unix domain socket. The usual way - * to stop a daemon is SIGKILL, which we can't catch, so the file remains, - * and will prevent us from binding the socket. - * - * XXX We should also scan the process table, to make sure the daemon - * isn't already running. - */ - unlink(socket_name); - - /* Create the listening socket. - */ - lsock = socket(AF_UNIX, SOCK_SEQPACKET, 0); - if (lsock == -1) { - perror("socket"); - exit(EXIT_FAILURE); - } - poll_add(lsock); - - /* For portability, clear the whole address structure, since some - * implementations have additional (nonstandard) fields in the structure. - */ - memset(&name, 0, sizeof(struct sockaddr_un)); - - /* Bind the listening socket. On some platforms, we have to pass the "real" - * (number of bytes in use) length of the sockaddr_un to get the name bound - * correctly, so use the SUN_LEN() macro to calculate that. - */ - name.sun_family = AF_UNIX; - strncpy(name.sun_path, socket_name, sizeof(name.sun_path) - 1); - ret = bind(lsock, (const struct sockaddr *) &name, SUN_LEN(&name)); - if (ret == -1) { - perror("bind"); - exit(EXIT_FAILURE); - } - - /* Prepare to accept connections. - */ - ret = listen(lsock, 20); - if (ret == -1) { - perror("listen"); - exit(EXIT_FAILURE); - } - - /* The main loop. - */ - for (;;) { - - /* Blocking poll on all descriptors of interest. - */ - ret = poll(pollfds, nfds, -1); - if (ret == -1) { - perror("poll"); - exit(EXIT_FAILURE); - } - - for (nfds_t i = 0; i < nfds; ++i) { - if (pollfds[i].revents != 0) { - /* XXX POLLERR|POLLHUP|POLLNVAL */ - - /* serial port */ - if (pollfds[i].fd == serial_fd) { - int complete; - hal_slip_recv_char(ibuf.buf, &ibuf.len, sizeof(ibuf.buf), &complete); - if (complete) { -#ifdef DEBUG - printf("serial port received response:\n"); - hexdump(ibuf.buf, ibuf.len); -#endif - /* We've got a complete rpc response packet. */ - const uint8_t *bufptr = ibuf.buf + 4; - const uint8_t * const limit = ibuf.buf + ibuf.len; - uint32_t sock; - /* Second word of the response is the client ID. */ - hal_xdr_decode_int(&bufptr, limit, &sock); - /* Pass response on to the client that requested it. */ - send(sock, ibuf.buf, ibuf.len, 0); - /* Reinitialize the receive buffer. */ - memset(&ibuf, 0, sizeof(ibuf)); - } - } - - /* listening socket */ - else if (pollfds[i].fd == lsock) { - /* Accept incoming connection. */ - dsock = accept(lsock, NULL, NULL); - if (ret == -1) { - perror("accept"); - exit(EXIT_FAILURE); - } - poll_add(dsock); -#ifdef DEBUG - printf("listening socket accept data socket %d\n", dsock); -#endif - } - - /* client data socket */ - else { - const uint8_t * const limit = obuf.buf + HAL_RPC_MAX_PKT_SIZE; - /* Get the client's rpc request packet. */ - obuf.len = recv(pollfds[i].fd, obuf.buf, HAL_RPC_MAX_PKT_SIZE, 0); -#ifdef DEBUG - printf("data socket %d received request:\n", pollfds[i].fd); - hexdump(obuf.buf, obuf.len); -#endif - - /* Fill in the client handle arg - first field after opcode. */ - uint8_t *bufptr = obuf.buf + 4; - hal_xdr_encode_int(&bufptr, limit, pollfds[i].fd); - - if (obuf.len > 0) { -#ifdef DEBUG - printf("passing to serial port:\n"); - hexdump(obuf.buf, obuf.len); -#endif - /* Pass it on to the serial port. */ - hal_slip_send(obuf.buf, obuf.len); - } - else { -#ifdef DEBUG - printf("closing data socket\n"); -#endif - /* Client has closed the socket. */ - close(pollfds[i].fd); - poll_remove(pollfds[i].fd); - } - /* Reinitialize the transmit buffer. */ - memset(&obuf, 0, sizeof(obuf)); - } - } - } - } - - /*NOTREACHED*/ - exit(EXIT_SUCCESS); -} @@ -1010,7 +1010,8 @@ hal_error_t hal_ecdsa_key_gen(const hal_core_t *core, if ((err = point_pick_random(curve, key->d, key->Q)) != HAL_OK) return err; - assert(point_is_on_curve(key->Q, curve)); + if (!point_is_on_curve(key->Q, curve)) + return HAL_ERROR_KEY_NOT_ON_CURVE; *key_ = key; return HAL_OK; @@ -1668,7 +1669,8 @@ hal_error_t hal_ecdsa_sign(const hal_core_t *core, if ((err = point_pick_random(curve, k, R)) != HAL_OK) goto fail; - assert(point_is_on_curve(R, curve)); + if (!point_is_on_curve(R, curve)) + lose(HAL_ERROR_IMPOSSIBLE); if (fp_mod(R->x, n, r) != FP_OKAY) lose(HAL_ERROR_IMPOSSIBLE); diff --git a/hal_internal.h b/hal_internal.h index a8f88e2..40a600c 100644 --- a/hal_internal.h +++ b/hal_internal.h @@ -90,6 +90,15 @@ extern void *hal_allocate_static_memory(const size_t size); #define HAL_MAX_HASH_DIGEST_LENGTH SHA512_DIGEST_LEN /* + * Locks and critical sections. + */ + +extern void hal_critical_section_start(void); +extern void hal_critical_section_end(void); +extern void hal_ks_lock(void); +extern void hal_ks_unlock(void); + +/* * Dispatch structures for RPC implementation. * * The breakdown of which functions go into which dispatch vectors is @@ -880,7 +889,7 @@ typedef enum { */ #ifndef HAL_CLIENT_DAEMON_DEFAULT_SOCKET_NAME -#define HAL_CLIENT_DAEMON_DEFAULT_SOCKET_NAME "/tmp/cryptech_rpcd.socket" +#define HAL_CLIENT_DAEMON_DEFAULT_SOCKET_NAME "/tmp/.cryptech_muxd.rpc" #endif /* diff --git a/ks_attribute.c b/ks_attribute.c index 92e450d..ec674f5 100644 --- a/ks_attribute.c +++ b/ks_attribute.c @@ -120,11 +120,18 @@ hal_error_t hal_ks_attribute_delete(uint8_t *bytes, const size_t bytes_len, if (bytes == NULL || attributes == NULL || attributes_len == NULL || total_len == NULL) return HAL_ERROR_BAD_ARGUMENTS; + /* + * Search for attribute by type. Note that there can be only one + * attribute of any given type. + */ + int i = 0; while (i < *attributes_len && attributes[i].type != type) i++; + /* If not found, great, it's already deleted from the key. */ + if (i == *attributes_len) return HAL_OK; @@ -152,6 +159,8 @@ hal_error_t hal_ks_attribute_insert(uint8_t *bytes, const size_t bytes_len, total_len == NULL || value == NULL) return HAL_ERROR_BAD_ARGUMENTS; + /* Delete the existing attribute value (if present), then write the new value. */ + hal_error_t err = hal_ks_attribute_delete(bytes, bytes_len, attributes, attributes_len, total_len, type); @@ -33,6 +33,15 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +/* + * This keystore driver operates over bare flash, versus over a flash file + * system or flash translation layer. The block size is large enough to + * hold an AES-keywrapped 4096-bit RSA key. Any remaining space in the key + * block may be used to store attributes (opaque TLV blobs). If the + * attributes overflow the key block, additional blocks may be added, but + * no attribute may exceed the block size. + */ + #include <stddef.h> #include <string.h> #include <assert.h> @@ -312,7 +321,7 @@ static hal_crc32_t calculate_block_crc(const flash_block_t * const block) } /* - * Calculate block offset. + * Calculate offset of the block in the flash address space. */ static inline uint32_t block_offset(const unsigned blockno) @@ -446,7 +455,7 @@ static hal_error_t block_erase(const unsigned blockno) return HAL_ERROR_IMPOSSIBLE; /* Sigh, magic numeric return codes */ - if (keystore_erase_subsectors(blockno, blockno) != 1) + if (keystore_erase_subsector(blockno) != 1) return HAL_ERROR_KEYSTORE_ACCESS; return HAL_OK; @@ -537,6 +546,11 @@ static hal_error_t block_update(const unsigned b1, flash_block_t *block, cache_mark_used(block, b2); + /* + * Erase the first block in the free list. In case of restart, this + * puts the block back at the head of the free list. + */ + return block_erase_maybe(db.ksi.index[db.ksi.used]); } @@ -565,6 +579,10 @@ static inline void *gnaw(uint8_t **mem, size_t *len, const size_t size) static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc) { + hal_error_t err = HAL_OK; + + hal_ks_lock(); + /* * Initialize the in-memory database. */ @@ -575,10 +593,18 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc sizeof(*db.ksi.names) * NUM_FLASH_BLOCKS + sizeof(*db.cache) * KS_FLASH_CACHE_SIZE); + /* + * This is done as a single large allocation, rather than 3 smaller + * allocations, to make it atomic - we need all 3, so either all + * succeed or all fail. + */ + uint8_t *mem = hal_allocate_static_memory(len); - if (mem == NULL) - return HAL_ERROR_ALLOCATION_FAILURE; + if (mem == NULL) { + err = HAL_ERROR_ALLOCATION_FAILURE; + goto done; + } memset(&db, 0, sizeof(db)); memset(mem, 0, len); @@ -597,8 +623,10 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc db.ksi.used = 0; - if (db.ksi.index == NULL || db.ksi.names == NULL || db.cache == NULL) - return HAL_ERROR_IMPOSSIBLE; + if (db.ksi.index == NULL || db.ksi.names == NULL || db.cache == NULL) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } for (int i = 0; i < KS_FLASH_CACHE_SIZE; i++) db.cache[i].blockno = ~0; @@ -613,11 +641,12 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc flash_block_status_t block_status[NUM_FLASH_BLOCKS]; flash_block_t *block = cache_pick_lru(); int first_erased = -1; - hal_error_t err; uint16_t n = 0; - if (block == NULL) - return HAL_ERROR_IMPOSSIBLE; + if (block == NULL) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } for (int i = 0; i < NUM_FLASH_BLOCKS; i++) { @@ -625,7 +654,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc * Read one block. If the CRC is bad or the block type is * unknown, it's old data we don't understand, something we were * writing when we crashed, or bad flash; in any of these cases, - * we want the block to ends up near the end of the free list. + * we want the block to end up near the end of the free list. */ err = block_read(i, block); @@ -637,7 +666,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc block_types[i] = block_get_type(block); else - return err; + goto done; switch (block_types[i]) { case BLOCK_TYPE_KEY: @@ -718,7 +747,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc */ if ((err = hal_ks_index_setup(&db.ksi)) != HAL_OK) - return err; + goto done; /* * We might want to call hal_ks_index_fsck() here, if we can figure @@ -733,20 +762,22 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc * * For any tombstone we find, we start by looking for all the blocks * with a matching UUID, then see what valid sequences we can - * construct from what we found. + * construct from what we found. This basically works in reverse of + * the update sequence in ks_set_attributes(). * * If we can construct a valid sequence of live blocks, the complete - * update was written out, and we just need to zero the tombstones. + * update was written out, and we just need to finish zeroing the + * tombstones. * * Otherwise, if we can construct a complete sequence of tombstone * blocks, the update failed before it was completely written, so we * have to zero the incomplete sequence of live blocks then restore - * from the tombstones. + * the tombstones. * * Otherwise, if the live and tombstone blocks taken together form a * valid sequence, the update failed while deprecating the old live - * blocks, and the update itself was not written, so we need to - * restore the tombstones and leave the live blocks alone. + * blocks, and none of the new data was written, so we need to restore + * the tombstones and leave the live blocks alone. * * If none of the above applies, we don't understand what happened, * which is a symptom of either a bug or a hardware failure more @@ -764,13 +795,27 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc int where = -1; if ((err = hal_ks_index_find_range(&db.ksi, &name, 0, &n_blocks, NULL, &where, 0)) != HAL_OK) - return err; + goto done; + + /* + * hal_ks_index_find_range does a binary search, not a linear search, + * so it may not return the first instance of a block with the given + * name and chunk=0. Search backwards to make sure we have all chunks. + */ while (where > 0 && !hal_uuid_cmp(&name, &db.ksi.names[db.ksi.index[where - 1]].name)) { where--; n_blocks++; } + /* + * Rather than calling hal_ks_index_find_range with an array pointer + * to get the list of matching blocks (because of the binary search + * issue), we're going to fondle the index directly. This is really + * not something to do in regular code, but this is error-recovery + * code. + */ + int live_ok = 1, tomb_ok = 1, join_ok = 1; unsigned n_live = 0, n_tomb = 0; unsigned i_live = 0, i_tomb = 0; @@ -778,9 +823,9 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc for (int j = 0; j < n_blocks; j++) { unsigned b = db.ksi.index[where + j]; switch (block_status[b]) { - case BLOCK_STATUS_LIVE: n_live++; break; - case BLOCK_STATUS_TOMBSTONE: n_tomb++; break; - default: return HAL_ERROR_IMPOSSIBLE; + case BLOCK_STATUS_LIVE: n_live++; break; + case BLOCK_STATUS_TOMBSTONE: n_tomb++; break; + default: err = HAL_ERROR_IMPOSSIBLE; goto done; } } @@ -790,7 +835,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc unsigned b = db.ksi.index[where + j]; if ((err = block_read(b, block)) != HAL_OK) - return err; + goto done; join_ok &= block->header.this_chunk == j && block->header.total_chunks == n_blocks; @@ -804,18 +849,27 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc tomb_ok &= block->header.this_chunk == i_tomb++ && block->header.total_chunks == n_tomb; break; default: - return HAL_ERROR_IMPOSSIBLE; + err = HAL_ERROR_IMPOSSIBLE; + goto done; } } - if (!live_ok && !tomb_ok && !join_ok) - return HAL_ERROR_KEYSTORE_LOST_DATA; + if (!live_ok && !tomb_ok && !join_ok) { + err = HAL_ERROR_KEYSTORE_LOST_DATA; + goto done; + } + + /* + * If live_ok or tomb_ok, we have to zero out some blocks, and adjust + * the index. Again, don't fondle the index directly, outside of error + * recovery. + */ if (live_ok) { for (int j = 0; j < n_tomb; j++) { const unsigned b = tomb_blocks[j]; if ((err = block_zero(b)) != HAL_OK) - return err; + goto done; block_types[b] = BLOCK_TYPE_ZEROED; block_status[b] = BLOCK_STATUS_UNKNOWN; } @@ -825,7 +879,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc for (int j = 0; j < n_live; j++) { const unsigned b = live_blocks[j]; if ((err = block_zero(b)) != HAL_OK) - return err; + goto done; block_types[b] = BLOCK_TYPE_ZEROED; block_status[b] = BLOCK_STATUS_UNKNOWN; } @@ -849,23 +903,31 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc n_blocks = n_tomb; } + /* + * Restore tombstone blocks (tomb_ok or join_ok). + */ + for (int j = 0; j < n_blocks; j++) { int hint = where + j; unsigned b1 = db.ksi.index[hint], b2; if (block_status[b1] != BLOCK_STATUS_TOMBSTONE) continue; if ((err = block_read(b1, block)) != HAL_OK) - return err; + goto done; block->header.block_status = BLOCK_STATUS_LIVE; if ((err = hal_ks_index_replace(&db.ksi, &name, j, &b2, &hint)) != HAL_OK || (err = block_write(b2, block)) != HAL_OK) - return err; + goto done; block_types[b1] = BLOCK_TYPE_ZEROED; block_status[b1] = BLOCK_STATUS_UNKNOWN; block_status[b2] = BLOCK_STATUS_LIVE; } } + /* + * Fetch or create the PIN block. + */ + err = fetch_pin_block(NULL, &block); if (err == HAL_OK) { @@ -875,7 +937,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc } else if (err != HAL_ERROR_KEY_NOT_FOUND) - return err; + goto done; else { /* @@ -900,7 +962,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc block->pin.user_pin = db.user_pin; if ((err = hal_ks_index_add(&db.ksi, &pin_uuid, 0, &b, NULL)) != HAL_OK) - return err; + goto done; cache_mark_used(block, b); @@ -909,7 +971,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc cache_release(block); if (err != HAL_OK) - return err; + goto done; } /* @@ -918,7 +980,7 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc if (db.ksi.used < db.ksi.size && (err = block_erase_maybe(db.ksi.index[db.ksi.used])) != HAL_OK) - return err; + goto done; /* * And we're finally done. @@ -926,7 +988,11 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, const int alloc db.ks.driver = driver; - return HAL_OK; + err = HAL_OK; + + done: + hal_ks_unlock(); + return err; } static hal_error_t ks_shutdown(const hal_ks_driver_t * const driver) @@ -974,18 +1040,24 @@ static hal_error_t ks_store(hal_ks_t *ks, if (ks != &db.ks || slot == NULL || der == NULL || der_len == 0 || !acceptable_key_type(slot->type)) return HAL_ERROR_BAD_ARGUMENTS; - flash_block_t *block = cache_pick_lru(); - flash_key_block_t *k = &block->key; + hal_error_t err = HAL_OK; + flash_block_t *block; + flash_key_block_t *k; uint8_t kek[KEK_LENGTH]; size_t kek_len; - hal_error_t err; unsigned b; - if (block == NULL) - return HAL_ERROR_IMPOSSIBLE; + hal_ks_lock(); + + if ((block = cache_pick_lru()) == NULL) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } + + k = &block->key; if ((err = hal_ks_index_add(&db.ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; + goto done; cache_mark_used(block, b); @@ -1003,18 +1075,29 @@ static hal_error_t ks_store(hal_ks_t *ks, k->der_len = SIZEOF_FLASH_KEY_BLOCK_DER; k->attributes_len = 0; - if ((err = hal_mkm_get_kek(kek, &kek_len, sizeof(kek))) == HAL_OK) + if (db.ksi.used < db.ksi.size) + err = block_erase_maybe(db.ksi.index[db.ksi.used]); + + if (err == HAL_OK) + err = hal_mkm_get_kek(kek, &kek_len, sizeof(kek)); + + if (err == HAL_OK) err = hal_aes_keywrap(NULL, kek, kek_len, der, der_len, k->der, &k->der_len); memset(kek, 0, sizeof(kek)); - if (err == HAL_OK && - (err = block_write(b, block)) == HAL_OK) - return HAL_OK; + if (err == HAL_OK) + err = block_write(b, block); + + if (err == HAL_OK) + goto done; memset(block, 0, sizeof(*block)); cache_release(block); (void) hal_ks_index_delete(&db.ksi, &slot->name, 0, NULL, &slot->hint); + + done: + hal_ks_unlock(); return err; } @@ -1025,16 +1108,20 @@ static hal_error_t ks_fetch(hal_ks_t *ks, if (ks != &db.ks || slot == NULL) return HAL_ERROR_BAD_ARGUMENTS; + hal_error_t err = HAL_OK; flash_block_t *block; - hal_error_t err; unsigned b; + hal_ks_lock(); + if ((err = hal_ks_index_find(&db.ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK || (err = block_read_cached(b, &block)) != HAL_OK) - return err; + goto done; - if (block_get_type(block) != BLOCK_TYPE_KEY) - return HAL_ERROR_KEYSTORE_WRONG_BLOCK_TYPE; /* HAL_ERROR_KEY_NOT_FOUND */ + if (block_get_type(block) != BLOCK_TYPE_KEY) { + err = HAL_ERROR_KEYSTORE_WRONG_BLOCK_TYPE; /* HAL_ERROR_KEY_NOT_FOUND */ + goto done; + } cache_mark_used(block, b); @@ -1062,12 +1149,11 @@ static hal_error_t ks_fetch(hal_ks_t *ks, err = hal_aes_keyunwrap(NULL, kek, kek_len, k->der, k->der_len, der, der_len); memset(kek, 0, sizeof(kek)); - - if (err != HAL_OK) - return err; } - return HAL_OK; + done: + hal_ks_unlock(); + return err; } static hal_error_t ks_delete(hal_ks_t *ks, @@ -1076,25 +1162,50 @@ static hal_error_t ks_delete(hal_ks_t *ks, if (ks != &db.ks || slot == NULL) return HAL_ERROR_BAD_ARGUMENTS; - hal_error_t err; + hal_error_t err = HAL_OK; unsigned n; - if ((err = hal_ks_index_delete_range(&db.ksi, &slot->name, 0, &n, NULL, &slot->hint)) != HAL_OK) - return err; + hal_ks_lock(); - unsigned b[n]; + { + /* + * Get the count of blocks to delete. + */ - if ((err = hal_ks_index_delete_range(&db.ksi, &slot->name, n, NULL, b, &slot->hint)) != HAL_OK) - return err; + if ((err = hal_ks_index_delete_range(&db.ksi, &slot->name, 0, &n, NULL, &slot->hint)) != HAL_OK) + goto done; - for (int i = 0; i < n; i++) - cache_release(cache_find_block(b[i])); + /* + * Then delete them. + */ - for (int i = 0; i < n; i++) - if ((err = block_zero(b[i])) != HAL_OK) - return err; + unsigned b[n]; - return block_erase_maybe(db.ksi.index[db.ksi.used]); + if ((err = hal_ks_index_delete_range(&db.ksi, &slot->name, n, NULL, b, &slot->hint)) != HAL_OK) + goto done; + + for (int i = 0; i < n; i++) + cache_release(cache_find_block(b[i])); + + /* + * Zero the blocks, to mark them as recently used. + */ + + for (int i = 0; i < n; i++) + if ((err = block_zero(b[i])) != HAL_OK) + goto done; + + /* + * Erase the first block in the free list. In case of restart, this + * puts the block back at the head of the free list. + */ + + err = block_erase_maybe(db.ksi.index[db.ksi.used]); + } + + done: + hal_ks_unlock(); + return err; } static inline hal_error_t locate_attributes(flash_block_t *block, const unsigned chunk, @@ -1141,11 +1252,13 @@ static hal_error_t ks_match(hal_ks_t *ks, return HAL_ERROR_BAD_ARGUMENTS; uint8_t need_attr[attributes_len > 0 ? attributes_len : 1]; + hal_error_t err = HAL_OK; flash_block_t *block; int possible = 0; - hal_error_t err; int i = -1; + hal_ks_lock(); + *result_len = 0; err = hal_ks_index_find(&db.ksi, previous_uuid, 0, NULL, &i); @@ -1153,7 +1266,7 @@ static hal_error_t ks_match(hal_ks_t *ks, if (err == HAL_ERROR_KEY_NOT_FOUND) i--; else if (err != HAL_OK) - return err; + goto done; while (*result_len < result_max && ++i < db.ksi.used) { @@ -1166,7 +1279,7 @@ static hal_error_t ks_match(hal_ks_t *ks, continue; if ((err = block_read_cached(b, &block)) != HAL_OK) - return err; + goto done; if (db.ksi.names[b].chunk == 0) { memset(need_attr, 1, sizeof(need_attr)); @@ -1184,13 +1297,13 @@ static hal_error_t ks_match(hal_ks_t *ks, if ((err = locate_attributes(block, db.ksi.names[b].chunk, &bytes, &bytes_len, &attrs_len)) != HAL_OK) - return err; + goto done; if (*attrs_len > 0) { hal_pkey_attribute_t attrs[*attrs_len]; if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, *attrs_len, NULL)) != HAL_OK) - return err; + goto done; for (int j = 0; possible && j < attributes_len; j++) { @@ -1220,7 +1333,11 @@ static hal_error_t ks_match(hal_ks_t *ks, possible = 0; } - return HAL_OK; + err = HAL_OK; + + done: + hal_ks_unlock(); + return err; } /* @@ -1259,419 +1376,453 @@ static hal_error_t ks_set_attributes(hal_ks_t *ks, */ unsigned updated_attributes_len = attributes_len; + hal_error_t err = HAL_OK; flash_block_t *block; unsigned chunk = 0; - hal_error_t err; unsigned b; - do { - int hint = slot->hint + chunk; + hal_ks_lock(); - if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || - (err = block_read_cached(b, &block)) != HAL_OK) - return err; + { - if (block->header.this_chunk != chunk) - return HAL_ERROR_IMPOSSIBLE; + do { + int hint = slot->hint + chunk; - cache_mark_used(block, b); + if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || + (err = block_read_cached(b, &block)) != HAL_OK) + goto done; - if (chunk == 0) - slot->hint = hint; + if (block->header.this_chunk != chunk) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - uint8_t *bytes = NULL; - size_t bytes_len = 0; - unsigned *attrs_len; + cache_mark_used(block, b); - if ((err = locate_attributes(block, chunk, &bytes, &bytes_len, &attrs_len)) != HAL_OK) - return err; + if (chunk == 0) + slot->hint = hint; + + uint8_t *bytes = NULL; + size_t bytes_len = 0; + unsigned *attrs_len; - updated_attributes_len += *attrs_len; + if ((err = locate_attributes(block, chunk, &bytes, &bytes_len, &attrs_len)) != HAL_OK) + goto done; + + updated_attributes_len += *attrs_len; #if KS_SET_ATTRIBUTES_SINGLE_BLOCK_UPDATE_FAST_PATH - hal_pkey_attribute_t attrs[*attrs_len + attributes_len]; - size_t total; + hal_pkey_attribute_t attrs[*attrs_len + attributes_len]; + size_t total; - if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, *attrs_len, &total)) != HAL_OK) - return err; + if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, *attrs_len, &total)) != HAL_OK) + goto done; - for (int i = 0; err == HAL_OK && i < attributes_len; i++) - if (attributes[i].length == HAL_PKEY_ATTRIBUTE_NIL) - err = hal_ks_attribute_delete(bytes, bytes_len, attrs, attrs_len, &total, - attributes[i].type); - else - err = hal_ks_attribute_insert(bytes, bytes_len, attrs, attrs_len, &total, - attributes[i].type, - attributes[i].value, - attributes[i].length); + for (int i = 0; err == HAL_OK && i < attributes_len; i++) + if (attributes[i].length == HAL_PKEY_ATTRIBUTE_NIL) + err = hal_ks_attribute_delete(bytes, bytes_len, attrs, attrs_len, &total, + attributes[i].type); + else + err = hal_ks_attribute_insert(bytes, bytes_len, attrs, attrs_len, &total, + attributes[i].type, + attributes[i].value, + attributes[i].length); - if (err != HAL_OK) - cache_release(block); + if (err != HAL_OK) + cache_release(block); - if (err == HAL_ERROR_RESULT_TOO_LONG) - continue; + if (err == HAL_ERROR_RESULT_TOO_LONG) + continue; - if (err != HAL_OK) - return err; + if (err == HAL_OK) + err = block_update(b, block, &slot->name, chunk, &hint); - return block_update(b, block, &slot->name, chunk, &hint); + goto done; #endif /* KS_SET_ATTRIBUTES_SINGLE_BLOCK_UPDATE_FAST_PATH */ - } while (++chunk < block->header.total_chunks); + } while (++chunk < block->header.total_chunks); - /* - * If we get here, we're on the slow path, which requires rewriting - * all the chunks in this object but which can also add or remove - * chunks from this object. We need to keep track of all the old - * chunks so we can zero them at the end, and because we can't zero - * them until we've written out the new chunks, we need enough free - * blocks to hold all the new chunks. - * - * Calculating all of this is extremely tedious, but flash writes - * are so much more expensive than anything else we do here that - * it's almost certainly worth it. - * - * We don't need the attribute values to compute the sizes, just the - * attribute sizes, so we scan all the existing blocks, build up a - * structure with the current attribute types and sizes, modify that - * according to our arguments, and compute the needed size. Once we - * have that, we can start rewriting existing blocks. We put all - * the new stuff at the end, which simplifies this slightly. - * - * In theory, this process never requires us to have more than two - * blocks in memory at the same time (source and destination when - * copying across chunk boundaries), but having enough cache buffers - * to keep the whole set in memory will almost certainly make this - * run faster. - */ + /* + * If we get here, we're on the slow path, which requires rewriting + * all the chunks in this object but which can also add or remove + * chunks from this object. We need to keep track of all the old + * chunks so we can zero them at the end, and because we can't zero + * them until we've written out the new chunks, we need enough free + * blocks to hold all the new chunks. + * + * Calculating all of this is extremely tedious, but flash writes + * are so much more expensive than anything else we do here that + * it's almost certainly worth it. + * + * We don't need the attribute values to compute the sizes, just the + * attribute sizes, so we scan all the existing blocks, build up a + * structure with the current attribute types and sizes, modify that + * according to our arguments, and compute the needed size. Once we + * have that, we can start rewriting existing blocks. We put all + * the new stuff at the end, which simplifies this slightly. + * + * In theory, this process never requires us to have more than two + * blocks in memory at the same time (source and destination when + * copying across chunk boundaries), but having enough cache buffers + * to keep the whole set in memory will almost certainly make this + * run faster. + */ - hal_pkey_attribute_t updated_attributes[updated_attributes_len]; - const unsigned total_chunks_old = block->header.total_chunks; - size_t bytes_available = 0; + hal_pkey_attribute_t updated_attributes[updated_attributes_len]; + const unsigned total_chunks_old = block->header.total_chunks; + size_t bytes_available = 0; - updated_attributes_len = 0; + updated_attributes_len = 0; - /* - * Phase 0.1: Walk the old chunks to populate updated_attributes[]. - * This also initializes bytes_available, since we can only get that - * by reading old chunk zero. - */ + /* + * Phase 0.1: Walk the old chunks to populate updated_attributes[]. + * This also initializes bytes_available, since we can only get that + * by reading old chunk zero. + */ - for (chunk = 0; chunk < total_chunks_old; chunk++) { - int hint = slot->hint + chunk; + for (chunk = 0; chunk < total_chunks_old; chunk++) { + int hint = slot->hint + chunk; - if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || - (err = block_read_cached(b, &block)) != HAL_OK) - return err; + if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || + (err = block_read_cached(b, &block)) != HAL_OK) + goto done; - if (block->header.this_chunk != chunk) - return HAL_ERROR_IMPOSSIBLE; + if (block->header.this_chunk != chunk) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - cache_mark_used(block, b); + cache_mark_used(block, b); - uint8_t *bytes = NULL; - size_t bytes_len = 0; - unsigned *attrs_len; + uint8_t *bytes = NULL; + size_t bytes_len = 0; + unsigned *attrs_len; - if ((err = locate_attributes(block, chunk, &bytes, &bytes_len, &attrs_len)) != HAL_OK) - return err; + if ((err = locate_attributes(block, chunk, &bytes, &bytes_len, &attrs_len)) != HAL_OK) + goto done; - hal_pkey_attribute_t attrs[*attrs_len]; - size_t total; + hal_pkey_attribute_t attrs[*attrs_len]; + size_t total; - if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, *attrs_len, &total)) != HAL_OK) - return err; + if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, *attrs_len, &total)) != HAL_OK) + goto done; - if (chunk == 0) - bytes_available = bytes_len; + if (chunk == 0) + bytes_available = bytes_len; - for (int i = 0; i < *attrs_len; i++) { + for (int i = 0; i < *attrs_len; i++) { - if (updated_attributes_len >= sizeof(updated_attributes)/sizeof(*updated_attributes)) - return HAL_ERROR_IMPOSSIBLE; + if (updated_attributes_len >= sizeof(updated_attributes)/sizeof(*updated_attributes)) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - updated_attributes[updated_attributes_len].type = attrs[i].type; - updated_attributes[updated_attributes_len].length = attrs[i].length; - updated_attributes[updated_attributes_len].value = NULL; - updated_attributes_len++; + updated_attributes[updated_attributes_len].type = attrs[i].type; + updated_attributes[updated_attributes_len].length = attrs[i].length; + updated_attributes[updated_attributes_len].value = NULL; + updated_attributes_len++; + } } - } - /* - * Phase 0.2: Merge new attributes into updated_attributes[]. - */ + /* + * Phase 0.2: Merge new attributes into updated_attributes[]. + * For each new attribute type, mark any existing attributes of that + * type for deletion. Append new attributes to updated_attributes[]. + */ - for (int i = 0; i < attributes_len; i++) { + for (int i = 0; i < attributes_len; i++) { - for (int j = 0; j < updated_attributes_len; j++) - if (updated_attributes[j].type == attributes[i].type) - updated_attributes[j].length = HAL_PKEY_ATTRIBUTE_NIL; + for (int j = 0; j < updated_attributes_len; j++) + if (updated_attributes[j].type == attributes[i].type) + updated_attributes[j].length = HAL_PKEY_ATTRIBUTE_NIL; - if (updated_attributes_len >= sizeof(updated_attributes)/sizeof(*updated_attributes)) - return HAL_ERROR_IMPOSSIBLE; + if (updated_attributes_len >= sizeof(updated_attributes)/sizeof(*updated_attributes)) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - updated_attributes[updated_attributes_len].type = attributes[i].type; - updated_attributes[updated_attributes_len].length = attributes[i].length; - updated_attributes[updated_attributes_len].value = attributes[i].value; - updated_attributes_len++; - } + updated_attributes[updated_attributes_len].type = attributes[i].type; + updated_attributes[updated_attributes_len].length = attributes[i].length; + updated_attributes[updated_attributes_len].value = attributes[i].value; + updated_attributes_len++; + } - /* - * Phase 0.3: Prune trailing deletion actions: we don't need them to - * maintain synchronization with existing attributes, and doing so - * simplifies logic for updating the final new chunk. - */ + /* + * Phase 0.3: Prune trailing deletion actions: we don't need them to + * maintain synchronization with existing attributes, and doing so + * simplifies logic for updating the final new chunk. + */ - while (updated_attributes_len > 0 && - updated_attributes[updated_attributes_len - 1].length == HAL_PKEY_ATTRIBUTE_NIL) - --updated_attributes_len; + while (updated_attributes_len > 0 && + updated_attributes[updated_attributes_len - 1].length == HAL_PKEY_ATTRIBUTE_NIL) + --updated_attributes_len; - /* - * Phase 0.4: Figure out how many chunks all this will occupy. - */ + /* + * Phase 0.4: Figure out how many chunks all this will occupy. + */ - chunk = 0; + chunk = 0; - for (int i = 0; i < updated_attributes_len; i++) { + for (int i = 0; i < updated_attributes_len; i++) { - if (updated_attributes[i].length == HAL_PKEY_ATTRIBUTE_NIL) - continue; + if (updated_attributes[i].length == HAL_PKEY_ATTRIBUTE_NIL) + continue; - const size_t needed = hal_ks_attribute_header_size + updated_attributes[i].length; + const size_t needed = hal_ks_attribute_header_size + updated_attributes[i].length; - if (needed > bytes_available) { - bytes_available = SIZEOF_FLASH_ATTRIBUTE_BLOCK_ATTRIBUTES; - chunk++; + if (needed > bytes_available) { + bytes_available = SIZEOF_FLASH_ATTRIBUTE_BLOCK_ATTRIBUTES; + chunk++; + } + + if (needed > bytes_available) { + err = HAL_ERROR_RESULT_TOO_LONG; + goto done; + } + + bytes_available -= needed; } - if (needed > bytes_available) - return HAL_ERROR_RESULT_TOO_LONG; + const unsigned total_chunks_new = chunk + 1; - bytes_available -= needed; - } + /* + * If there aren't enough free blocks, give up now, before changing anything. + */ - const unsigned total_chunks_new = chunk + 1; + if (db.ksi.used + total_chunks_new > db.ksi.size) { + err = HAL_ERROR_NO_KEY_INDEX_SLOTS; + goto done; + } - /* - * If there aren't enough free blocks, give up now, before changing anything. - */ + /* + * Phase 1: Deprecate all the old chunks, remember where they were. + */ - if (db.ksi.used + total_chunks_new > db.ksi.size) - return HAL_ERROR_NO_KEY_INDEX_SLOTS; + unsigned old_blocks[total_chunks_old]; - /* - * Phase 1: Deprecate all the old chunks, remember where they were. - */ + for (chunk = 0; chunk < total_chunks_old; chunk++) { + int hint = slot->hint + chunk; + if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || + (err = block_deprecate(b)) != HAL_OK) + goto done; + old_blocks[chunk] = b; + } - unsigned old_blocks[total_chunks_old]; + /* + * Phase 2: Write new chunks, copying attributes from old chunks or + * from attributes[], as needed. + */ - for (chunk = 0; chunk < total_chunks_old; chunk++) { - int hint = slot->hint + chunk; - if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || - (err = block_deprecate(b)) != HAL_OK) - return err; - old_blocks[chunk] = b; - } + { + hal_pkey_attribute_t old_attrs[updated_attributes_len], new_attrs[updated_attributes_len]; + unsigned *old_attrs_len = NULL, *new_attrs_len = NULL; + flash_block_t *old_block = NULL, *new_block = NULL; + uint8_t *old_bytes = NULL, *new_bytes = NULL; + size_t old_bytes_len = 0, new_bytes_len = 0; + unsigned old_chunk = 0, new_chunk = 0; + size_t old_total = 0, new_total = 0; - /* - * Phase 2: Write new chunks, copying attributes from old chunks or - * from attributes[], as needed. - */ + int updated_attributes_i = 0, old_attrs_i = 0; - { - hal_pkey_attribute_t old_attrs[updated_attributes_len], new_attrs[updated_attributes_len]; - unsigned *old_attrs_len = NULL, *new_attrs_len = NULL; - flash_block_t *old_block = NULL, *new_block = NULL; - uint8_t *old_bytes = NULL, *new_bytes = NULL; - size_t old_bytes_len = 0, new_bytes_len = 0; - unsigned old_chunk = 0, new_chunk = 0; - size_t old_total = 0, new_total = 0; + uint32_t new_attr_type; + size_t new_attr_length; + const uint8_t *new_attr_value; - int updated_attributes_i = 0, old_attrs_i = 0; + while (updated_attributes_i < updated_attributes_len) { - uint32_t new_attr_type; - size_t new_attr_length; - const uint8_t *new_attr_value; + if (old_chunk >= total_chunks_old || new_chunk >= total_chunks_new) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - while (updated_attributes_i < updated_attributes_len) { + /* + * If we've gotten as far as new data that comes from + * attributes[], we have it in hand and can just copy it. + */ - if (old_chunk >= total_chunks_old || new_chunk >= total_chunks_new) - return HAL_ERROR_IMPOSSIBLE; + if (updated_attributes_len - updated_attributes_i <= attributes_len) { + new_attr_type = updated_attributes[updated_attributes_i].type; + new_attr_length = updated_attributes[updated_attributes_i].length; + new_attr_value = updated_attributes[updated_attributes_i].value; + } - /* - * If we've gotten as far as new data that comes from - * attributes[], we have it in hand and can just copy it. - */ + /* + * Otherwise, we have to read it from an old block, which may in + * turn require reading in the next old block. + */ - if (updated_attributes_len - updated_attributes_i <= attributes_len) { - new_attr_type = updated_attributes[updated_attributes_i].type; - new_attr_length = updated_attributes[updated_attributes_i].length; - new_attr_value = updated_attributes[updated_attributes_i].value; - } + else { - /* - * Otherwise, we have to read it from an old block, which may in - * turn require reading in the next old block. - */ + if (old_block == NULL) { - else { + if ((err = block_read_cached(old_blocks[old_chunk], &old_block)) != HAL_OK) + goto done; - if (old_block == NULL) { + if (old_block->header.this_chunk != old_chunk) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - if ((err = block_read_cached(old_blocks[old_chunk], &old_block)) != HAL_OK) - return err; + if ((err = locate_attributes(old_block, old_chunk, + &old_bytes, &old_bytes_len, &old_attrs_len)) != HAL_OK || + (err = hal_ks_attribute_scan(old_bytes, old_bytes_len, + old_attrs, *old_attrs_len, &old_total)) != HAL_OK) + goto done; - if (old_block->header.this_chunk != old_chunk) - return HAL_ERROR_IMPOSSIBLE; + old_attrs_i = 0; + } - if ((err = locate_attributes(old_block, old_chunk, - &old_bytes, &old_bytes_len, &old_attrs_len)) != HAL_OK || - (err = hal_ks_attribute_scan(old_bytes, old_bytes_len, - old_attrs, *old_attrs_len, &old_total)) != HAL_OK) - return err; + if (old_attrs_i >= *old_attrs_len) { + old_chunk++; + old_block = NULL; + continue; + } + + new_attr_type = old_attrs[old_attrs_i].type; + new_attr_length = old_attrs[old_attrs_i].length; + new_attr_value = old_attrs[old_attrs_i].value; + + if (new_attr_type != updated_attributes[updated_attributes_i].type) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } - old_attrs_i = 0; + old_attrs_i++; } - if (old_attrs_i >= *old_attrs_len) { - old_chunk++; - old_block = NULL; - continue; + /* + * Unless this is a deletion, we should have something to write. + */ + + if (new_attr_length != HAL_PKEY_ATTRIBUTE_NIL && new_attr_value == NULL) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; } - new_attr_type = old_attrs[old_attrs_i].type; - new_attr_length = old_attrs[old_attrs_i].length; - new_attr_value = old_attrs[old_attrs_i].value; + /* + * Initialize the new block if necessary. If it's the new chunk + * zero, we need to copy all the non-attribute data from the old + * chunk zero; otherwise, it's a new empty attribute block. + */ + + if (new_block == NULL) { + + new_block = cache_pick_lru(); + memset(new_block, 0xFF, sizeof(*new_block)); + + if (new_chunk == 0) { + flash_block_t *tmp_block; + if ((err = block_read_cached(old_blocks[0], &tmp_block)) != HAL_OK) + goto done; + if (tmp_block->header.this_chunk != 0) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } + new_block->header.block_type = BLOCK_TYPE_KEY; + new_block->key.name = slot->name; + new_block->key.type = tmp_block->key.type; + new_block->key.curve = tmp_block->key.curve; + new_block->key.flags = tmp_block->key.flags; + new_block->key.der_len = tmp_block->key.der_len; + new_block->key.attributes_len = 0; + memcpy(new_block->key.der, tmp_block->key.der, tmp_block->key.der_len); + } + else { + new_block->header.block_type = BLOCK_TYPE_ATTR; + new_block->attr.name = slot->name; + new_block->attr.attributes_len = 0; + } - if (new_attr_type != updated_attributes[updated_attributes_i].type) - return HAL_ERROR_IMPOSSIBLE; + new_block->header.block_status = BLOCK_STATUS_LIVE; + new_block->header.total_chunks = total_chunks_new; + new_block->header.this_chunk = new_chunk; - old_attrs_i++; - } + if ((err = locate_attributes(new_block, new_chunk, + &new_bytes, &new_bytes_len, &new_attrs_len)) != HAL_OK) + goto done; - /* - * Unless this is a deletion, we should have something to write. - */ + new_total = 0; + } - if (new_attr_length != HAL_PKEY_ATTRIBUTE_NIL && new_attr_value == NULL) - return HAL_ERROR_IMPOSSIBLE; + /* + * After all that setup, we finally get to write the frelling attribute. + */ - /* - * Initialize the new block if necessary. If it's the new chunk - * zero, we need to copy all the non-attribute data from the old - * chunk zero; otherwise, it's a new empty attribute block. - */ + if (new_attr_length != HAL_PKEY_ATTRIBUTE_NIL) + err = hal_ks_attribute_insert(new_bytes, new_bytes_len, new_attrs, new_attrs_len, &new_total, + new_attr_type, new_attr_value, new_attr_length); - if (new_block == NULL) { - - new_block = cache_pick_lru(); - memset(new_block, 0xFF, sizeof(*new_block)); - - if (new_chunk == 0) { - flash_block_t *tmp_block; - if ((err = block_read_cached(old_blocks[0], &tmp_block)) != HAL_OK) - return err; - if (tmp_block->header.this_chunk != 0) - return HAL_ERROR_IMPOSSIBLE; - new_block->header.block_type = BLOCK_TYPE_KEY; - new_block->key.name = slot->name; - new_block->key.type = tmp_block->key.type; - new_block->key.curve = tmp_block->key.curve; - new_block->key.flags = tmp_block->key.flags; - new_block->key.der_len = tmp_block->key.der_len; - new_block->key.attributes_len = 0; - memcpy(new_block->key.der, tmp_block->key.der, tmp_block->key.der_len); - } - else { - new_block->header.block_type = BLOCK_TYPE_ATTR; - new_block->attr.name = slot->name; - new_block->attr.attributes_len = 0; - } + /* + * Figure out what to do next: immediately loop for next + * attribute, write current block, or bail out. + */ - new_block->header.block_status = BLOCK_STATUS_LIVE; - new_block->header.total_chunks = total_chunks_new; - new_block->header.this_chunk = new_chunk; + switch (err) { + case HAL_OK: + if (++updated_attributes_i < updated_attributes_len) + continue; + break; + case HAL_ERROR_RESULT_TOO_LONG: + if (new_chunk > 0 && new_attrs_len == 0) + goto done; + break; + default: + goto done; + } - if ((err = locate_attributes(new_block, new_chunk, - &new_bytes, &new_bytes_len, &new_attrs_len)) != HAL_OK) - return err; + /* + * If we get here, either the current new block is full or we + * finished the last block, so we need to write it out. + */ - new_total = 0; - } + int hint = slot->hint + new_chunk; - /* - * After all that setup, we finally get to write the frelling attribute. - */ + if (new_chunk < total_chunks_old) + err = hal_ks_index_replace(&db.ksi, &slot->name, new_chunk, &b, &hint); + else + err = hal_ks_index_add( &db.ksi, &slot->name, new_chunk, &b, &hint); - if (new_attr_length != HAL_PKEY_ATTRIBUTE_NIL) - err = hal_ks_attribute_insert(new_bytes, new_bytes_len, new_attrs, new_attrs_len, &new_total, - new_attr_type, new_attr_value, new_attr_length); + if (err != HAL_OK || (err = block_write(b, new_block)) != HAL_OK) + goto done; - /* - * Figure out what to do next: immediately loop for next - * attribute, write current block, or bail out. - */ + cache_mark_used(new_block, b); - switch (err) { - case HAL_OK: - if (++updated_attributes_i < updated_attributes_len) - continue; - break; - case HAL_ERROR_RESULT_TOO_LONG: - if (new_chunk > 0 && new_attrs_len == 0) - return err; - break; - default: - return err; + new_block = NULL; + new_chunk++; } /* - * If we get here, either the current new block is full or we - * finished the last block, so we need to write it out. + * If number of blocks shrank, we need to clear trailing entries from the index. */ - int hint = slot->hint + new_chunk; + for (old_chunk = total_chunks_new; old_chunk < total_chunks_old; old_chunk++) { + int hint = slot->hint + old_chunk; - if (new_chunk < total_chunks_old) - err = hal_ks_index_replace(&db.ksi, &slot->name, new_chunk, &b, &hint); - else - err = hal_ks_index_add( &db.ksi, &slot->name, new_chunk, &b, &hint); + err = hal_ks_index_delete(&db.ksi, &slot->name, old_chunk, NULL, &hint); - if (err != HAL_OK || (err = block_write(b, new_block)) != HAL_OK) - return err; - - cache_mark_used(new_block, b); + if (err != HAL_OK) + goto done; + } - new_block = NULL; - new_chunk++; } /* - * If number of blocks shrank, we need to clear trailing entries from the index. + * Phase 3: Zero the old chunks we deprecated in phase 1. */ - for (old_chunk = total_chunks_new; old_chunk < total_chunks_old; old_chunk++) { - int hint = slot->hint + old_chunk; + for (chunk = 0; chunk < total_chunks_old; chunk++) + if ((err = block_zero(old_blocks[chunk])) != HAL_OK) + goto done; - err = hal_ks_index_delete(&db.ksi, &slot->name, old_chunk, NULL, &hint); - - if (err != HAL_OK) - return err; - } + err = HAL_OK; } - /* - * Phase 3: Zero the old chunks we deprecated in phase 1. - */ - - for (chunk = 0; chunk < total_chunks_old; chunk++) - if ((err = block_zero(old_blocks[chunk])) != HAL_OK) - return err; - - return HAL_OK; + done: + hal_ks_unlock(); + return err; #warning What happens if something goes wrong partway through this awful mess? // We're left in a state with all the old blocks deprecated and @@ -1700,18 +1851,22 @@ static hal_error_t ks_get_attributes(hal_ks_t *ks, flash_block_t *block = NULL; unsigned chunk = 0; unsigned found = 0; - hal_error_t err; + hal_error_t err = HAL_OK; unsigned b; + hal_ks_lock(); + do { int hint = slot->hint + chunk; if ((err = hal_ks_index_find(&db.ksi, &slot->name, chunk, &b, &hint)) != HAL_OK || (err = block_read_cached(b, &block)) != HAL_OK) - return err; + goto done; - if (block->header.this_chunk != chunk) - return HAL_ERROR_IMPOSSIBLE; + if (block->header.this_chunk != chunk) { + err = HAL_ERROR_IMPOSSIBLE; + goto done; + } if (chunk == 0) slot->hint = hint; @@ -1723,7 +1878,7 @@ static hal_error_t ks_get_attributes(hal_ks_t *ks, unsigned *attrs_len; if ((err = locate_attributes(block, chunk, &bytes, &bytes_len, &attrs_len)) != HAL_OK) - return err; + goto done; if (*attrs_len == 0) continue; @@ -1731,7 +1886,7 @@ static hal_error_t ks_get_attributes(hal_ks_t *ks, hal_pkey_attribute_t attrs[*attrs_len]; if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, *attrs_len, NULL)) != HAL_OK) - return err; + goto done; for (int i = 0; i < attributes_len; i++) { @@ -1750,8 +1905,10 @@ static hal_error_t ks_get_attributes(hal_ks_t *ks, if (attributes_buffer_len == 0) continue; - if (attrs[j].length > attributes_buffer + attributes_buffer_len - abuf) - return HAL_ERROR_RESULT_TOO_LONG; + if (attrs[j].length > attributes_buffer + attributes_buffer_len - abuf) { + err = HAL_ERROR_RESULT_TOO_LONG; + goto done; + } memcpy(abuf, attrs[j].value, attrs[j].length); attributes[i].value = abuf; @@ -1761,9 +1918,13 @@ static hal_error_t ks_get_attributes(hal_ks_t *ks, } while (found < attributes_len && ++chunk < block->header.total_chunks); if (found < attributes_len && attributes_buffer_len > 0) - return HAL_ERROR_ATTRIBUTE_NOT_FOUND; + err = HAL_ERROR_ATTRIBUTE_NOT_FOUND; + else + err = HAL_OK; - return HAL_OK; + done: + hal_ks_unlock(); + return err; } const hal_ks_driver_t hal_ks_token_driver[1] = {{ @@ -1799,6 +1960,8 @@ void hal_ks_init_read_only_pins_only(void) unsigned b, best_seen = ~0; flash_block_t block[1]; + hal_ks_lock(); + for (b = 0; b < NUM_FLASH_BLOCKS; b++) { if (block_read(b, block) != HAL_OK || block_get_type(block) != BLOCK_TYPE_PIN) continue; @@ -1818,6 +1981,8 @@ void hal_ks_init_read_only_pins_only(void) db.wheel_pin = block->pin.wheel_pin; db.so_pin = block->pin.so_pin; db.user_pin = block->pin.user_pin; + + hal_ks_unlock(); } /* @@ -1830,14 +1995,20 @@ hal_error_t hal_get_pin(const hal_user_t user, if (pin == NULL) return HAL_ERROR_BAD_ARGUMENTS; + hal_error_t err = HAL_OK; + + hal_ks_lock(); + switch (user) { case HAL_USER_WHEEL: *pin = &db.wheel_pin; break; case HAL_USER_SO: *pin = &db.so_pin; break; case HAL_USER_NORMAL: *pin = &db.user_pin; break; - default: return HAL_ERROR_BAD_ARGUMENTS; + default: err = HAL_ERROR_BAD_ARGUMENTS; } - return HAL_OK; + hal_ks_unlock(); + + return err; } /* @@ -1904,8 +2075,10 @@ hal_error_t hal_set_pin(const hal_user_t user, hal_error_t err; unsigned b; + hal_ks_lock(); + if ((err = fetch_pin_block(&b, &block)) != HAL_OK) - return err; + goto done; flash_pin_block_t new_data = block->pin; hal_ks_pin_t *dp, *bp; @@ -1914,7 +2087,7 @@ hal_error_t hal_set_pin(const hal_user_t user, case HAL_USER_WHEEL: bp = &new_data.wheel_pin; dp = &db.wheel_pin; break; case HAL_USER_SO: bp = &new_data.so_pin; dp = &db.so_pin; break; case HAL_USER_NORMAL: bp = &new_data.user_pin; dp = &db.user_pin; break; - default: return HAL_ERROR_BAD_ARGUMENTS; + default: err = HAL_ERROR_BAD_ARGUMENTS; goto done; } const hal_ks_pin_t old_pin = *dp; @@ -1923,6 +2096,8 @@ hal_error_t hal_set_pin(const hal_user_t user, if ((err = update_pin_block(b, block, &new_data)) != HAL_OK) *dp = old_pin; + done: + hal_ks_unlock(); return err; } @@ -1951,16 +2126,20 @@ hal_error_t hal_mkm_flash_read(uint8_t *buf, const size_t len) hal_error_t err; unsigned b; + hal_ks_lock(); + if ((err = fetch_pin_block(&b, &block)) != HAL_OK) - return err; + goto done; if (block->pin.kek_set != FLASH_KEK_SET) - return HAL_ERROR_MASTERKEY_NOT_SET; + err = HAL_ERROR_MASTERKEY_NOT_SET; - if (buf != NULL) + else if (buf != NULL) memcpy(buf, block->pin.kek, len); - return HAL_OK; + done: + hal_ks_unlock(); + return err; } hal_error_t hal_mkm_flash_write(const uint8_t * const buf, const size_t len) @@ -1975,15 +2154,21 @@ hal_error_t hal_mkm_flash_write(const uint8_t * const buf, const size_t len) hal_error_t err; unsigned b; + hal_ks_lock(); + if ((err = fetch_pin_block(&b, &block)) != HAL_OK) - return err; + goto done; flash_pin_block_t new_data = block->pin; new_data.kek_set = FLASH_KEK_SET; memcpy(new_data.kek, buf, len); - return update_pin_block(b, block, &new_data); + err = update_pin_block(b, block, &new_data); + + done: + hal_ks_unlock(); + return err; } hal_error_t hal_mkm_flash_erase(const size_t len) @@ -1995,15 +2180,21 @@ hal_error_t hal_mkm_flash_erase(const size_t len) hal_error_t err; unsigned b; + hal_ks_lock(); + if ((err = fetch_pin_block(&b, &block)) != HAL_OK) - return err; + goto done; flash_pin_block_t new_data = block->pin; new_data.kek_set = FLASH_KEK_SET; memset(new_data.kek, 0, len); - return update_pin_block(b, block, &new_data); + err = update_pin_block(b, block, &new_data); + + done: + hal_ks_unlock(); + return err; } #endif /* HAL_MKM_FLASH_BACKUP_KLUDGE */ @@ -55,8 +55,8 @@ static inline int ks_name_cmp(const hal_ks_name_t * const name1, const hal_ks_na } /* - * Return value indicates whether the name is present in the index. - * "where" indicates the name's position whether present or not. + * Find a block in the index, return true (found) or false (not found). + * "where" indicates the name's position, or the position of the first free block. * * NB: This does NOT return a block number, it returns an index into * ksi->index[]. @@ -145,6 +145,10 @@ static inline void ks_heapsort(hal_ks_index_t *ksi) } } +/* + * Perform a consistency check on the index. + */ + #define fsck(_ksi) \ do { hal_error_t _err = hal_ks_index_fsck(_ksi); if (_err != HAL_OK) return _err; } while (0) @@ -179,16 +183,16 @@ hal_error_t hal_ks_index_fsck(hal_ks_index_t *ksi) return HAL_OK; } +/* + * Set up the index. Only setup task we have at the moment is sorting the index. + */ + hal_error_t hal_ks_index_setup(hal_ks_index_t *ksi) { if (ksi == NULL || ksi->index == NULL || ksi->names == NULL || ksi->size == 0 || ksi->used > ksi->size) return HAL_ERROR_BAD_ARGUMENTS; - /* - * Only setup task we have at the moment is sorting the index. - */ - ks_heapsort(ksi); /* @@ -200,6 +204,10 @@ hal_error_t hal_ks_index_setup(hal_ks_index_t *ksi) return HAL_OK; } +/* + * Find a single block by name and chunk number. + */ + hal_error_t hal_ks_index_find(hal_ks_index_t *ksi, const hal_uuid_t * const name, const unsigned chunk, @@ -225,6 +233,11 @@ hal_error_t hal_ks_index_find(hal_ks_index_t *ksi, return ok ? HAL_OK : HAL_ERROR_KEY_NOT_FOUND; } +/* + * Find all blocks with the given name. + * If 'strict' is set, expect it to be a well-ordered set of chunks. + */ + hal_error_t hal_ks_index_find_range(hal_ks_index_t *ksi, const hal_uuid_t * const name, const unsigned max_blocks, @@ -266,6 +279,10 @@ hal_error_t hal_ks_index_find_range(hal_ks_index_t *ksi, return HAL_OK; } +/* + * Add a single block to the index. + */ + hal_error_t hal_ks_index_add(hal_ks_index_t *ksi, const hal_uuid_t * const name, const unsigned chunk, @@ -309,6 +326,10 @@ hal_error_t hal_ks_index_add(hal_ks_index_t *ksi, return HAL_OK; } +/* + * Delete a single block from the index. + */ + hal_error_t hal_ks_index_delete(hal_ks_index_t *ksi, const hal_uuid_t * const name, const unsigned chunk, @@ -348,6 +369,11 @@ hal_error_t hal_ks_index_delete(hal_ks_index_t *ksi, return HAL_OK; } +/* + * Delete all blocks with the given name. If blocknos is NULL, return a + * count of the matching blocks without deleting anything. + */ + hal_error_t hal_ks_index_delete_range(hal_ks_index_t *ksi, const hal_uuid_t * const name, const unsigned max_blocks, @@ -404,6 +430,10 @@ hal_error_t hal_ks_index_delete_range(hal_ks_index_t *ksi, return HAL_OK; } +/* + * Replace a single block in the index. + */ + hal_error_t hal_ks_index_replace(hal_ks_index_t *ksi, const hal_uuid_t * const name, const unsigned chunk, diff --git a/ks_volatile.c b/ks_volatile.c index 99ad68c..9762da3 100644 --- a/ks_volatile.c +++ b/ks_volatile.c @@ -187,6 +187,10 @@ static hal_error_t ks_init(const hal_ks_driver_t * const driver, static hal_error_t ks_volatile_init(const hal_ks_driver_t * const driver, const int alloc) { + hal_error_t err = HAL_OK; + + hal_ks_lock(); + const size_t len = (sizeof(*volatile_ks.db) + sizeof(*volatile_ks.db->ksi.index) * STATIC_KS_VOLATILE_SLOTS + sizeof(*volatile_ks.db->ksi.names) * STATIC_KS_VOLATILE_SLOTS + @@ -195,9 +199,12 @@ static hal_error_t ks_volatile_init(const hal_ks_driver_t * const driver, const uint8_t *mem = NULL; if (alloc && (mem = hal_allocate_static_memory(len)) == NULL) - return HAL_ERROR_ALLOCATION_FAILURE; + err = HAL_ERROR_ALLOCATION_FAILURE; + else + err = ks_init(driver, 1, &volatile_ks, mem, len); - return ks_init(driver, 1, &volatile_ks, mem, len); + hal_ks_unlock(); + return err; } static hal_error_t ks_volatile_shutdown(const hal_ks_driver_t * const driver) @@ -241,14 +248,18 @@ static hal_error_t ks_store(hal_ks_t *ks, return HAL_ERROR_BAD_ARGUMENTS; ks_t *ksv = ks_to_ksv(ks); - hal_error_t err; + hal_error_t err = HAL_OK; unsigned b; - if (ksv->db == NULL) - return HAL_ERROR_KEYSTORE_ACCESS; + hal_ks_lock(); + + if (ksv->db == NULL) { + err = HAL_ERROR_KEYSTORE_ACCESS; + goto done; + } if ((err = hal_ks_index_add(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; + goto done; uint8_t kek[KEK_LENGTH]; size_t kek_len; @@ -272,6 +283,8 @@ static hal_error_t ks_store(hal_ks_t *ks, else (void) hal_ks_index_delete(&ksv->db->ksi, &slot->name, 0, NULL, &slot->hint); + done: + hal_ks_unlock(); return err; } @@ -283,19 +296,25 @@ static hal_error_t ks_fetch(hal_ks_t *ks, return HAL_ERROR_BAD_ARGUMENTS; ks_t *ksv = ks_to_ksv(ks); - hal_error_t err; + hal_error_t err = HAL_OK; unsigned b; - if (ksv->db == NULL) - return HAL_ERROR_KEYSTORE_ACCESS; + hal_ks_lock(); + + if (ksv->db == NULL) { + err = HAL_ERROR_KEYSTORE_ACCESS; + goto done; + } if ((err = hal_ks_index_find(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; + goto done; const ks_key_t * const k = &ksv->db->keys[b]; - if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, k)) - return HAL_ERROR_KEY_NOT_FOUND; + if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, k)) { + err = HAL_ERROR_KEY_NOT_FOUND; + goto done; + } slot->type = k->type; slot->curve = k->curve; @@ -319,12 +338,11 @@ static hal_error_t ks_fetch(hal_ks_t *ks, err = hal_aes_keyunwrap(NULL, kek, kek_len, k->der, k->der_len, der, der_len); memset(kek, 0, sizeof(kek)); - - if (err != HAL_OK) - return err; } - return HAL_OK; + done: + hal_ks_unlock(); + return err; } static hal_error_t ks_delete(hal_ks_t *ks, @@ -334,24 +352,32 @@ static hal_error_t ks_delete(hal_ks_t *ks, return HAL_ERROR_BAD_ARGUMENTS; ks_t *ksv = ks_to_ksv(ks); - hal_error_t err; + hal_error_t err = HAL_OK; unsigned b; - if (ksv->db == NULL) - return HAL_ERROR_KEYSTORE_ACCESS; + hal_ks_lock(); + + if (ksv->db == NULL) { + err = HAL_ERROR_KEYSTORE_ACCESS; + goto done; + } if ((err = hal_ks_index_find(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; + goto done; - if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, &ksv->db->keys[b])) - return HAL_ERROR_KEY_NOT_FOUND; + if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, &ksv->db->keys[b])) { + err = HAL_ERROR_KEY_NOT_FOUND; + goto done; + } if ((err = hal_ks_index_delete(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; + goto done; memset(&ksv->db->keys[b], 0, sizeof(ksv->db->keys[b])); - return HAL_OK; + done: + hal_ks_unlock(); + return err; } static hal_error_t ks_match(hal_ks_t *ks, @@ -376,9 +402,11 @@ static hal_error_t ks_match(hal_ks_t *ks, if (ksv->db == NULL) return HAL_ERROR_KEYSTORE_ACCESS; - hal_error_t err; + hal_error_t err = HAL_OK; int i = -1; + hal_ks_lock(); + *result_len = 0; err = hal_ks_index_find(&ksv->db->ksi, previous_uuid, 0, NULL, &i); @@ -386,7 +414,7 @@ static hal_error_t ks_match(hal_ks_t *ks, if (err == HAL_ERROR_KEY_NOT_FOUND) i--; else if (err != HAL_OK) - return err; + goto done; while (*result_len < result_max && ++i < ksv->db->ksi.used) { @@ -415,7 +443,7 @@ static hal_error_t ks_match(hal_ks_t *ks, if ((err = hal_ks_attribute_scan(k->der + k->der_len, sizeof(k->der) - k->der_len, key_attrs, k->attributes_len, NULL)) != HAL_OK) - return err; + goto done; for (const hal_pkey_attribute_t *required = attributes; ok && required < attributes + attributes_len; required++) { @@ -437,7 +465,11 @@ static hal_error_t ks_match(hal_ks_t *ks, ++*result_len; } - return HAL_OK; + err = HAL_OK; + + done: + hal_ks_unlock(); + return err; } static hal_error_t ks_set_attributes(hal_ks_t *ks, @@ -449,40 +481,53 @@ static hal_error_t ks_set_attributes(hal_ks_t *ks, return HAL_ERROR_BAD_ARGUMENTS; ks_t *ksv = ks_to_ksv(ks); - hal_error_t err; + hal_error_t err = HAL_OK; unsigned b; - if (ksv->db == NULL) - return HAL_ERROR_KEYSTORE_ACCESS; + hal_ks_lock(); + + { + if (ksv->db == NULL) { + err = HAL_ERROR_KEYSTORE_ACCESS; + goto done; + } + + if ((err = hal_ks_index_find(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) + goto done; + + ks_key_t * const k = &ksv->db->keys[b]; + + if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, k)) { + err = HAL_ERROR_KEY_NOT_FOUND; + goto done; + } + + hal_pkey_attribute_t attrs[k->attributes_len + attributes_len]; + uint8_t *bytes = k->der + k->der_len; + size_t bytes_len = sizeof(k->der) - k->der_len; + size_t total_len; + + if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, k->attributes_len, &total_len)) != HAL_OK) + goto done; + + for (const hal_pkey_attribute_t *a = attributes; a < attributes + attributes_len; a++) { + if (a->length == HAL_PKEY_ATTRIBUTE_NIL) + err = hal_ks_attribute_delete(bytes, bytes_len, attrs, &k->attributes_len, &total_len, + a->type); + else + err = hal_ks_attribute_insert(bytes, bytes_len, attrs, &k->attributes_len, &total_len, + a->type, a->value, a->length); + if (err != HAL_OK) + goto done; + } + + err = HAL_OK; - if ((err = hal_ks_index_find(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; - - ks_key_t * const k = &ksv->db->keys[b]; - - if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, k)) - return HAL_ERROR_KEY_NOT_FOUND; - - hal_pkey_attribute_t attrs[k->attributes_len + attributes_len]; - uint8_t *bytes = k->der + k->der_len; - size_t bytes_len = sizeof(k->der) - k->der_len; - size_t total_len; - - if ((err = hal_ks_attribute_scan(bytes, bytes_len, attrs, k->attributes_len, &total_len)) != HAL_OK) - return err; - - for (const hal_pkey_attribute_t *a = attributes; a < attributes + attributes_len; a++) { - if (a->length == HAL_PKEY_ATTRIBUTE_NIL) - err = hal_ks_attribute_delete(bytes, bytes_len, attrs, &k->attributes_len, &total_len, - a->type); - else - err = hal_ks_attribute_insert(bytes, bytes_len, attrs, &k->attributes_len, &total_len, - a->type, a->value, a->length); - if (err != HAL_OK) - return err; } - return HAL_OK; + done: + hal_ks_unlock(); + return err; } static hal_error_t ks_get_attributes(hal_ks_t *ks, @@ -497,53 +542,70 @@ static hal_error_t ks_get_attributes(hal_ks_t *ks, return HAL_ERROR_BAD_ARGUMENTS; ks_t *ksv = ks_to_ksv(ks); - hal_error_t err; + hal_error_t err = HAL_OK; unsigned b; - if (ksv->db == NULL) - return HAL_ERROR_KEYSTORE_ACCESS; + hal_ks_lock(); - if ((err = hal_ks_index_find(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) - return err; + { + if (ksv->db == NULL) { + err = HAL_ERROR_KEYSTORE_ACCESS; + goto done; + } - const ks_key_t * const k = &ksv->db->keys[b]; + if ((err = hal_ks_index_find(&ksv->db->ksi, &slot->name, 0, &b, &slot->hint)) != HAL_OK) + goto done; - if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, k)) - return HAL_ERROR_KEY_NOT_FOUND; + const ks_key_t * const k = &ksv->db->keys[b]; - hal_pkey_attribute_t attrs[k->attributes_len > 0 ? k->attributes_len : 1]; + if (!key_visible_to_session(ksv, slot->client_handle, slot->session_handle, k)) { + err = HAL_ERROR_KEY_NOT_FOUND; + goto done; + } - if ((err = hal_ks_attribute_scan(k->der + k->der_len, sizeof(k->der) - k->der_len, - attrs, k->attributes_len, NULL)) != HAL_OK) - return err; + hal_pkey_attribute_t attrs[k->attributes_len > 0 ? k->attributes_len : 1]; - uint8_t *abuf = attributes_buffer; + if ((err = hal_ks_attribute_scan(k->der + k->der_len, sizeof(k->der) - k->der_len, + attrs, k->attributes_len, NULL)) != HAL_OK) + goto done; - for (int i = 0; i < attributes_len; i++) { - int j = 0; - while (j < k->attributes_len && attrs[j].type != attributes[i].type) - j++; - const int found = j < k->attributes_len; + uint8_t *abuf = attributes_buffer; - if (attributes_buffer_len == 0) { - attributes[i].value = NULL; - attributes[i].length = found ? attrs[j].length : 0; - continue; - } + for (int i = 0; i < attributes_len; i++) { + int j = 0; + while (j < k->attributes_len && attrs[j].type != attributes[i].type) + j++; + const int found = j < k->attributes_len; + + if (attributes_buffer_len == 0) { + attributes[i].value = NULL; + attributes[i].length = found ? attrs[j].length : 0; + continue; + } - if (!found) - return HAL_ERROR_ATTRIBUTE_NOT_FOUND; + if (!found) { + err = HAL_ERROR_ATTRIBUTE_NOT_FOUND; + goto done; + } + + if (attrs[j].length > attributes_buffer + attributes_buffer_len - abuf) { + err = HAL_ERROR_RESULT_TOO_LONG; + goto done; + } - if (attrs[j].length > attributes_buffer + attributes_buffer_len - abuf) - return HAL_ERROR_RESULT_TOO_LONG; + memcpy(abuf, attrs[j].value, attrs[j].length); + attributes[i].value = abuf; + attributes[i].length = attrs[j].length; + abuf += attrs[j].length; + } + + err = HAL_OK; - memcpy(abuf, attrs[j].value, attrs[j].length); - attributes[i].value = abuf; - attributes[i].length = attrs[j].length; - abuf += attrs[j].length; } - return HAL_OK; + done: + hal_ks_unlock(); + return err; } const hal_ks_driver_t hal_ks_volatile_driver[1] = {{ @@ -39,18 +39,28 @@ A Python interface to the Cryptech libhal RPC API. # not likely to want to use the full ONC RPC mechanism. import os -import sys -import time import uuid import xdrlib -import serial +import socket +import logging import contextlib +logger = logging.getLogger(__name__) + + SLIP_END = chr(0300) # indicates end of packet SLIP_ESC = chr(0333) # indicates byte stuffing SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte + +def slip_encode(buffer): + return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END + +def slip_decode(buffer): + return buffer.strip(SLIP_END).replace(SLIP_ESC + SLIP_ESC_END, SLIP_END).replace(SLIP_ESC + SLIP_ESC_ESC, SLIP_ESC) + + HAL_OK = 0 class HALError(Exception): @@ -381,8 +391,9 @@ class PKey(Handle): def verify(self, hash = 0, data = "", signature = None): self.hsm.pkey_verify(self, hash = hash, data = data, signature = signature) - def set_attributes(self, attributes): - self.hsm.pkey_set_attributes(self, attributes) + def set_attributes(self, attributes = None, **kwargs): + assert if attributes is None or not kwargs + self.hsm.pkey_set_attributes(self, attributes or kwargs) def get_attributes(self, attributes): attrs = self.hsm.pkey_get_attributes(self, attributes, 0) @@ -394,79 +405,41 @@ class PKey(Handle): class HSM(object): - debug = False mixed_mode = False - - _send_delay = 0 # 0.1 + debug_io = False def _raise_if_error(self, status): if status != 0: raise HALError.table[status]() - def __init__(self, device = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE", "/dev/ttyUSB0")): - while True: - try: - self.tty = serial.Serial(device, 921600, timeout = 0.1) - break - except serial.SerialException: - time.sleep(0.2) - - def _write(self, c): - if self.debug: - sys.stdout.write("{:02x}".format(ord(c))) - self.tty.write(c) - if self._send_delay > 0: - time.sleep(self._send_delay) + def __init__(self, sockname = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", "/tmp/.cryptech_muxd.rpc")): + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket.connect(sockname) + self.sockfile = self.socket.makefile("rb") def _send(self, msg): # Expects an xdrlib.Packer - if self.debug: - sys.stdout.write("+send: ") - self._write(SLIP_END) - for c in msg.get_buffer(): - if c == SLIP_END: - self._write(SLIP_ESC) - self._write(SLIP_ESC_END) - elif c == SLIP_ESC: - self._write(SLIP_ESC) - self._write(SLIP_ESC_ESC) - else: - self._write(c) - self._write(SLIP_END) - if self.debug: - sys.stdout.write("\n") + msg = slip_encode(msg.get_buffer()) + if self.debug_io: + logger.debug("send: %s", ":".join("{:02x}".format(ord(c)) for c in msg)) + self.socket.sendall(msg) def _recv(self, code): # Returns an xdrlib.Unpacker - if self.debug: - sys.stdout.write("+recv: ") - msg = [] - esc = False + closed = False while True: - c = self.tty.read(1) - if self.debug and c: - sys.stdout.write("{:02x}".format(ord(c))) - if not c: - time.sleep(0.1) - elif c == SLIP_END and not msg: + msg = [self.sockfile.read(1)] + while msg[-1] != SLIP_END: + if msg[-1] == "": + raise HAL_ERROR_RPC_TRANSPORT() + msg.append(self.sockfile.read(1)) + if self.debug_io: + logger.debug("recv: %s", ":".join("{:02x}".format(ord(c)) for c in msg)) + msg = slip_decode("".join(msg)) + if not msg: continue - elif c == SLIP_END: - if self.debug: - sys.stdout.write("\n") - msg = xdrlib.Unpacker("".join(msg)) - if msg.unpack_uint() == code: - return msg - msg = [] - if self.debug: - sys.stdout.write("+recv: ") - elif c == SLIP_ESC: - esc = True - elif esc and c == SLIP_ESC_END: - esc = False - msg.append(SLIP_END) - elif esc and c == SLIP_ESC_ESC: - esc = False - msg.append(SLIP_ESC) - else: - msg.append(c) + msg = xdrlib.Unpacker("".join(msg)) + if msg.unpack_uint() != code: + continue + return msg _pack_builtin = (((int, long), "_pack_uint"), (str, "_pack_bytes"), @@ -576,25 +549,41 @@ class HSM(object): def pkey_load(self, type, curve, der, flags = 0, client = 0, session = 0): with self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) as r: - return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) + pkey = PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) + logger.debug("Loaded pkey %s", pkey.uuid) + return pkey def pkey_open(self, uuid, flags = 0, client = 0, session = 0): with self.rpc(RPC_FUNC_PKEY_OPEN, session, uuid, flags, client = client) as r: - return PKey(self, r.unpack_uint(), uuid) + pkey = PKey(self, r.unpack_uint(), uuid) + logger.debug("Opened pkey %s", pkey.uuid) + return pkey def pkey_generate_rsa(self, keylen, exponent = "\x01\x00\x01", flags = 0, client = 0, session = 0): with self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) as r: - return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) + pkey = PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) + logger.debug("Generated RSA pkey %s", pkey.uuid) + return pkey def pkey_generate_ec(self, curve, flags = 0, client = 0, session = 0): with self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) as r: - return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) + pkey = PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes())) + logger.debug("Generated EC pkey %s", pkey.uuid) + return pkey def pkey_close(self, pkey): + try: + logger.debug("Closing pkey %s", pkey.uuid) + except AttributeError: + pass with self.rpc(RPC_FUNC_PKEY_CLOSE, pkey): return def pkey_delete(self, pkey): + try: + logger.debug("Deleting pkey %s", pkey.uuid) + except AttributeError: + pass with self.rpc(RPC_FUNC_PKEY_DELETE, pkey): return @@ -0,0 +1,108 @@ +/* + * locks.c + * ------- + * Dummy lock code for libhal. + * + * Copyright (c) 2017, NORDUnet A/S All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * - Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * - Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * - Neither the name of the NORDUnet nor the names of its contributors may + * be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS + * IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <stdint.h> +#include <string.h> + +#include "hal.h" +#include "hal_internal.h" + +/* + * There are three slightly peculiar things about this module. + * + * 1) We want to include optional support for GNU weak functions, + * because they're convenient, but we don't want to require support + * for them. So we wrap this in a compilation time conditional + * which defaults to something compatible with C99, but allow this + * to be overriden via an external definition. + * + * 2) The functions in this module are all no-ops, just here so that + * things will link correctly on platforms that don't define them. + * Real definitions for these functions have to come from the port + * to a specific environment, eg, from sw/stm32/projects/hsm.c. + * + * 3) Because we want to expose as little as possible of the + * underlying mechanisms, some of the functions here are closures + * encapsulating objects things which would otherwise be arguments. + * So, for example, we have functions to lock and unlock the HSM + * keystore, rather than general lock and unlock functions which + * they HSM keystore lock as an argument. Since the versions in + * this file are the no-ops, the lock itself goes away here. + */ + +#ifndef ENABLE_WEAK_FUNCTIONS +#define ENABLE_WEAK_FUNCTIONS 0 +#endif + +#if ENABLE_WEAK_FUNCTIONS +#define WEAK_FUNCTION __attribute__((weak)) +#else +#define WEAK_FUNCTION +#endif + +/* + * Critical sections -- disable preemption BRIEFLY. + */ + +WEAK_FUNCTION void hal_critical_section_start(void) +{ + return; +} + +WEAK_FUNCTION void hal_critical_section_end(void) +{ + return; +} + +/* + * Keystore lock -- lock call blocks indefinitely. + */ + +WEAK_FUNCTION void hal_ks_lock(void) +{ + return; +} + +WEAK_FUNCTION void hal_ks_unlock(void) +{ + return; +} + +/* + * Local variables: + * indent-tabs-mode: nil + * End: + */ diff --git a/rpc_client_daemon.c b/rpc_client_daemon.c index dea352f..7ff3f21 100644 --- a/rpc_client_daemon.c +++ b/rpc_client_daemon.c @@ -45,17 +45,18 @@ static int sock = -1; hal_error_t hal_rpc_client_transport_init(void) { + const char *sockname = getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME"); struct sockaddr_un name; - int ret; - sock = socket(AF_UNIX, SOCK_SEQPACKET, 0); + sock = socket(AF_UNIX, SOCK_STREAM, 0); if (sock == -1) return perror("socket"), HAL_ERROR_RPC_TRANSPORT; + if (sockname == NULL) + sockname = HAL_CLIENT_DAEMON_DEFAULT_SOCKET_NAME; memset(&name, 0, sizeof(struct sockaddr_un)); name.sun_family = AF_UNIX; - strncpy(name.sun_path, HAL_CLIENT_DAEMON_DEFAULT_SOCKET_NAME, sizeof(name.sun_path) - 1); - ret = connect(sock, (const struct sockaddr *) &name, sizeof(struct sockaddr_un)); - if (ret == -1) + strncpy(name.sun_path, sockname, sizeof(name.sun_path) - 1); + if (connect(sock, (const struct sockaddr *) &name, sizeof(struct sockaddr_un)) < 0) return perror("connect"), HAL_ERROR_RPC_TRANSPORT; return HAL_OK; } @@ -69,17 +70,35 @@ hal_error_t hal_rpc_client_transport_close(void) return HAL_OK; } + hal_error_t hal_rpc_send(const uint8_t * const buf, const size_t len) { - ssize_t ret = send(sock, (const void *)buf, len, 0); - return (ret == -1) ? HAL_ERROR_RPC_TRANSPORT : HAL_OK; + return hal_slip_send(buf, len); } hal_error_t hal_rpc_recv(uint8_t * const buf, size_t * const len) { - ssize_t ret = recv(sock, (void *)buf, *len, 0); - if (ret == -1) - return HAL_ERROR_RPC_TRANSPORT; - *len = (size_t)ret; + size_t maxlen = *len; + *len = 0; + hal_error_t err = hal_slip_recv(buf, len, maxlen); + return err; +} + +/* + * These two are sort of mis-named, fix eventually, but this is what + * the code in slip.c expects. + */ + +hal_error_t hal_serial_send_char(const uint8_t c) +{ + if (write(sock, &c, 1) != 1) + return perror("write"), HAL_ERROR_RPC_TRANSPORT; + return HAL_OK; +} + +hal_error_t hal_serial_recv_char(uint8_t * const c) +{ + if (read(sock, c, 1) != 1) + return perror("read"), HAL_ERROR_RPC_TRANSPORT; return HAL_OK; } @@ -103,24 +103,32 @@ static client_slot_t client_handle[HAL_STATIC_CLIENT_STATE_BLOCKS]; static inline client_slot_t *alloc_slot(void) { + client_slot_t *slot = NULL; + hal_critical_section_start(); + #if HAL_STATIC_CLIENT_STATE_BLOCKS > 0 - for (int i = 0; i < sizeof(client_handle)/sizeof(*client_handle); i++) + for (int i = 0; slot == NULL && i < sizeof(client_handle)/sizeof(*client_handle); i++) if (client_handle[i].logged_in == HAL_USER_NONE) - return &client_handle[i]; + slot = &client_handle[i]; #endif - return NULL; + hal_critical_section_end(); + return slot; } static inline client_slot_t *find_handle(const hal_client_handle_t handle) { + client_slot_t *slot = NULL; + hal_critical_section_start(); + #if HAL_STATIC_CLIENT_STATE_BLOCKS > 0 - for (int i = 0; i < sizeof(client_handle)/sizeof(*client_handle); i++) + for (int i = 0; slot == NULL && i < sizeof(client_handle)/sizeof(*client_handle); i++) if (client_handle[i].logged_in != HAL_USER_NONE && client_handle[i].handle.handle == handle.handle) - return &client_handle[i]; + slot = &client_handle[i]; #endif - return NULL; + hal_critical_section_end(); + return slot; } static hal_error_t login(const hal_client_handle_t client, @@ -44,11 +44,11 @@ #endif #if HAL_STATIC_PKEY_STATE_BLOCKS > 0 -static hal_pkey_slot_t pkey_handle[HAL_STATIC_PKEY_STATE_BLOCKS]; +static hal_pkey_slot_t pkey_slot[HAL_STATIC_PKEY_STATE_BLOCKS]; #endif /* - * Handle allocation is simple: look for an unused (HAL_KEY_TYPE_NONE) + * Handle allocation is simple: look for an unused (HAL_HANDLE_NONE) * slot in the table, and, assuming we find one, construct a composite * handle consisting of the index into the table and a counter whose * sole purpose is to keep the same handle from reoccurring anytime @@ -61,6 +61,9 @@ static hal_pkey_slot_t pkey_handle[HAL_STATIC_PKEY_STATE_BLOCKS]; static inline hal_pkey_slot_t *alloc_slot(const hal_key_flags_t flags) { + hal_pkey_slot_t *slot = NULL; + hal_critical_section_start(); + #if HAL_STATIC_PKEY_STATE_BLOCKS > 0 static uint16_t next_glop = 0; uint32_t glop = ++next_glop << 16; @@ -71,17 +74,18 @@ static inline hal_pkey_slot_t *alloc_slot(const hal_key_flags_t flags) if ((flags & HAL_KEY_FLAG_TOKEN) != 0) glop |= HAL_PKEY_HANDLE_TOKEN_FLAG; - for (int i = 0; i < sizeof(pkey_handle)/sizeof(*pkey_handle); i++) { - if (pkey_handle[i].type != HAL_KEY_TYPE_NONE) + for (int i = 0; slot == NULL && i < sizeof(pkey_slot)/sizeof(*pkey_slot); i++) { + if (pkey_slot[i].pkey_handle.handle != HAL_HANDLE_NONE) continue; - memset(&pkey_handle[i], 0, sizeof(pkey_handle[i])); - pkey_handle[i].pkey_handle.handle = i | glop; - pkey_handle[i].hint = -1; - return &pkey_handle[i]; + memset(&pkey_slot[i], 0, sizeof(pkey_slot[i])); + pkey_slot[i].pkey_handle.handle = i | glop; + pkey_slot[i].hint = -1; + slot = &pkey_slot[i]; } #endif - return NULL; + hal_critical_section_end(); + return slot; } /* @@ -91,14 +95,18 @@ static inline hal_pkey_slot_t *alloc_slot(const hal_key_flags_t flags) static inline hal_pkey_slot_t *find_handle(const hal_pkey_handle_t handle) { + hal_pkey_slot_t *slot = NULL; + hal_critical_section_start(); + #if HAL_STATIC_PKEY_STATE_BLOCKS > 0 const int i = (int) (handle.handle & 0xFFFF); - if (i < sizeof(pkey_handle)/sizeof(*pkey_handle) && pkey_handle[i].pkey_handle.handle == handle.handle) - return &pkey_handle[i]; + if (i < sizeof(pkey_slot)/sizeof(*pkey_slot) && pkey_slot[i].pkey_handle.handle == handle.handle) + slot = &pkey_slot[i]; #endif - return NULL; + hal_critical_section_end(); + return slot; } /* @@ -219,7 +227,8 @@ static inline hal_error_t ks_open_from_flags(hal_ks_t **ks, const hal_key_flags_ } /* - * Receive key from application, store it with supplied name, return a key handle. + * Receive key from application, generate a name (UUID), store it, and + * return a key handle and the name. */ static hal_error_t pkey_local_load(const hal_client_handle_t client, diff --git a/unit-tests.py b/unit-tests.py index a8779c5..bc7edf7 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -39,6 +39,7 @@ LibHAL unit tests, using libhal.py and the Python unit_test framework. import unittest import datetime +import logging import sys from libhal import * @@ -66,10 +67,7 @@ except ImportError: ecdsa_loaded = False -def log(msg): - if not args.quiet: - sys.stderr.write(msg) - sys.stderr.write("\n") +logger = logging.getLogger("unit-tests") def main(): @@ -77,12 +75,20 @@ def main(): global args args = parse_arguments(argv[1:]) argv = argv[:1] + args.only_test - unittest.main(verbosity = 1 if args.quiet else 2, argv = argv, catchbreak = True, testRunner = TextTestRunner) + logging.basicConfig(level = logging.DEBUG if args.debug else logging.INFO, + datefmt = "%Y-%m-%d %H:%M:%S", + format = "%(asctime)-15s %(name)s[%(process)d]:%(levelname)s: %(message)s",) + unittest.main(verbosity = 1 if args.quiet else 2, + argv = argv, + catchbreak = True, + testRunner = TextTestRunner) def parse_arguments(argv = ()): from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter parser = ArgumentParser(description = __doc__, formatter_class = ArgumentDefaultsHelpFormatter) parser.add_argument("--quiet", action = "store_true", help = "suppress chatter") + parser.add_argument("--debug", action = "store_true", help = "debug-level logging") + parser.add_argument("--io-log", action = "store_true", help = "log HSM I/O stream") parser.add_argument("--wheel-pin", default = "fnord", help = "PIN for wheel user") parser.add_argument("--so-pin", default = "fnord", help = "PIN for security officer") parser.add_argument("--user-pin", default = "fnord", help = "PIN for normal user") @@ -99,6 +105,7 @@ pin_map = { HAL_USER_NORMAL : "user_pin", HAL_USER_SO : "so_pin", HAL_USER_WHEEL def setUpModule(): global hsm hsm = HSM() + hsm.debug_io = args.io_log def tearDownModule(): hsm.logout() @@ -125,6 +132,12 @@ class TextTestResult(unittest.TextTestResult): self.stream.flush() super(TextTestResult, self).addSuccess(test) + def addError(self, test, err): + if self.showAll: + self.stream.write("exception {!s} ".format(err[0].__name__)) # err[1] + self.stream.flush() + super(TextTestResult, self).addError(test, err) + class TextTestRunner(unittest.TextTestRunner): resultclass = TextTestResult @@ -335,93 +348,123 @@ class TestPKeyHashing(TestCaseLoggedIn): k1.verify(signature = sig, hash = self.h(alg, mixed_mode = True)) k2.verify(signature = sig, hash = self.h(alg, mixed_mode = True)) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_1024_sha256_data(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA256, 1024, self.sign_verify_data) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_2048_sha384_data(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA384, 2048, self.sign_verify_data) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_4096_sha512_data(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA512, 4096, self.sign_verify_data) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p256_sha256_data(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA256, HAL_CURVE_P256, self.sign_verify_data) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p384_sha384_data(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA384, HAL_CURVE_P384, self.sign_verify_data) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p521_sha512_data(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA512, HAL_CURVE_P521, self.sign_verify_data) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_1024_sha256_remote_remote(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA256, 1024, self.sign_verify_remote_remote) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_2048_sha384_remote_remote(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA384, 2048, self.sign_verify_remote_remote) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_4096_sha512_remote_remote(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA512, 4096, self.sign_verify_remote_remote) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p256_sha256_remote_remote(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA256, HAL_CURVE_P256, self.sign_verify_remote_remote) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p384_sha384_remote_remote(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA384, HAL_CURVE_P384, self.sign_verify_remote_remote) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p521_sha512_remote_remote(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA512, HAL_CURVE_P521, self.sign_verify_remote_remote) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_1024_sha256_remote_local(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA256, 1024, self.sign_verify_remote_local) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_2048_sha384_remote_local(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA384, 2048, self.sign_verify_remote_local) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_4096_sha512_remote_local(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA512, 4096, self.sign_verify_remote_local) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p256_sha256_remote_local(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA256, HAL_CURVE_P256, self.sign_verify_remote_local) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p384_sha384_remote_local(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA384, HAL_CURVE_P384, self.sign_verify_remote_local) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p521_sha512_remote_local(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA512, HAL_CURVE_P521, self.sign_verify_remote_local) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_1024_sha256_local_remote(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA256, 1024, self.sign_verify_local_remote) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_2048_sha384_local_remote(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA384, 2048, self.sign_verify_local_remote) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_4096_sha512_local_remote(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA512, 4096, self.sign_verify_local_remote) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p256_sha256_local_remote(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA256, HAL_CURVE_P256, self.sign_verify_local_remote) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p384_sha384_local_remote(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA384, HAL_CURVE_P384, self.sign_verify_local_remote) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p521_sha512_local_remote(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA512, HAL_CURVE_P521, self.sign_verify_local_remote) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_1024_sha256_local_local(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA256, 1024, self.sign_verify_local_local) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_2048_sha384_local_local(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA384, 2048, self.sign_verify_local_local) + @unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") def test_load_sign_verify_rsa_4096_sha512_local_local(self): self.load_sign_verify_rsa(HAL_DIGEST_ALGORITHM_SHA512, 4096, self.sign_verify_local_local) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p256_sha256_local_local(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA256, HAL_CURVE_P256, self.sign_verify_local_local) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p384_sha384_local_local(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA384, HAL_CURVE_P384, self.sign_verify_local_local) + @unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") def test_load_sign_verify_ecdsa_p521_sha512_local_local(self): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA512, HAL_CURVE_P521, self.sign_verify_local_local) @@ -494,7 +537,7 @@ class TestPKeyECDSAInterop(TestCaseLoggedIn): self.load_sign_verify_ecdsa(HAL_DIGEST_ALGORITHM_SHA512, SHA512, HAL_CURVE_P521) -class TestPKeyList(TestCaseLoggedIn): +class TestPKeyMatch(TestCaseLoggedIn): """ Tests involving PKey list and match functions. """ @@ -594,6 +637,7 @@ class TestPKeyAttribute(TestCaseLoggedIn): self.load_and_fill(HAL_KEY_FLAG_TOKEN, n_attrs = 4, n_fill = 512) # [16, 1024] +@unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") class TestPKeyAttributeP11(TestCaseLoggedIn): """ Attribute creation/lookup/deletion tests based on a PKCS #11 trace. @@ -658,6 +702,7 @@ class TestPKeyAttributeP11(TestCaseLoggedIn): 0x180 : "\x06\x08\x2a\x86\x48\xce\x3d\x03\x01\x07" }) +@unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") class TestPKeyAttributeWriteSpeedToken(TestCaseLoggedIn): """ Attribute speed tests. @@ -682,6 +727,7 @@ class TestPKeyAttributeWriteSpeedToken(TestCaseLoggedIn): def test_set_12_attributes(self): self.set_attributes(12) +@unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") class TestPKeyAttributeWriteSpeedVolatile(TestCaseLoggedIn): """ Attribute speed tests. @@ -706,6 +752,7 @@ class TestPKeyAttributeWriteSpeedVolatile(TestCaseLoggedIn): def test_set_12_attributes(self): self.set_attributes(12) +@unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") class TestPKeyAttributeReadSpeedToken(TestCaseLoggedIn): """ Attribute speed tests. @@ -737,6 +784,7 @@ class TestPKeyAttributeReadSpeedToken(TestCaseLoggedIn): def test_get_12_attributes(self): self.get_attributes(12) +@unittest.skipUnless(ecdsa_loaded, "Requires Python ECDSA package") class TestPKeyAttributeReadSpeedVolatile(TestCaseLoggedIn): """ Attribute speed tests. |