From 65dded3893635e8db89c1c84e1b91fd81e04aeea Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Tue, 10 Jan 2017 23:57:16 -0500 Subject: Handle connection close events properly, use logging library. --- cryptech_muxd | 58 ++++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 16 deletions(-) (limited to 'cryptech_muxd') diff --git a/cryptech_muxd b/cryptech_muxd index 80be443..e188721 100755 --- a/cryptech_muxd +++ b/cryptech_muxd @@ -43,6 +43,7 @@ import time import struct import atexit import weakref +import logging import argparse import serial @@ -56,6 +57,9 @@ import tornado.locks import tornado.gen +logger = logging.getLogger("cryptech_muxd") + + SLIP_END = chr(0300) # Indicates end of SLIP packet SLIP_ESC = chr(0333) # Indicates byte stuffing SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte @@ -85,9 +89,8 @@ class SerialIOStream(tornado.iostream.BaseIOStream): Implementation of a Tornado IOStream over a PySerial device. """ - def __init__(self, device, baudrate = 921600, debug = False, *pargs, **kwargs): + def __init__(self, device, baudrate = 921600, *pargs, **kwargs): self.serial = serial.Serial(device, baudrate, timeout = 0, write_timeout = 0) - self.debug = debug super(SerialIOStream, self).__init__(*pargs, **kwargs) def fileno(self): @@ -134,8 +137,7 @@ class RPCIOStream(SerialIOStream): @tornado.gen.coroutine def rpc_input(self, query, handle, queue): "Send a query to the HSM." - if self.debug: - sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in query))) + logger.debug("rpc send: %s", ":".join("{:02x}".format(ord(c)) for c in query)) self.queues[handle] = queue with (yield self.rpc_input_lock.acquire()): yield self.write(query) @@ -144,15 +146,25 @@ class RPCIOStream(SerialIOStream): def rpc_output_loop(self): "Handle reply stream HSM -> network." while True: - reply = yield self.read_until(SLIP_END) - if self.debug: - sys.stdout.write("+recv: {}\n".format(":".join("{:02x}".format(ord(c)) for c in reply))) - if len(reply) < 9: + try: + reply = yield self.read_until(SLIP_END) + except tornado.iostream.StreamClosedError: + logger.info("rpc uart closed") + for q in self.queues.itervalues(): + q.put_nowait(None) + return + logger.debug("rpc recv: %s", ":".join("{:02x}".format(ord(c)) for c in reply)) + try: + handle = client_handle_get(slip_decode(reply)) + except: continue - handle = client_handle_get(slip_decode(reply)) self.queues[handle].put_nowait(reply) +class QueuedStreamClosedError(tornado.iostream.StreamClosedError): + "Deferred StreamClosedError passed throught a Queue." + + class RPCServer(PFUnixServer): """ Serve multiplexed Cryptech RPC over a PF_UNIX socket. @@ -164,10 +176,10 @@ class RPCServer(PFUnixServer): @tornado.gen.coroutine def handle_stream(self, stream, address): "Handle one network connection." + logger.info("rpc connected %r", stream) handle = stream.socket.fileno() queue = tornado.queues.Queue() - closed = False - while not closed: + while True: try: query = yield stream.read_until(SLIP_END) if len(query) < 9: @@ -175,9 +187,13 @@ class RPCServer(PFUnixServer): query = slip_encode(client_handle_set(slip_decode(query), handle)) yield self.serial.rpc_input(query, handle, queue) reply = yield queue.get() + if reply is None: + raise QueuedStreamClosedError() yield stream.write(SLIP_END + reply) except tornado.iostream.StreamClosedError: - closed = True + logger.info("rpc closing %r", stream) + stream.close() + return class CTYIOStream(SerialIOStream): """ @@ -191,7 +207,13 @@ class CTYIOStream(SerialIOStream): @tornado.gen.coroutine def cty_output_loop(self): while True: - buffer = yield self.read_bytes(1024, partial = True) + try: + buffer = yield self.read_bytes(1024, partial = True) + except tornado.iostream.StreamClosedError: + logger.info("cty uart closed") + if self.attached_cty is not None: + self.attached_cty.close() + return try: if self.attached_cty is not None: yield self.attached_cty.write(buffer) @@ -216,13 +238,16 @@ class CTYServer(PFUnixServer): stream.close() return + logger.info("cty connected to %r", stream) + try: self.serial.attached_cty = stream while self.serial.attached_cty is stream: yield self.serial.write((yield stream.read_bytes(1024, partial = True))) except tornado.iostream.StreamClosedError: - pass + stream.close() finally: + logger.info("cty disconnecting from %r", stream) if self.serial.attached_cty is stream: self.serial.attached_cty = None @@ -247,13 +272,13 @@ def main(): futures = [] - rpc_stream = RPCIOStream(device = args.rpc_device, debug = args.debug) + rpc_stream = RPCIOStream(device = args.rpc_device) rpc_server = RPCServer() rpc_server.set_serial(rpc_stream) rpc_server.listen(args.rpc_socket) futures.append(rpc_stream.rpc_output_loop()) - cty_stream = CTYIOStream(device = args.cty_device, debug = args.debug) + cty_stream = CTYIOStream(device = args.cty_device) cty_server = CTYServer() cty_server.set_serial(cty_stream) cty_server.listen(args.cty_socket) @@ -267,6 +292,7 @@ def main(): if __name__ == "__main__": try: + #logging.basicConfig(level = logging.DEBUG) tornado.ioloop.IOLoop.current().run_sync(main) except KeyboardInterrupt: pass -- cgit v1.2.3