diff options
-rw-r--r-- | cryptech_rpcmuxd | 180 | ||||
-rw-r--r-- | libhal.py | 89 |
2 files changed, 208 insertions, 61 deletions
diff --git a/cryptech_rpcmuxd b/cryptech_rpcmuxd new file mode 100644 index 0000000..d08d6df --- /dev/null +++ b/cryptech_rpcmuxd @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# +# Copyright (c) 2016, NORDUnet A/S All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# - Neither the name of the NORDUnet nor the names of its contributors may +# be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Implementation of Cryptech RPC protocol multiplexer in Python. + +Unlike the original C implementation, this uses SLIP encapsulation +over a SOCK_STREAM channel, because support for SOCK_SEQPACKET is not +what we might wish. We outsource all the heavy lifting for serial and +network I/O to the PySerial and Tornado libraries, respectively. +""" + +import os +import sys +import time +import struct +import atexit +import weakref +import argparse + +import serial + +import tornado.tcpserver +import tornado.iostream +import tornado.netutil +import tornado.ioloop +import tornado.queues +import tornado.locks +import tornado.gen + + +SLIP_END = chr(0300) # Indicates end of SLIP packet +SLIP_ESC = chr(0333) # Indicates byte stuffing +SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte +SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte + + +def slip_encode(buffer): + "Encode a buffer using SLIP encapsulation." + return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END + +def slip_decode(buffer): + "Decode a SLIP-encapsulated buffer." + return buffer.strip(SLIP_END).replace(SLIP_ESC + SLIP_ESC_END, SLIP_END).replace(SLIP_ESC + SLIP_ESC_ESC, SLIP_ESC) + + +def client_handle_get(msg): + "Extract client_handle field from a Cryptech RPC message." + return struct.unpack(">L", msg[4:8])[0] + +def client_handle_set(msg, handle): + "Replace client_handle field in a Cryptech RPC message." + return msg[:4] + struct.pack(">L", handle) + msg[8:] + + +class SerialIOStream(tornado.iostream.BaseIOStream): + """ + Implementation of a Tornado IOStream over a PySerial device. + """ + + def __init__(self, device, baudrate = 921600, debug = False, *pargs, **kwargs): + self.serial = serial.Serial(device, baudrate, timeout = 0, write_timeout = 0) + self.queues = weakref.WeakValueDictionary() + self.debug = debug + self.hsm_write_lock = tornado.locks.Lock() + super(SerialIOStream, self).__init__(*pargs, **kwargs) + + # The next four methods are required: BaseIOStream is an abstract + # class, we provide a driver by overriding them. + + def fileno(self): + return self.serial.fileno() + + def close_fd(self): + self.serial.close() + + def write_to_fd(self, data): + return self.serial.write(data) + + def read_from_fd(self): + return self.serial.read(self.read_chunk_size) or None + + @tornado.gen.coroutine + def hsm_write(self, query, handle, queue): + "Send a query to the HSM." + if self.debug: + sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in query))) + self.queues[handle] = queue + with (yield self.hsm_write_lock.acquire()): + yield self.write(query) + + @tornado.gen.coroutine + def hsm_read_loop(self): + "Handle reply stream HSM -> network." + while True: + reply = yield self.read_until(SLIP_END) + if self.debug: + sys.stdout.write("+recv: {}\n".format(":".join("{:02x}".format(ord(c)) for c in reply))) + if len(reply) < 9: + continue + handle = client_handle_get(slip_decode(reply)) + self.queues[handle].put_nowait(reply) + + +class UnixServer(tornado.tcpserver.TCPServer): + """ + Variant on tornado.tcpserver.TCPServer, listening on a PF_UNIX + (aka PF_LOCAL) socket instead of a TCP socket. + """ + + def listen(self, filename, mode = 0600): + self.add_socket(tornado.netutil.bind_unix_socket(filename, mode)) + + def set_serial(self, serial_stream): + self.serial = serial_stream + + @tornado.gen.coroutine + def handle_stream(self, stream, address): + "Handle one network connection." + handle = stream.socket.fileno() + queue = tornado.queues.Queue() + closed = False + while not closed: + try: + query = yield stream.read_until(SLIP_END) + if len(query) < 9: + continue + query = slip_encode(client_handle_set(slip_decode(query), handle)) + yield self.serial.hsm_write(query, handle, queue) + reply = yield queue.get() + yield stream.write(SLIP_END + reply) + except tornado.iostream.StreamClosedError: + closed = True + + +parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("-v", "--verbose", action = "store_true", help = "produce human-readable output") +parser.add_argument("-d", "--debug", action = "store_true", help = "blather about what we're doing") +parser.add_argument("device", nargs = "?", help = "serial device name", + default = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE", "/dev/ttyUSB0")) +parser.add_argument("socket", nargs = "?", help = "PF_UNIX socket name", + default = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", "/tmp/.cryptech_rpcmuxd")) +args = parser.parse_args() + +serial_stream = SerialIOStream(device = args.device, debug = args.debug) + +unix_server = UnixServer() +unix_server.set_serial(serial_stream) +unix_server.listen(args.socket) + +atexit.register(os.unlink, args.socket) + +tornado.ioloop.IOLoop.current().run_sync(serial_stream.hsm_read_loop) @@ -40,10 +40,9 @@ A Python interface to the Cryptech libhal RPC API. import os import sys -import time import uuid import xdrlib -import serial +import socket import contextlib SLIP_END = chr(0300) # indicates end of packet @@ -51,6 +50,13 @@ SLIP_ESC = chr(0333) # indicates byte stuffing SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte +def slip_encode(buffer): + return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END + +def slip_decode(buffer): + return buffer.strip(SLIP_END).replace(SLIP_ESC + SLIP_ESC_END, SLIP_END).replace(SLIP_ESC + SLIP_ESC_ESC, SLIP_ESC) + + HAL_OK = 0 class HALError(Exception): @@ -397,76 +403,37 @@ class HSM(object): debug = False mixed_mode = False - _send_delay = 0 # 0.1 - def _raise_if_error(self, status): if status != 0: raise HALError.table[status]() - def __init__(self, device = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE", "/dev/ttyUSB0")): - while True: - try: - self.tty = serial.Serial(device, 921600, timeout = 0.1) - break - except serial.SerialException: - time.sleep(0.2) - - def _write(self, c): - if self.debug: - sys.stdout.write("{:02x}".format(ord(c))) - self.tty.write(c) - if self._send_delay > 0: - time.sleep(self._send_delay) + def __init__(self, sockname = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME", "/tmp/.cryptech_rpcmuxd")): + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket.connect(sockname) + self.sockfile = self.socket.makefile("rb") def _send(self, msg): # Expects an xdrlib.Packer + msg = slip_encode(msg.get_buffer()) if self.debug: - sys.stdout.write("+send: ") - self._write(SLIP_END) - for c in msg.get_buffer(): - if c == SLIP_END: - self._write(SLIP_ESC) - self._write(SLIP_ESC_END) - elif c == SLIP_ESC: - self._write(SLIP_ESC) - self._write(SLIP_ESC_ESC) - else: - self._write(c) - self._write(SLIP_END) - if self.debug: - sys.stdout.write("\n") + sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in msg))) + self.socket.sendall(msg) def _recv(self, code): # Returns an xdrlib.Unpacker - if self.debug: - sys.stdout.write("+recv: ") - msg = [] - esc = False while True: - c = self.tty.read(1) - if self.debug and c: - sys.stdout.write("{:02x}".format(ord(c))) - if not c: - time.sleep(0.1) - elif c == SLIP_END and not msg: + if self.debug: + sys.stdout.write("+recv: ") + msg = [self.sockfile.read(1)] + while msg[-1] != SLIP_END: + msg.append(self.sockfile.read(1)) + if self.debug: + sys.stdout.write("{}\n".format(":".join("{:02x}".format(ord(c)) for c in msg))) + msg = slip_decode("".join(msg)) + if not msg: continue - elif c == SLIP_END: - if self.debug: - sys.stdout.write("\n") - msg = xdrlib.Unpacker("".join(msg)) - if msg.unpack_uint() == code: - return msg - msg = [] - if self.debug: - sys.stdout.write("+recv: ") - elif c == SLIP_ESC: - esc = True - elif esc and c == SLIP_ESC_END: - esc = False - msg.append(SLIP_END) - elif esc and c == SLIP_ESC_ESC: - esc = False - msg.append(SLIP_ESC) - else: - msg.append(c) + msg = xdrlib.Unpacker("".join(msg)) + if msg.unpack_uint() != code: + continue + return msg _pack_builtin = (((int, long), "_pack_uint"), (str, "_pack_bytes"), |