aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2017-04-27 12:45:52 -0400
committerRob Austein <sra@hactrn.net>2017-04-27 12:45:52 -0400
commit99c8452187b2f99a92fbbea50ea1968d209b7c44 (patch)
treee4ac4491cd93c9053652948367203ae2445f981c
parent0ced3ff3f64ca7ee0fc804d8f6dcfc89d2c5492f (diff)
Refactor cryptech_upload to work either directly or via cryptech_muxd.
-rwxr-xr-xprojects/hsm/cryptech_upload190
1 files changed, 131 insertions, 59 deletions
diff --git a/projects/hsm/cryptech_upload b/projects/hsm/cryptech_upload
index 559195d..b6e02bd 100755
--- a/projects/hsm/cryptech_upload
+++ b/projects/hsm/cryptech_upload
@@ -37,6 +37,7 @@ import sys
import time
import struct
import serial
+import socket
import getpass
import os.path
import tarfile
@@ -46,7 +47,7 @@ import platform
from binascii import crc32
FIRMWARE_CHUNK_SIZE = 4096
-FPGA_CHUNK_SIZE = 4096
+FPGA_CHUNK_SIZE = 4096
def parse_args():
@@ -70,6 +71,12 @@ def parse_args():
help = "Name of management port USB serial device",
)
+ parser.add_argument("--socket",
+ default = os.getenv("CRYPTECH_CTY_CLIENT_SOCKET_NAME",
+ "/tmp/.cryptech_muxd.cty"),
+ help = "Name of cryptech_muxd management port socket",
+ )
+
parser.add_argument("--firmware-tarball",
type = argparse.FileType("rb"),
default = default_tarball,
@@ -126,63 +133,126 @@ def parse_args():
return parser.parse_args()
-def _write(dst, data):
- numeric = isinstance(data, (int, long))
- if numeric:
- data = struct.pack("<I", data)
- dst.write(data)
- dst.flush()
- if args.debug:
+class ManagementPortAbstract(object):
+ """
+ Abstract class encapsulating actions on the HSM management port.
+ """
+
+ def __init__(self, args):
+ self.args = args
+
+ def write(self, data):
+ numeric = isinstance(data, (int, long))
if numeric:
- print("Wrote 0x{!s}".format(data.encode("hex")))
- else:
- print("Wrote {!r}".format(data))
-
-def _read(dst):
- res = ""
- x = dst.read(1)
- while not x:
- x = dst.read(1)
- while x:
- res += x
- x = dst.read(1)
- if args.debug:
- print ("Read {!r}".format(res))
- return res
-
-def _execute(dst, cmd):
- global args
- _write(dst, "\r")
- prompt = _read(dst)
- #if prompt.endswith("This is the bootloader speaking..."):
- # prompt = _read(dst)
- if prompt.endswith("Username: "):
- _write(dst, args.username + "\r")
- prompt = _read(dst)
- if prompt.endswith("Password: "):
- if not args.pin or args.separate_pins:
- args.pin = getpass.getpass("{} PIN: ".format(args.username))
- _write(dst, args.pin + "\r")
- prompt = _read(dst)
- if not prompt.endswith(("> ", "# ")):
- print("Device does not seem to be ready for a file transfer (got {!r})".format(prompt))
- return prompt
- _write(dst, cmd + "\r")
- response = _read(dst)
- return response
+ data = struct.pack("<I", data)
+ self.send(data)
+ if self.args.debug:
+ if numeric:
+ print("Wrote 0x{!s}".format(data.encode("hex")))
+ else:
+ print("Wrote {!r}".format(data))
+
+ def read(self):
+ res = ""
+ x = self.recv()
+ while not x:
+ x = self.recv()
+ while x:
+ res += x
+ x = self.recv()
+ if self.args.debug:
+ print ("Read {!r}".format(res))
+ return res
+
+ def execute(self, cmd):
+ self.write("\r")
+ prompt = self.read()
+ #if prompt.endswith("This is the bootloader speaking..."):
+ # prompt = self.read()
+ if prompt.endswith("Username: "):
+ self.write(self.args.username + "\r")
+ prompt = self.read()
+ if prompt.endswith("Password: "):
+ if not self.args.pin or self.args.separate_pins:
+ self.args.pin = getpass.getpass("{} PIN: ".format(self.args.username))
+ self.write(self.args.pin + "\r")
+ prompt = self.read()
+ if not prompt.endswith(("> ", "# ")):
+ print("Device does not seem to be ready for a file transfer (got {!r})".format(prompt))
+ return prompt
+ self.write(cmd + "\r")
+ response = self.read()
+ return response
+
+
+class ManagementPortSerial(ManagementPortAbstract):
+ """
+ Implmentation of HSM management port abstraction over a direct
+ serial connection.
+ """
+
+ def __init__(self, args, timeout = 1):
+ super(ManagementPortSerial, self).__init__(args)
+ self.serial = serial.Serial(args.device, 921600, timeout = timeout)
+
+ def send(self, data):
+ self.serial.write(data)
+ self.serial.flush()
+
+ def recv(self):
+ return self.serial.read(1)
+
+ def set_timeout(self, timeout):
+ self.serial.timeout = timeout
+
+ def close(self):
+ self.serial.close()
+
+class ManagementPortSocket(ManagementPortAbstract):
+ """
+ Implmentation of HSM management port abstraction over a PF_UNIX
+ socket connection to the cryptech_muxd management socket.
+ """
+
+ def __init__(self, args, timeout = 1):
+ super(ManagementPortSocket, self).__init__(args)
+ self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.socket.connect(args.socket)
+ self.socket.settimeout(timeout)
+
+ def send(self, data):
+ self.socket.sendall(data)
+
+ def recv(self):
+ try:
+ return self.socket.recv(1)
+ except socket.timeout:
+ return ""
+
+ def set_timeout(self, timeout):
+ self.socket.settimeout(timeout)
+
+ def close(self):
+ self.socket.close()
+
def send_file(src, size, args, dst):
+ """
+ Upload an image from some file-like source to the management port.
+ Details depend on what kind of image it is.
+ """
+
if args.fpga:
chunk_size = FPGA_CHUNK_SIZE
- response = _execute(dst, "fpga bitstream upload")
+ response = dst.execute("fpga bitstream upload")
elif args.firmware:
chunk_size = FIRMWARE_CHUNK_SIZE
- response = _execute(dst, "firmware upload")
+ response = dst.execute("firmware upload")
if "Rebooting" in response:
- response = _execute(dst, "firmware upload")
+ response = dst.execute("firmware upload")
elif args.bootloader:
chunk_size = FIRMWARE_CHUNK_SIZE
- response = _execute(dst, "bootloader upload")
+ response = dst.execute("bootloader upload")
if "Access denied" in response:
print "Access denied"
return False
@@ -190,12 +260,12 @@ def send_file(src, size, args, dst):
print("Device did not accept the upload command (got {!r})".format(response))
return False
- dst.timeout = 0.001
+ dst.set_timeout(0.001)
crc = 0
counter = 0
# 1. Write size of file (4 bytes)
- _write(dst, struct.pack("<I", size))
- response = _read(dst)
+ dst.write(struct.pack("<I", size))
+ response = dst.read()
if not response.startswith("Send "):
print response
return False
@@ -205,13 +275,12 @@ def send_file(src, size, args, dst):
for counter in xrange(chunks):
data = src.read(chunk_size)
dst.write(data)
- dst.flush()
if not args.quiet:
print("Wrote {!s} bytes (chunk {!s}/{!s})".format(len(data), counter + 1, chunks))
# read ACK (a counter of number of 4k chunks received)
ack_bytes = ""
while len(ack_bytes) < 4:
- ack_bytes += _read(dst)
+ ack_bytes += dst.read()
ack = struct.unpack("<I", ack_bytes[:4])[0]
if ack != counter + 1:
print("ERROR: Did not receive the expected counter as ACK (got {!r}/{!r}, not {!r})".format(ack, ack_bytes, counter))
@@ -221,8 +290,8 @@ def send_file(src, size, args, dst):
crc = crc32(data, crc) & 0xffffffff
# 3. Write CRC-32 (4 bytes)
- _write(dst, struct.pack("<I", crc))
- response = _read(dst)
+ dst.write(struct.pack("<I", crc))
+ response = dst.read()
if not args.quiet:
print response
@@ -230,10 +299,10 @@ def send_file(src, size, args, dst):
if args.fpga:
# tell the fpga to read its new configuration
- _execute(dst, "fpga reset")
+ dst.execute("fpga reset")
# log out of the CLI
# (firmware/bootloader upgrades reboot, don't need an exit)
- _execute(dst, "exit")
+ dst.execute("exit")
return True
@@ -300,8 +369,11 @@ def main():
print "Uploading {} from {}".format(name, args.firmware_tarball.name)
if not args.quiet:
- print "Initializing serial port and synchronizing with HSM, this may take a few seconds"
- dst = serial.Serial(args.device, 921600, timeout = 1)
+ print "Initializing management port and synchronizing with HSM, this may take a few seconds"
+ try:
+ dst = ManagementPortSocket(args, timeout = 1)
+ except socket.error as e:
+ dst = ManagementPortSerial(args, timeout = 1)
send_file(src, size, args, dst)
dst.close()