aboutsummaryrefslogtreecommitdiff
path: root/cryptech_muxd
diff options
context:
space:
mode:
Diffstat (limited to 'cryptech_muxd')
-rwxr-xr-xcryptech_muxd54
1 files changed, 46 insertions, 8 deletions
diff --git a/cryptech_muxd b/cryptech_muxd
index 269ac15..3dcf449 100755
--- a/cryptech_muxd
+++ b/cryptech_muxd
@@ -58,6 +58,7 @@ import tornado.queues
import tornado.locks
import tornado.gen
+from zlib import crc32
logger = logging.getLogger("cryptech_muxd")
@@ -89,6 +90,19 @@ def client_handle_set(msg, handle):
return msg[:4] + struct.pack(">L", handle) + msg[8:]
+def send_checksum(msg):
+ "Add a CRC32 checksum at the end of the message."
+ crc = (~crc32(msg)) & 0xffffffff
+ return msg + struct.pack("<I", crc)
+
+def verify_checksum(msg):
+ "Verify the CRC32 checksum at the end of the message."
+ crc = crc32(msg) & 0xffffffff
+ if crc != 0xffffffff:
+ raise ValueError('Bad CRC32 in message: {} (0x{:8x})'.format(':'.join('{:02x}'.format(ord(c)) for c in msg), crc))
+ return msg[:-4]
+
+
class SerialIOStream(tornado.iostream.BaseIOStream):
"""
Implementation of a Tornado IOStream over a PySerial device.
@@ -164,14 +178,30 @@ class RPCIOStream(SerialIOStream):
q.put_nowait(None)
return
logger.debug("RPC recv: %s", ":".join("{:02x}".format(ord(c)) for c in reply))
+
+ reply = slip_decode(reply)
+
+ if len(reply) < 5:
+ continue
+
+ # Check CRC
try:
- handle = client_handle_get(slip_decode(reply))
+ reply = verify_checksum(reply)
+ except ValueError:
+ logger.error("RPC response CRC fail: {}".format(":".join("{:02x}".format(ord(c)) for c in reply)))
+ continue
+
+ try:
+ handle = client_handle_get(reply)
except:
continue
- logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s",
- handle, self.queues[handle].qsize(), self.queues[handle].maxsize)
- self.queues[handle].put_nowait(reply)
+ try:
+ logger.debug("RPC queue put: handle 0x%x, qsize %s, maxsize %s",
+ handle, self.queues[handle].qsize(), self.queues[handle].maxsize)
+ self.queues[handle].put_nowait(slip_encode(reply))
+ except:
+ logger.debug("Invalid RPC handle: 0x{:08x} / {}".format(handle, handle))
class QueuedStreamClosedError(tornado.iostream.StreamClosedError):
"Deferred StreamClosedError passed throught a Queue."
@@ -194,7 +224,7 @@ class RPCServer(PFUnixServer):
query = yield stream.read_until(SLIP_END)
if len(query) < 9:
continue
- query = slip_encode(client_handle_set(slip_decode(query), handle))
+ query = slip_encode(send_checksum(client_handle_set(slip_decode(query), handle)))
yield self.serial.rpc_input(query, handle, queue)
logger.debug("RPC queue wait, handle 0x%x", handle)
reply = yield queue.get()
@@ -307,21 +337,29 @@ class ProbeIOStream(SerialIOStream):
@tornado.gen.coroutine
def run_probe(self):
- RPC_query = chr(0) * 8 # client_handle = 0, function code = RPC_FUNC_GET_VERSION
+ RPC_query = send_checksum(chr(0) * 8) # client_handle = 0, function code = RPC_FUNC_GET_VERSION
RPC_reply = chr(0) * 12 # opcode = RPC_FUNC_GET_VERSION, client_handle = 0, valret = HAL_OK
probe_string = SLIP_END + Control_U + SLIP_END + RPC_query + SLIP_END + Control_U + Control_M
+ logger.debug("Probing %s with: %s", self.serial_device, ":".join("{:02x}".format(ord(c)) for c in probe_string))
+
yield self.write(probe_string)
yield tornado.gen.sleep(0.5)
response = yield self.read_bytes(self.read_chunk_size, partial = True)
- logger.debug("Probing %s: %r %s", self.serial_device, response, ":".join("{:02x}".format(ord(c)) for c in response))
+ logger.debug("Probing %s response: %r %s", self.serial_device, response, ":".join("{:02x}".format(ord(c)) for c in response))
is_cty = any(prompt in response for prompt in ("Username:", "Password:", "cryptech>"))
try:
- is_rpc = response[response.index(SLIP_END + RPC_reply) + len(SLIP_END + RPC_reply) + 4] == SLIP_END
+ reply_idx = response.index(SLIP_END + RPC_reply)
+ reply_len = len(SLIP_END + RPC_reply)
+ logger.debug("Reply index {}, length {}".format(reply_idx, reply_len))
+ end_offs = reply_idx + reply_len + 8 # RPC_reply is followed by 4 bytes of version data and a CRC32 checksum
+ is_rpc = response[end_offs] == SLIP_END
+ logger.debug("Response[{} + {} + 4] = 0x{:x} (is_rpc {})".format(
+ reply_idx, reply_len, ord(response[end_offs]), is_rpc))
except ValueError:
is_rpc = False
except IndexError: