#!/usr/bin/env python
#
# Copyright (c) 2016-2017, NORDUnet A/S All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
# be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Implementation of Cryptech RPC protocol multiplexer in Python.
Unlike the original C implementation, this uses SLIP encapsulation
over a SOCK_STREAM channel, because support for SOCK_SEQPACKET is not
what we might wish. We outsource all the heavy lifting for serial and
network I/O to the PySerial and Tornado libraries, respectively.
"""
import os
import sys
import time
import struct
import atexit
import weakref
import logging
import argparse
import logging.handlers
import serial
import serial.tools.list_ports_posix
import tornado.tcpserver
import tornado.iostream
import tornado.netutil
import tornado.ioloop
import tornado.queues
import tornado.locks
import tornado.gen
from zlib import crc32
logger = logging.getLogger("cryptech_muxd")
SLIP_END = chr(0300) # Indicates end of SLIP packet
SLIP_ESC = chr(0333) # Indicates byte stuffing
SLIP_ESC_END = chr(0334) # ESC ESC_END means END data byte
SLIP_ESC_ESC = chr(0335) # ESC ESC_ESC means ESC data byte
Control_U = chr(0025) # Console: clear line
Control_M = chr(0015) # Console: end of line
def slip_encode(buffer):
"Encode a buffer using SLIP encapsulation."
return SLIP_END + buffer.replace(SLIP_ESC, SLIP_ESC + SLIP_ESC_ESC).replace(SLIP_END, SLIP_ESC + SLIP_ESC_END) + SLIP_END
def slip_decode(buffer):
"Decode a SLIP-encapsulated buffer."
return buffer.strip(SLIP_END).replace(SLIP_ESC + SLIP_ESC_END, SLIP_END).replace(SLIP_ESC + SLIP_ESC_ESC, SLIP_ESC)
def client_handle_get(msg):
"Extract client_handle field from a Cryptech RPC message."
return struct.unpack(">L", msg[4:8])[0]
def client_handle_set(msg, handle):
"Replace client_handle field in a Cryptech RPC message."
return msg[:4] + struct.pack(">L", handle) + msg[8:]
def send_checksum(msg):
"Add a CRC32 checksum at the end of the message."
crc = (~crc32(msg)) & 0xffffffff
return msg + struct.pack("<I", crc)
def verify_checksum(msg):
"Verify the CRC32 checksum at the end of the message."
crc = crc32(msg) & 0xffffffff
if crc != 0xffffffff:
raise ValueError('Bad CRC32 in message: {} (0x{:8x})'.format(':'.join('{:02x}'.format(ord(c)) for c in msg), crc))
return msg[:-4]
class SerialIOStream(tornado.iostream.BaseIOStream):
"""
Implementation of a Tornado IOStream over a PySerial device.
"""
def __init__(self, device):
self.serial = serial.Serial(device, 921600, timeout = 0, write_timeout = 0)
self.serial_device = device
super(SerialIOStream, self).__init__()
def fileno(self):
return self.serial.fileno()
def close_fd(self):
self.serial.close()
def write_to_fd(self, data):
return self.serial.write(data)
def read_from_fd(self):
return self.serial.read(self.read_chunk_size) or None
class PFUnixServer(tornado.tcpserver.TCPServer):
"""
Variant on tornado.tcpserver.TCPServer, listening on a PF_UNIX
(aka PF_LOCAL) socket instead of a TCP socket.
"""
def __init__(self, serial_stream, socket_filename, mode = 0600):
super(PFUnixServer, self).__init__()
self.serial = serial_stream
self.socket_filename = socket_filename
self.add_socket(tornado.netutil.bind_unix_socket(socket_filename, mode))
atexit.register(self.atexit_unlink)
def atexit_unlink(self):
try:
os.unlink(self.socket_filename)
except:
pass
class RPCIOStream(SerialIOStream):
"""
Tornado IOStream for a serial RPC channel.
"""
def __init__(self, device):
super(RPCIOStream, self).__init__(device)
self.queues = weakref.WeakValueDictionary()
self.rpc_input_lock = tornado.locks.Lock()
@tornado.gen.coroutine
def rpc_input(self, query, handle, queue):
"Send a query to the HSM."
logger.debug("RPC send: %s", ":".join("{:02x}".format(ord(c)) for c in query))
self.queues[handle] = queue
with (yield self.rpc_input_lock.acquire()):
yield self.write(query)
logger.debug("RPC sent")
@tornado.gen.coroutine
def rpc_output_loop(self):
"Handle reply stream HSM -> network."
while True:
try:
logger.debug("RPC UART read")
reply = yield self.read_until(SLIP_END)
except tornado.iostream.StreamClosedError:
logger.info("RPC UART closed")
for q in self.queues.itervalues():
q.put_nowait(None)
return
logger.debug("RPC recv: %s", ":".join("{:02x}".format(ord(c)) for c in reply))
reply = slip_decode(reply)
if len(reply) < 5:
continue
# Check CRC
try:
reply = verify_checksum(reply)
except ValueError:
logger.error("RPC response CRC fail: {}".format(":".join("{:02x}".format(ord(c)) for c in reply)))
continue
try:
handle = client_handle_get(reply)
except:
continue
try:
logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s",
handle, self.queues[handle].qsize(), self.queues[handle].maxsize)
self.queues[handle].put_nowait(slip_encode(reply))
except:
logger.debug("Invalid RPC handle: 0x{:08x} / {}".format(handle, handle))
class QueuedStreamClosedError(tornado.iostream.StreamClosedError):
"Deferred StreamClosedError passed throught a Queue."
class RPCServer(PFUnixServer):
"""
Serve multiplexed Cryptech RPC over a PF_UNIX socket.
"""
@tornado.gen.coroutine
def handle_stream(self, stream, address):
"Handle one network connection."
handle = stream.socket.fileno()
queue = tornado.queues.Queue()
logger.info("RPC connected %r, handle 0x%x", stream, handle)
while True:
try:
logger.debug("RPC socket read, handle 0x%x", handle)
query = yield stream.read_until(SLIP_END)
if len(query) < 9:
continue
query = slip_encode(send_checksum(client_handle_set(slip_decode(query), handle)))
yield self.serial.rpc_input(query, handle, queue)
logger.debug("RPC queue wait, handle 0x%x", handle)
reply = yield queue.get()
if reply is None:
raise QueuedStreamClosedError()
logger.debug("RPC socket write, handle 0x%x", handle)
yield stream.write(SLIP_END + reply)
except tornado.iostream.StreamClosedError:
logger.info("RPC closing %r, handle 0x%x", stream, handle)
stream.close()
return
class CTYIOStream(SerialIOStream):
"""
Tornado IOStream for a serial console channel.
"""
def __init__(self, device):
super(CTYIOStream, self).__init__(device)
self.attached_cty = None
@tornado.gen.coroutine
def cty_output_loop(self):
while True:
try:
buffer = yield self.read_bytes(self.read_chunk_size, partial = True)
except tornado.iostream.StreamClosedError:
logger.info("CTY UART closed")
if self.attached_cty is not None:
self.attached_cty.close()
return
try:
if self.attached_cty is not None:
yield self.attached_cty.write(buffer)
except tornado.iostream.StreamClosedError:
pass
class CTYServer(PFUnixServer):
"""
Serve Cryptech console over a PF_UNIX socket.
"""
@tornado.gen.coroutine
def handle_stream(self, stream, address):
"Handle one network connection."
if self.serial.attached_cty is not None:
yield stream.write("[Console already in use, sorry]\n")
stream.close()
return
logger.info("CTY connected to %r", stream)
try:
self.serial.attached_cty = stream
while self.serial.attached_cty is stream:
yield self.serial.write((yield stream.read_bytes(1024, partial = True)))
except tornado.iostream.StreamClosedError:
stream.close()
finally:
logger.info("CTY disconnected from %r", stream)
if self.serial.attached_cty is stream:
self.serial.attached_cty = None
class ProbeIOStream(SerialIOStream):
"""
Tornado IOStream for probing a serial port. This is nasty.
"""
def __init__(self, device):
super(ProbeIOStream, self).__init__(device)
@classmethod
@tornado.gen.coroutine
def run_probes(cls, args):
if args.rpc_device is not None and args.cty_device is not None:
return
if args.probe:
devs = set(args.probe)
else:
devs = set(str(port)
for port, desc, hwid in serial.tools.list_ports_posix.comports()
if "VID:PID=0403:6014" in hwid)
devs.discard(args.rpc_device)
devs.discard(args.cty_device)
if not devs:
return
logging.debug("Probing candidate devices %s", " ".join(devs))
results = yield dict((dev, ProbeIOStream(dev).run_probe()) for dev in devs)
for dev, result in results.iteritems():
if result == "cty" and args.cty_device is None:
logger.info("Selecting %s as CTY device", dev)
args.cty_device = dev
if result == "rpc" and args.rpc_device is None:
logger.info("Selecting %s as RPC device", dev)
args.rpc_device = dev
@tornado.gen.coroutine
def run_probe(self):
RPC_query = send_checksum(chr(0) * 8) # client_handle = 0, function code = RPC_FUNC_GET_VERSION
RPC_reply = chr(0) * 12 # opcode = RPC_FUNC_GET_VERSION, client_handle = 0, valret = HAL_OK
probe_string = SLIP_END + Control_U + SLIP_END + RPC_query + SLIP_END + Control_U + Control_M
logger.debug("Probing %s with: %s", self.serial_device, ":".join("{:02x}".format(ord(c)) for c in probe_string))
yield self.write(probe_string)
yield tornado.gen.sleep(0.5)
response = yield self.read_bytes(self.read_chunk_size, partial = True)
logger.debug("Probing %s response: %r %s", self.serial_device, response, ":".join("{:02x}".format(ord(c)) for c in response))
is_cty = any(prompt in response for prompt in ("Username:", "Password:", "cryptech>"))
try:
reply_idx = response.index(SLIP_END + RPC_reply)
reply_len = len(SLIP_END + RPC_reply)
logger.debug("Reply index {}, length {}".format(reply_idx, reply_len))
end_offs = reply_idx + reply_len + 8 # RPC_reply is followed by 4 bytes of version data and a CRC32 checksum
is_rpc = response[end_offs] == SLIP_END
logger.debug("Response[{} + {} + 4] = 0x{:x} (is_rpc {})".format(
reply_idx, reply_len, ord(response[end_offs]), is_rpc))
except ValueError:
is_rpc = False
except IndexError:
is_rpc = False
assert not is_cty or not is_rpc
result = None
if is_cty:
result = "cty"
yield self.write(Control_U)
if is_rpc:
result = "rpc"
yield self.write(SLIP_END)
self.close()
raise tornado.gen.Return(result)
@tornado.gen.coroutine
def main():
parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-v", "--verbose",
action = "count",
help = "blather about what we're doing")
parser.add_argument("-l", "--log-file",
help = "log to file instead of stderr")
parser.add_argument("-p", "--probe",
nargs = "*",
metavar = "DEVICE",
help = "probe for device UARTs")
parser.add_argument("--rpc-device",
help = "RPC serial device name",
default = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE"))
parser.add_argument("--rpc-socket",
help = "RPC PF_UNIX socket name",
default = os.getenv("CRYPTECH_RPC_CLIENT_SOCKET_NAME",
"/tmp/.cryptech_muxd.rpc"))
parser.add_argument("--cty-device",
help = "CTY serial device name",
default = os.getenv("CRYPTECH_CTY_CLIENT_SERIAL_DEVICE"))
parser.add_argument("--cty-socket",
help = "CTY PF_UNIX socket name",
default = os.getenv("CRYPTECH_CTY_CLIENT_SOCKET_NAME",
"/tmp/.cryptech_muxd.cty"))
args = parser.parse_args()
if args.log_file is not None:
logging.getLogger().handlers[:] = [logging.handlers.WatchedFileHandler(args.log_file)]
logging.getLogger().handlers[0].setFormatter(
logging.Formatter("%(asctime)-15s %(name)s[%(process)d]:%(levelname)s: %(message)s",
"%Y-%m-%d %H:%M:%S"))
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG if args.verbose > 1 else logging.INFO)
if args.probe is not None:
yield ProbeIOStream.run_probes(args)
futures = []
if args.rpc_device is None:
logger.warn("No RPC device found")
else:
rpc_stream = RPCIOStream(device = args.rpc_device)
rpc_server = RPCServer(rpc_stream, args.rpc_socket)
futures.append(rpc_stream.rpc_output_loop())
if args.cty_device is None:
logger.warn("No CTY device found")
else:
cty_stream = CTYIOStream(device = args.cty_device)
cty_server = CTYServer(cty_stream, args.cty_socket)
futures.append(cty_stream.cty_output_loop())
# Might want to use WaitIterator(dict(...)) here so we can
# diagnose and restart output loops if they fail?
if futures:
yield futures
if __name__ == "__main__":
try:
tornado.ioloop.IOLoop.current().run_sync(main)
except KeyboardInterrupt:
pass