diff options
author | Rob Austein <sra@hactrn.net> | 2017-05-31 01:41:38 -0400 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2017-05-31 01:41:38 -0400 |
commit | 35a88083a9936b2ed3d0091c0461530be81287c1 (patch) | |
tree | 78b15a1945f345870202c9ff6749784ce46d8333 /cryptech_muxd | |
parent | 6b881dfa81a0d51d4897c62de5abdb94c1aba0b7 (diff) |
Automatic logout when client disconnects or muxd restarts.
The HSM itself should be detecting carrier drop on its RPC port, but I
haven't figured out where the DCD bit is hiding in the STM32 UART API,
and the muxd has to be involved in this in any case, since only the
muxd knows when an individual client connection has dropped. So, for
the moment, we handle all of this in the muxd.
Diffstat (limited to 'cryptech_muxd')
-rwxr-xr-x | cryptech_muxd | 41 |
1 files changed, 32 insertions, 9 deletions
diff --git a/cryptech_muxd b/cryptech_muxd index d5de227..d306eaf 100755 --- a/cryptech_muxd +++ b/cryptech_muxd @@ -58,6 +58,8 @@ import tornado.queues import tornado.locks import tornado.gen +from cryptech.libhal import HAL_OK, RPC_FUNC_GET_VERSION, RPC_FUNC_LOGOUT, RPC_FUNC_LOGOUT_ALL + logger = logging.getLogger("cryptech_muxd") @@ -89,6 +91,10 @@ def client_handle_set(msg, handle): return msg[:4] + struct.pack(">L", handle) + msg[8:] +logout_msg = struct.pack(">LL", RPC_FUNC_LOGOUT, 0) +logout_all_msg = struct.pack(">LL", RPC_FUNC_LOGOUT_ALL, 0) + + class SerialIOStream(tornado.iostream.BaseIOStream): """ Implementation of a Tornado IOStream over a PySerial device. @@ -157,10 +163,11 @@ class RPCIOStream(SerialIOStream): self.rpc_input_lock = tornado.locks.Lock() @tornado.gen.coroutine - def rpc_input(self, query, handle, queue): + def rpc_input(self, query, handle = 0, queue = None): "Send a query to the HSM." logger.debug("RPC send: %s", ":".join("{:02x}".format(ord(c)) for c in query)) - self.queues[handle] = queue + if queue is not None: + self.queues[handle] = queue with (yield self.rpc_input_lock.acquire()): yield self.write(query) logger.debug("RPC sent") @@ -182,13 +189,18 @@ class RPCIOStream(SerialIOStream): continue try: handle = client_handle_get(slip_decode(reply)) - queue = self.queues[handle] except: logger.debug("RPC skipping bad packet") continue - logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s", - handle, queue.qsize(), queue.maxsize) - queue.put_nowait(reply) + if handle not in self.queues: + logger.debug("RPC ignoring response: handle 0x%x", handle) + continue + logger.debug("RPC queue put: handle 0x%x, qsize %s", handle, self.queues[handle].qsize()) + self.queues[handle].put_nowait(reply) + + def logout_all(self): + "Execute an RPC LOGOUT_ALL operation." + return self.rpc_input(slip_encode(logout_all_msg)) class QueuedStreamClosedError(tornado.iostream.StreamClosedError): @@ -203,7 +215,7 @@ class RPCServer(PFUnixServer): @tornado.gen.coroutine def handle_stream(self, stream, address): "Handle one network connection." - handle = stream.socket.fileno() + handle = self.next_client_handle() queue = tornado.queues.Queue() logger.info("RPC connected %r, handle 0x%x", stream, handle) while True: @@ -223,8 +235,18 @@ class RPCServer(PFUnixServer): except tornado.iostream.StreamClosedError: logger.info("RPC closing %r, handle 0x%x", stream, handle) stream.close() + query = slip_encode(client_handle_set(logout_msg, handle)) + yield self.serial.rpc_input(query, handle) return + client_handle = int(time.time()) << 4 + + @classmethod + def next_client_handle(cls): + cls.client_handle += 1 + cls.client_handle &= 0xFFFFFFFF + return cls.client_handle + class CTYIOStream(SerialIOStream): """ @@ -331,8 +353,8 @@ class ProbeIOStream(SerialIOStream): @tornado.gen.coroutine def run_probe(self): - RPC_query = chr(0) * 8 # client_handle = 0, function code = RPC_FUNC_GET_VERSION - RPC_reply = chr(0) * 12 # opcode = RPC_FUNC_GET_VERSION, client_handle = 0, valret = HAL_OK + RPC_query = struct.pack(">LL", RPC_FUNC_GET_VERSION, 0) + RPC_reply = struct.pack(">LLL", RPC_FUNC_GET_VERSION, 0, HAL_OK) probe_string = SLIP_END + Control_U + SLIP_END + RPC_query + SLIP_END + Control_U + Control_M @@ -434,6 +456,7 @@ def main(): rpc_stream = RPCIOStream(device = args.rpc_device) rpc_server = RPCServer(rpc_stream, args.rpc_socket) futures.append(rpc_stream.rpc_output_loop()) + futures.append(rpc_stream.logout_all()) if args.cty_device is None: logger.warn("No CTY device found") |