From 65dded3893635e8db89c1c84e1b91fd81e04aeea Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Tue, 10 Jan 2017 23:57:16 -0500 Subject: Handle connection close events properly, use logging library. --- cryptech_console | 32 +++++++++++++++++-------------- cryptech_muxd | 58 ++++++++++++++++++++++++++++++++++++++++---------------- libhal.py | 18 ++++++++++-------- 3 files changed, 70 insertions(+), 38 deletions(-) diff --git a/cryptech_console b/cryptech_console index 80ec15d..6e0bc80 100755 --- a/cryptech_console +++ b/cryptech_console @@ -37,6 +37,7 @@ import sys import socket import atexit import termios +import logging import argparse import tornado.iostream @@ -44,6 +45,9 @@ import tornado.ioloop import tornado.gen +logger = logging.getLogger("cryptech_console") + + class FemtoTerm(object): def __init__(self, s): @@ -64,7 +68,16 @@ class FemtoTerm(object): termios.tcsetattr(self.fd, termios.TCSANOW, self.new_tcattr) def termios_teardown(self): - termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.old_tcattr) + if self.fd is not None: + termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.old_tcattr) + self.fd = None + + def close_loops(self): + self.termios_teardown() + self.stdin_stream.close() + self.stdout_stream.close() + self.socket_stream.close() + self.closed = True @tornado.gen.coroutine def stdin_loop(self): @@ -73,7 +86,7 @@ class FemtoTerm(object): buffer = yield self.stdin_stream.read_bytes(1024, partial = True) yield self.socket_stream.write(buffer.replace("\n", "\r")) except tornado.iostream.StreamClosedError: - self.closed = True + self.close_loops() @tornado.gen.coroutine def stdout_loop(self): @@ -82,7 +95,7 @@ class FemtoTerm(object): buffer = yield self.socket_stream.read_bytes(1024, partial = True) yield self.stdout_stream.write(buffer.replace("\r\n", "\n")) except tornado.iostream.StreamClosedError: - self.closed = True + self.close_loops() @tornado.gen.coroutine @@ -100,20 +113,11 @@ def main(): s.connect(args.cty_socket) term = FemtoTerm(s) - - if False: - yield [term.stdin_loop(), term.stdout_loop()] - - else: - stdout_future = term.stdout_loop() - stdin_future = term.stdin_loop() - yield stdout_future - sys.stdin.close() - yield stdin_future - + yield [term.stdout_loop(), term.stdin_loop()] if __name__ == "__main__": try: + #logging.basicConfig(level = logging.DEBUG) tornado.ioloop.IOLoop.current().run_sync(main) except KeyboardInterrupt: pass diff --git a/cryptech_muxd b/cryptech_muxd index 80be443..e188721 100755 --- a/cryptech_muxd +++ b/cryptech_muxd @@ -43,6 +43,7 @@ import time import struct import atexit import weakref +import logging import argparse import serial @@ -56,6 +57,9 @@ import tornado.locks import tornado.gen +logger = logging.getLogger("cryptech_muxd") + + SLIP_END = chr(0300) # Indicates end of SLIP packet SLIP_ESC = chr(0333) # Indicates byte stuffing SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte @@ -85,9 +89,8 @@ class SerialIOStream(tornado.iostream.BaseIOStream): Implementation of a Tornado IOStream over a PySerial device. """ - def __init__(self, device, baudrate = 921600, debug = False, *pargs, **kwargs): + def __init__(self, device, baudrate = 921600, *pargs, **kwargs): self.serial = serial.Serial(device, baudrate, timeout = 0, write_timeout = 0) - self.debug = debug super(SerialIOStream, self).__init__(*pargs, **kwargs) def fileno(self): @@ -134,8 +137,7 @@ class RPCIOStream(SerialIOStream): @tornado.gen.coroutine def rpc_input(self, query, handle, queue): "Send a query to the HSM." - if self.debug: - sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in query))) + logger.debug("rpc send: %s", ":".join("{:02x}".format(ord(c)) for c in query)) self.queues[handle] = queue with (yield self.rpc_input_lock.acquire()): yield self.write(query) @@ -144,15 +146,25 @@ class RPCIOStream(SerialIOStream): def rpc_output_loop(self): "Handle reply stream HSM -> network." while True: - reply = yield self.read_until(SLIP_END) - if self.debug: - sys.stdout.write("+recv: {}\n".format(":".join("{:02x}".format(ord(c)) for c in reply))) - if len(reply) < 9: + try: + reply = yield self.read_until(SLIP_END) + except tornado.iostream.StreamClosedError: + logger.info("rpc uart closed") + for q in self.queues.itervalues(): + q.put_nowait(None) + return + logger.debug("rpc recv: %s", ":".join("{:02x}".format(ord(c)) for c in reply)) + try: + handle = client_handle_get(slip_decode(reply)) + except: continue - handle = client_handle_get(slip_decode(reply)) self.queues[handle].put_nowait(reply) +class QueuedStreamClosedError(tornado.iostream.StreamClosedError): + "Deferred StreamClosedError passed throught a Queue." + + class RPCServer(PFUnixServer): """ Serve multiplexed Cryptech RPC over a PF_UNIX socket. @@ -164,10 +176,10 @@ class RPCServer(PFUnixServer): @tornado.gen.coroutine def handle_stream(self, stream, address): "Handle one network connection." + logger.info("rpc connected %r", stream) handle = stream.socket.fileno() queue = tornado.queues.Queue() - closed = False - while not closed: + while True: try: query = yield stream.read_until(SLIP_END) if len(query) < 9: @@ -175,9 +187,13 @@ class RPCServer(PFUnixServer): query = slip_encode(client_handle_set(slip_decode(query), handle)) yield self.serial.rpc_input(query, handle, queue) reply = yield queue.get() + if reply is None: + raise QueuedStreamClosedError() yield stream.write(SLIP_END + reply) except tornado.iostream.StreamClosedError: - closed = True + logger.info("rpc closing %r", stream) + stream.close() + return class CTYIOStream(SerialIOStream): """ @@ -191,7 +207,13 @@ class CTYIOStream(SerialIOStream): @tornado.gen.coroutine def cty_output_loop(self): while True: - buffer = yield self.read_bytes(1024, partial = True) + try: + buffer = yield self.read_bytes(1024, partial = True) + except tornado.iostream.StreamClosedError: + logger.info("cty uart closed") + if self.attached_cty is not None: + self.attached_cty.close() + return try: if self.attached_cty is not None: yield self.attached_cty.write(buffer) @@ -216,13 +238,16 @@ class CTYServer(PFUnixServer): stream.close() return + logger.info("cty connected to %r", stream) + try: self.serial.attached_cty = stream while self.serial.attached_cty is stream: yield self.serial.write((yield stream.read_bytes(1024, partial = True))) except tornado.iostream.StreamClosedError: - pass + stream.close() finally: + logger.info("cty disconnecting from %r", stream) if self.serial.attached_cty is stream: self.serial.attached_cty = None @@ -247,13 +272,13 @@ def main(): futures = [] - rpc_stream = RPCIOStream(device = args.rpc_device, debug = args.debug) + rpc_stream = RPCIOStream(device = args.rpc_device) rpc_server = RPCServer() rpc_server.set_serial(rpc_stream) rpc_server.listen(args.rpc_socket) futures.append(rpc_stream.rpc_output_loop()) - cty_stream = CTYIOStream(device = args.cty_device, debug = args.debug) + cty_stream = CTYIOStream(device = args.cty_device) cty_server = CTYServer() cty_server.set_serial(cty_stream) cty_server.listen(args.cty_socket) @@ -267,6 +292,7 @@ def main(): if __name__ == "__main__": try: + #logging.basicConfig(level = logging.DEBUG) tornado.ioloop.IOLoop.current().run_sync(main) except KeyboardInterrupt: pass diff --git a/libhal.py b/libhal.py index e899d7b..93746e1 100644 --- a/libhal.py +++ b/libhal.py @@ -39,17 +39,21 @@ A Python interface to the Cryptech libhal RPC API. # not likely to want to use the full ONC RPC mechanism. import os -import sys import uuid import xdrlib import socket +import logging import contextlib +logger = logging.getLogger(__name__) + + SLIP_END = chr(0300) # indicates end of packet SLIP_ESC = chr(0333) # indicates byte stuffing SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte + def slip_encode(buffer): return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END @@ -400,7 +404,6 @@ class PKey(Handle): class HSM(object): - debug = False mixed_mode = False def _raise_if_error(self, status): @@ -414,19 +417,18 @@ class HSM(object): def _send(self, msg): # Expects an xdrlib.Packer msg = slip_encode(msg.get_buffer()) - if self.debug: - sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in msg))) + #logger.debug("send: %s", ":".join("{:02x}".format(ord(c)) for c in msg)) self.socket.sendall(msg) def _recv(self, code): # Returns an xdrlib.Unpacker + closed = False while True: - if self.debug: - sys.stdout.write("+recv: ") msg = [self.sockfile.read(1)] while msg[-1] != SLIP_END: + if msg[-1] == "": + raise HAL_ERROR_RPC_TRANSPORT() msg.append(self.sockfile.read(1)) - if self.debug: - sys.stdout.write("{}\n".format(":".join("{:02x}".format(ord(c)) for c in msg))) + #logger.debug("recv: %s", ":".join("{:02x}".format(ord(c)) for c in msg)) msg = slip_decode("".join(msg)) if not msg: continue -- cgit v1.2.3