diff options
-rwxr-xr-x | cryptech_muxd | 41 |
1 files changed, 32 insertions, 9 deletions
diff --git a/cryptech_muxd b/cryptech_muxd index d5de227..d306eaf 100755 --- a/cryptech_muxd +++ b/cryptech_muxd @@ -58,6 +58,8 @@ import tornado.queues import tornado.locks import tornado.gen +from cryptech.libhal import HAL_OK, RPC_FUNC_GET_VERSION, RPC_FUNC_LOGOUT, RPC_FUNC_LOGOUT_ALL + logger = logging.getLogger("cryptech_muxd") @@ -89,6 +91,10 @@ def client_handle_set(msg, handle): return msg[:4] + struct.pack(">L", handle) + msg[8:] +logout_msg = struct.pack(">LL", RPC_FUNC_LOGOUT, 0) +logout_all_msg = struct.pack(">LL", RPC_FUNC_LOGOUT_ALL, 0) + + class SerialIOStream(tornado.iostream.BaseIOStream): """ Implementation of a Tornado IOStream over a PySerial device. @@ -157,10 +163,11 @@ class RPCIOStream(SerialIOStream): self.rpc_input_lock = tornado.locks.Lock() @tornado.gen.coroutine - def rpc_input(self, query, handle, queue): + def rpc_input(self, query, handle = 0, queue = None): "Send a query to the HSM." logger.debug("RPC send: %s", ":".join("{:02x}".format(ord(c)) for c in query)) - self.queues[handle] = queue + if queue is not None: + self.queues[handle] = queue with (yield self.rpc_input_lock.acquire()): yield self.write(query) logger.debug("RPC sent") @@ -182,13 +189,18 @@ class RPCIOStream(SerialIOStream): continue try: handle = client_handle_get(slip_decode(reply)) - queue = self.queues[handle] except: logger.debug("RPC skipping bad packet") continue - logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s", - handle, queue.qsize(), queue.maxsize) - queue.put_nowait(reply) + if handle not in self.queues: + logger.debug("RPC ignoring response: handle 0x%x", handle) + continue + logger.debug("RPC queue put: handle 0x%x, qsize %s", handle, self.queues[handle].qsize()) + self.queues[handle].put_nowait(reply) + + def logout_all(self): + "Execute an RPC LOGOUT_ALL operation." + return self.rpc_input(slip_encode(logout_all_msg)) class QueuedStreamClosedError(tornado.iostream.StreamClosedError): @@ -203,7 +215,7 @@ class RPCServer(PFUnixServer): @tornado.gen.coroutine def handle_stream(self, stream, address): "Handle one network connection." - handle = stream.socket.fileno() + handle = self.next_client_handle() queue = tornado.queues.Queue() logger.info("RPC connected %r, handle 0x%x", stream, handle) while True: @@ -223,8 +235,18 @@ class RPCServer(PFUnixServer): except tornado.iostream.StreamClosedError: logger.info("RPC closing %r, handle 0x%x", stream, handle) stream.close() + query = slip_encode(client_handle_set(logout_msg, handle)) + yield self.serial.rpc_input(query, handle) return + client_handle = int(time.time()) << 4 + + @classmethod + def next_client_handle(cls): + cls.client_handle += 1 + cls.client_handle &= 0xFFFFFFFF + return cls.client_handle + class CTYIOStream(SerialIOStream): """ @@ -331,8 +353,8 @@ class ProbeIOStream(SerialIOStream): @tornado.gen.coroutine def run_probe(self): - RPC_query = chr(0) * 8 # client_handle = 0, function code = RPC_FUNC_GET_VERSION - RPC_reply = chr(0) * 12 # opcode = RPC_FUNC_GET_VERSION, client_handle = 0, valret = HAL_OK + RPC_query = struct.pack(">LL", RPC_FUNC_GET_VERSION, 0) + RPC_reply = struct.pack(">LLL", RPC_FUNC_GET_VERSION, 0, HAL_OK) probe_string = SLIP_END + Control_U + SLIP_END + RPC_query + SLIP_END + Control_U + Control_M @@ -434,6 +456,7 @@ def main(): rpc_stream = RPCIOStream(device = args.rpc_device) rpc_server = RPCServer(rpc_stream, args.rpc_socket) futures.append(rpc_stream.rpc_output_loop()) + futures.append(rpc_stream.logout_all()) if args.cty_device is None: logger.warn("No CTY device found") |