aboutsummaryrefslogtreecommitdiff
path: root/cryptech_muxd
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2017-05-31 01:41:38 -0400
committerRob Austein <sra@hactrn.net>2017-05-31 01:41:38 -0400
commit35a88083a9936b2ed3d0091c0461530be81287c1 (patch)
tree78b15a1945f345870202c9ff6749784ce46d8333 /cryptech_muxd
parent6b881dfa81a0d51d4897c62de5abdb94c1aba0b7 (diff)
Automatic logout when client disconnects or muxd restarts.
The HSM itself should be detecting carrier drop on its RPC port, but I haven't figured out where the DCD bit is hiding in the STM32 UART API, and the muxd has to be involved in this in any case, since only the muxd knows when an individual client connection has dropped. So, for the moment, we handle all of this in the muxd.
Diffstat (limited to 'cryptech_muxd')
-rwxr-xr-xcryptech_muxd41
1 files changed, 32 insertions, 9 deletions
diff --git a/cryptech_muxd b/cryptech_muxd
index d5de227..d306eaf 100755
--- a/cryptech_muxd
+++ b/cryptech_muxd
@@ -58,6 +58,8 @@ import tornado.queues
import tornado.locks
import tornado.gen
+from cryptech.libhal import HAL_OK, RPC_FUNC_GET_VERSION, RPC_FUNC_LOGOUT, RPC_FUNC_LOGOUT_ALL
+
logger = logging.getLogger("cryptech_muxd")
@@ -89,6 +91,10 @@ def client_handle_set(msg, handle):
return msg[:4] + struct.pack(">L", handle) + msg[8:]
+logout_msg = struct.pack(">LL", RPC_FUNC_LOGOUT, 0)
+logout_all_msg = struct.pack(">LL", RPC_FUNC_LOGOUT_ALL, 0)
+
+
class SerialIOStream(tornado.iostream.BaseIOStream):
"""
Implementation of a Tornado IOStream over a PySerial device.
@@ -157,10 +163,11 @@ class RPCIOStream(SerialIOStream):
self.rpc_input_lock = tornado.locks.Lock()
@tornado.gen.coroutine
- def rpc_input(self, query, handle, queue):
+ def rpc_input(self, query, handle = 0, queue = None):
"Send a query to the HSM."
logger.debug("RPC send: %s", ":".join("{:02x}".format(ord(c)) for c in query))
- self.queues[handle] = queue
+ if queue is not None:
+ self.queues[handle] = queue
with (yield self.rpc_input_lock.acquire()):
yield self.write(query)
logger.debug("RPC sent")
@@ -182,13 +189,18 @@ class RPCIOStream(SerialIOStream):
continue
try:
handle = client_handle_get(slip_decode(reply))
- queue = self.queues[handle]
except:
logger.debug("RPC skipping bad packet")
continue
- logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s",
- handle, queue.qsize(), queue.maxsize)
- queue.put_nowait(reply)
+ if handle not in self.queues:
+ logger.debug("RPC ignoring response: handle 0x%x", handle)
+ continue
+ logger.debug("RPC queue put: handle 0x%x, qsize %s", handle, self.queues[handle].qsize())
+ self.queues[handle].put_nowait(reply)
+
+ def logout_all(self):
+ "Execute an RPC LOGOUT_ALL operation."
+ return self.rpc_input(slip_encode(logout_all_msg))
class QueuedStreamClosedError(tornado.iostream.StreamClosedError):
@@ -203,7 +215,7 @@ class RPCServer(PFUnixServer):
@tornado.gen.coroutine
def handle_stream(self, stream, address):
"Handle one network connection."
- handle = stream.socket.fileno()
+ handle = self.next_client_handle()
queue = tornado.queues.Queue()
logger.info("RPC connected %r, handle 0x%x", stream, handle)
while True:
@@ -223,8 +235,18 @@ class RPCServer(PFUnixServer):
except tornado.iostream.StreamClosedError:
logger.info("RPC closing %r, handle 0x%x", stream, handle)
stream.close()
+ query = slip_encode(client_handle_set(logout_msg, handle))
+ yield self.serial.rpc_input(query, handle)
return
+ client_handle = int(time.time()) << 4
+
+ @classmethod
+ def next_client_handle(cls):
+ cls.client_handle += 1
+ cls.client_handle &= 0xFFFFFFFF
+ return cls.client_handle
+
class CTYIOStream(SerialIOStream):
"""
@@ -331,8 +353,8 @@ class ProbeIOStream(SerialIOStream):
@tornado.gen.coroutine
def run_probe(self):
- RPC_query = chr(0) * 8 # client_handle = 0, function code = RPC_FUNC_GET_VERSION
- RPC_reply = chr(0) * 12 # opcode = RPC_FUNC_GET_VERSION, client_handle = 0, valret = HAL_OK
+ RPC_query = struct.pack(">LL", RPC_FUNC_GET_VERSION, 0)
+ RPC_reply = struct.pack(">LLL", RPC_FUNC_GET_VERSION, 0, HAL_OK)
probe_string = SLIP_END + Control_U + SLIP_END + RPC_query + SLIP_END + Control_U + Control_M
@@ -434,6 +456,7 @@ def main():
rpc_stream = RPCIOStream(device = args.rpc_device)
rpc_server = RPCServer(rpc_stream, args.rpc_socket)
futures.append(rpc_stream.rpc_output_loop())
+ futures.append(rpc_stream.logout_all())
if args.cty_device is None:
logger.warn("No CTY device found")