aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2017-01-10 23:57:16 -0500
committerRob Austein <sra@hactrn.net>2017-01-10 23:57:16 -0500
commit65dded3893635e8db89c1c84e1b91fd81e04aeea (patch)
tree7a148fb2ceac8f3b296a9f0f95866609b84825d7
parent3c20fd189648b8182edbafed572898d1af744aa6 (diff)
Handle connection close events properly, use logging library.
-rwxr-xr-xcryptech_console32
-rwxr-xr-xcryptech_muxd58
-rw-r--r--libhal.py18
3 files changed, 70 insertions, 38 deletions
diff --git a/cryptech_console b/cryptech_console
index 80ec15d..6e0bc80 100755
--- a/cryptech_console
+++ b/cryptech_console
@@ -37,6 +37,7 @@ import sys
import socket
import atexit
import termios
+import logging
import argparse
import tornado.iostream
@@ -44,6 +45,9 @@ import tornado.ioloop
import tornado.gen
+logger = logging.getLogger("cryptech_console")
+
+
class FemtoTerm(object):
def __init__(self, s):
@@ -64,7 +68,16 @@ class FemtoTerm(object):
termios.tcsetattr(self.fd, termios.TCSANOW, self.new_tcattr)
def termios_teardown(self):
- termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.old_tcattr)
+ if self.fd is not None:
+ termios.tcsetattr(self.fd, termios.TCSAFLUSH, self.old_tcattr)
+ self.fd = None
+
+ def close_loops(self):
+ self.termios_teardown()
+ self.stdin_stream.close()
+ self.stdout_stream.close()
+ self.socket_stream.close()
+ self.closed = True
@tornado.gen.coroutine
def stdin_loop(self):
@@ -73,7 +86,7 @@ class FemtoTerm(object):
buffer = yield self.stdin_stream.read_bytes(1024, partial = True)
yield self.socket_stream.write(buffer.replace("\n", "\r"))
except tornado.iostream.StreamClosedError:
- self.closed = True
+ self.close_loops()
@tornado.gen.coroutine
def stdout_loop(self):
@@ -82,7 +95,7 @@ class FemtoTerm(object):
buffer = yield self.socket_stream.read_bytes(1024, partial = True)
yield self.stdout_stream.write(buffer.replace("\r\n", "\n"))
except tornado.iostream.StreamClosedError:
- self.closed = True
+ self.close_loops()
@tornado.gen.coroutine
@@ -100,20 +113,11 @@ def main():
s.connect(args.cty_socket)
term = FemtoTerm(s)
-
- if False:
- yield [term.stdin_loop(), term.stdout_loop()]
-
- else:
- stdout_future = term.stdout_loop()
- stdin_future = term.stdin_loop()
- yield stdout_future
- sys.stdin.close()
- yield stdin_future
-
+ yield [term.stdout_loop(), term.stdin_loop()]
if __name__ == "__main__":
try:
+ #logging.basicConfig(level = logging.DEBUG)
tornado.ioloop.IOLoop.current().run_sync(main)
except KeyboardInterrupt:
pass
diff --git a/cryptech_muxd b/cryptech_muxd
index 80be443..e188721 100755
--- a/cryptech_muxd
+++ b/cryptech_muxd
@@ -43,6 +43,7 @@ import time
import struct
import atexit
import weakref
+import logging
import argparse
import serial
@@ -56,6 +57,9 @@ import tornado.locks
import tornado.gen
+logger = logging.getLogger("cryptech_muxd")
+
+
SLIP_END = chr(0300) # Indicates end of SLIP packet
SLIP_ESC = chr(0333) # Indicates byte stuffing
SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte
@@ -85,9 +89,8 @@ class SerialIOStream(tornado.iostream.BaseIOStream):
Implementation of a Tornado IOStream over a PySerial device.
"""
- def __init__(self, device, baudrate = 921600, debug = False, *pargs, **kwargs):
+ def __init__(self, device, baudrate = 921600, *pargs, **kwargs):
self.serial = serial.Serial(device, baudrate, timeout = 0, write_timeout = 0)
- self.debug = debug
super(SerialIOStream, self).__init__(*pargs, **kwargs)
def fileno(self):
@@ -134,8 +137,7 @@ class RPCIOStream(SerialIOStream):
@tornado.gen.coroutine
def rpc_input(self, query, handle, queue):
"Send a query to the HSM."
- if self.debug:
- sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in query)))
+ logger.debug("rpc send: %s", ":".join("{:02x}".format(ord(c)) for c in query))
self.queues[handle] = queue
with (yield self.rpc_input_lock.acquire()):
yield self.write(query)
@@ -144,15 +146,25 @@ class RPCIOStream(SerialIOStream):
def rpc_output_loop(self):
"Handle reply stream HSM -> network."
while True:
- reply = yield self.read_until(SLIP_END)
- if self.debug:
- sys.stdout.write("+recv: {}\n".format(":".join("{:02x}".format(ord(c)) for c in reply)))
- if len(reply) < 9:
+ try:
+ reply = yield self.read_until(SLIP_END)
+ except tornado.iostream.StreamClosedError:
+ logger.info("rpc uart closed")
+ for q in self.queues.itervalues():
+ q.put_nowait(None)
+ return
+ logger.debug("rpc recv: %s", ":".join("{:02x}".format(ord(c)) for c in reply))
+ try:
+ handle = client_handle_get(slip_decode(reply))
+ except:
continue
- handle = client_handle_get(slip_decode(reply))
self.queues[handle].put_nowait(reply)
+class QueuedStreamClosedError(tornado.iostream.StreamClosedError):
+ "Deferred StreamClosedError passed throught a Queue."
+
+
class RPCServer(PFUnixServer):
"""
Serve multiplexed Cryptech RPC over a PF_UNIX socket.
@@ -164,10 +176,10 @@ class RPCServer(PFUnixServer):
@tornado.gen.coroutine
def handle_stream(self, stream, address):
"Handle one network connection."
+ logger.info("rpc connected %r", stream)
handle = stream.socket.fileno()
queue = tornado.queues.Queue()
- closed = False
- while not closed:
+ while True:
try:
query = yield stream.read_until(SLIP_END)
if len(query) < 9:
@@ -175,9 +187,13 @@ class RPCServer(PFUnixServer):
query = slip_encode(client_handle_set(slip_decode(query), handle))
yield self.serial.rpc_input(query, handle, queue)
reply = yield queue.get()
+ if reply is None:
+ raise QueuedStreamClosedError()
yield stream.write(SLIP_END + reply)
except tornado.iostream.StreamClosedError:
- closed = True
+ logger.info("rpc closing %r", stream)
+ stream.close()
+ return
class CTYIOStream(SerialIOStream):
"""
@@ -191,7 +207,13 @@ class CTYIOStream(SerialIOStream):
@tornado.gen.coroutine
def cty_output_loop(self):
while True:
- buffer = yield self.read_bytes(1024, partial = True)
+ try:
+ buffer = yield self.read_bytes(1024, partial = True)
+ except tornado.iostream.StreamClosedError:
+ logger.info("cty uart closed")
+ if self.attached_cty is not None:
+ self.attached_cty.close()
+ return
try:
if self.attached_cty is not None:
yield self.attached_cty.write(buffer)
@@ -216,13 +238,16 @@ class CTYServer(PFUnixServer):
stream.close()
return
+ logger.info("cty connected to %r", stream)
+
try:
self.serial.attached_cty = stream
while self.serial.attached_cty is stream:
yield self.serial.write((yield stream.read_bytes(1024, partial = True)))
except tornado.iostream.StreamClosedError:
- pass
+ stream.close()
finally:
+ logger.info("cty disconnecting from %r", stream)
if self.serial.attached_cty is stream:
self.serial.attached_cty = None
@@ -247,13 +272,13 @@ def main():
futures = []
- rpc_stream = RPCIOStream(device = args.rpc_device, debug = args.debug)
+ rpc_stream = RPCIOStream(device = args.rpc_device)
rpc_server = RPCServer()
rpc_server.set_serial(rpc_stream)
rpc_server.listen(args.rpc_socket)
futures.append(rpc_stream.rpc_output_loop())
- cty_stream = CTYIOStream(device = args.cty_device, debug = args.debug)
+ cty_stream = CTYIOStream(device = args.cty_device)
cty_server = CTYServer()
cty_server.set_serial(cty_stream)
cty_server.listen(args.cty_socket)
@@ -267,6 +292,7 @@ def main():
if __name__ == "__main__":
try:
+ #logging.basicConfig(level = logging.DEBUG)
tornado.ioloop.IOLoop.current().run_sync(main)
except KeyboardInterrupt:
pass
diff --git a/libhal.py b/libhal.py
index e899d7b..93746e1 100644
--- a/libhal.py
+++ b/libhal.py
@@ -39,17 +39,21 @@ A Python interface to the Cryptech libhal RPC API.
# not likely to want to use the full ONC RPC mechanism.
import os
-import sys
import uuid
import xdrlib
import socket
+import logging
import contextlib
+logger = logging.getLogger(__name__)
+
+
SLIP_END = chr(0300) # indicates end of packet
SLIP_ESC = chr(0333) # indicates byte stuffing
SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte
SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte
+
def slip_encode(buffer):
return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END
@@ -400,7 +404,6 @@ class PKey(Handle):
class HSM(object):
- debug = False
mixed_mode = False
def _raise_if_error(self, status):
@@ -414,19 +417,18 @@ class HSM(object):
def _send(self, msg): # Expects an xdrlib.Packer
msg = slip_encode(msg.get_buffer())
- if self.debug:
- sys.stdout.write("+send: {}\n".format(":".join("{:02x}".format(ord(c)) for c in msg)))
+ #logger.debug("send: %s", ":".join("{:02x}".format(ord(c)) for c in msg))
self.socket.sendall(msg)
def _recv(self, code): # Returns an xdrlib.Unpacker
+ closed = False
while True:
- if self.debug:
- sys.stdout.write("+recv: ")
msg = [self.sockfile.read(1)]
while msg[-1] != SLIP_END:
+ if msg[-1] == "":
+ raise HAL_ERROR_RPC_TRANSPORT()
msg.append(self.sockfile.read(1))
- if self.debug:
- sys.stdout.write("{}\n".format(":".join("{:02x}".format(ord(c)) for c in msg)))
+ #logger.debug("recv: %s", ":".join("{:02x}".format(ord(c)) for c in msg))
msg = slip_decode("".join(msg))
if not msg:
continue