diff options
-rwxr-xr-x | projects/hsm/cryptech_upload | 190 |
1 files changed, 131 insertions, 59 deletions
diff --git a/projects/hsm/cryptech_upload b/projects/hsm/cryptech_upload index 559195d..b6e02bd 100755 --- a/projects/hsm/cryptech_upload +++ b/projects/hsm/cryptech_upload @@ -37,6 +37,7 @@ import sys import time import struct import serial +import socket import getpass import os.path import tarfile @@ -46,7 +47,7 @@ import platform from binascii import crc32 FIRMWARE_CHUNK_SIZE = 4096 -FPGA_CHUNK_SIZE = 4096 +FPGA_CHUNK_SIZE = 4096 def parse_args(): @@ -70,6 +71,12 @@ def parse_args(): help = "Name of management port USB serial device", ) + parser.add_argument("--socket", + default = os.getenv("CRYPTECH_CTY_CLIENT_SOCKET_NAME", + "/tmp/.cryptech_muxd.cty"), + help = "Name of cryptech_muxd management port socket", + ) + parser.add_argument("--firmware-tarball", type = argparse.FileType("rb"), default = default_tarball, @@ -126,63 +133,126 @@ def parse_args(): return parser.parse_args() -def _write(dst, data): - numeric = isinstance(data, (int, long)) - if numeric: - data = struct.pack("<I", data) - dst.write(data) - dst.flush() - if args.debug: +class ManagementPortAbstract(object): + """ + Abstract class encapsulating actions on the HSM management port. + """ + + def __init__(self, args): + self.args = args + + def write(self, data): + numeric = isinstance(data, (int, long)) if numeric: - print("Wrote 0x{!s}".format(data.encode("hex"))) - else: - print("Wrote {!r}".format(data)) - -def _read(dst): - res = "" - x = dst.read(1) - while not x: - x = dst.read(1) - while x: - res += x - x = dst.read(1) - if args.debug: - print ("Read {!r}".format(res)) - return res - -def _execute(dst, cmd): - global args - _write(dst, "\r") - prompt = _read(dst) - #if prompt.endswith("This is the bootloader speaking..."): - # prompt = _read(dst) - if prompt.endswith("Username: "): - _write(dst, args.username + "\r") - prompt = _read(dst) - if prompt.endswith("Password: "): - if not args.pin or args.separate_pins: - args.pin = getpass.getpass("{} PIN: ".format(args.username)) - _write(dst, args.pin + "\r") - prompt = _read(dst) - if not prompt.endswith(("> ", "# ")): - print("Device does not seem to be ready for a file transfer (got {!r})".format(prompt)) - return prompt - _write(dst, cmd + "\r") - response = _read(dst) - return response + data = struct.pack("<I", data) + self.send(data) + if self.args.debug: + if numeric: + print("Wrote 0x{!s}".format(data.encode("hex"))) + else: + print("Wrote {!r}".format(data)) + + def read(self): + res = "" + x = self.recv() + while not x: + x = self.recv() + while x: + res += x + x = self.recv() + if self.args.debug: + print ("Read {!r}".format(res)) + return res + + def execute(self, cmd): + self.write("\r") + prompt = self.read() + #if prompt.endswith("This is the bootloader speaking..."): + # prompt = self.read() + if prompt.endswith("Username: "): + self.write(self.args.username + "\r") + prompt = self.read() + if prompt.endswith("Password: "): + if not self.args.pin or self.args.separate_pins: + self.args.pin = getpass.getpass("{} PIN: ".format(self.args.username)) + self.write(self.args.pin + "\r") + prompt = self.read() + if not prompt.endswith(("> ", "# ")): + print("Device does not seem to be ready for a file transfer (got {!r})".format(prompt)) + return prompt + self.write(cmd + "\r") + response = self.read() + return response + + +class ManagementPortSerial(ManagementPortAbstract): + """ + Implmentation of HSM management port abstraction over a direct + serial connection. + """ + + def __init__(self, args, timeout = 1): + super(ManagementPortSerial, self).__init__(args) + self.serial = serial.Serial(args.device, 921600, timeout = timeout) + + def send(self, data): + self.serial.write(data) + self.serial.flush() + + def recv(self): + return self.serial.read(1) + + def set_timeout(self, timeout): + self.serial.timeout = timeout + + def close(self): + self.serial.close() + +class ManagementPortSocket(ManagementPortAbstract): + """ + Implmentation of HSM management port abstraction over a PF_UNIX + socket connection to the cryptech_muxd management socket. + """ + + def __init__(self, args, timeout = 1): + super(ManagementPortSocket, self).__init__(args) + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket.connect(args.socket) + self.socket.settimeout(timeout) + + def send(self, data): + self.socket.sendall(data) + + def recv(self): + try: + return self.socket.recv(1) + except socket.timeout: + return "" + + def set_timeout(self, timeout): + self.socket.settimeout(timeout) + + def close(self): + self.socket.close() + def send_file(src, size, args, dst): + """ + Upload an image from some file-like source to the management port. + Details depend on what kind of image it is. + """ + if args.fpga: chunk_size = FPGA_CHUNK_SIZE - response = _execute(dst, "fpga bitstream upload") + response = dst.execute("fpga bitstream upload") elif args.firmware: chunk_size = FIRMWARE_CHUNK_SIZE - response = _execute(dst, "firmware upload") + response = dst.execute("firmware upload") if "Rebooting" in response: - response = _execute(dst, "firmware upload") + response = dst.execute("firmware upload") elif args.bootloader: chunk_size = FIRMWARE_CHUNK_SIZE - response = _execute(dst, "bootloader upload") + response = dst.execute("bootloader upload") if "Access denied" in response: print "Access denied" return False @@ -190,12 +260,12 @@ def send_file(src, size, args, dst): print("Device did not accept the upload command (got {!r})".format(response)) return False - dst.timeout = 0.001 + dst.set_timeout(0.001) crc = 0 counter = 0 # 1. Write size of file (4 bytes) - _write(dst, struct.pack("<I", size)) - response = _read(dst) + dst.write(struct.pack("<I", size)) + response = dst.read() if not response.startswith("Send "): print response return False @@ -205,13 +275,12 @@ def send_file(src, size, args, dst): for counter in xrange(chunks): data = src.read(chunk_size) dst.write(data) - dst.flush() if not args.quiet: print("Wrote {!s} bytes (chunk {!s}/{!s})".format(len(data), counter + 1, chunks)) # read ACK (a counter of number of 4k chunks received) ack_bytes = "" while len(ack_bytes) < 4: - ack_bytes += _read(dst) + ack_bytes += dst.read() ack = struct.unpack("<I", ack_bytes[:4])[0] if ack != counter + 1: print("ERROR: Did not receive the expected counter as ACK (got {!r}/{!r}, not {!r})".format(ack, ack_bytes, counter)) @@ -221,8 +290,8 @@ def send_file(src, size, args, dst): crc = crc32(data, crc) & 0xffffffff # 3. Write CRC-32 (4 bytes) - _write(dst, struct.pack("<I", crc)) - response = _read(dst) + dst.write(struct.pack("<I", crc)) + response = dst.read() if not args.quiet: print response @@ -230,10 +299,10 @@ def send_file(src, size, args, dst): if args.fpga: # tell the fpga to read its new configuration - _execute(dst, "fpga reset") + dst.execute("fpga reset") # log out of the CLI # (firmware/bootloader upgrades reboot, don't need an exit) - _execute(dst, "exit") + dst.execute("exit") return True @@ -300,8 +369,11 @@ def main(): print "Uploading {} from {}".format(name, args.firmware_tarball.name) if not args.quiet: - print "Initializing serial port and synchronizing with HSM, this may take a few seconds" - dst = serial.Serial(args.device, 921600, timeout = 1) + print "Initializing management port and synchronizing with HSM, this may take a few seconds" + try: + dst = ManagementPortSocket(args, timeout = 1) + except socket.error as e: + dst = ManagementPortSerial(args, timeout = 1) send_file(src, size, args, dst) dst.close() |