diff options
-rw-r--r-- | cryptech/libhal.py | 15 | ||||
-rwxr-xr-x | tests/parallel-signatures.py | 2 | ||||
-rw-r--r-- | tests/test-ecdsa.py | 108 | ||||
-rw-r--r-- | tests/test-rsa.py | 118 | ||||
-rw-r--r-- | unit-tests.py | 5 |
5 files changed, 124 insertions, 124 deletions
diff --git a/cryptech/libhal.py b/cryptech/libhal.py index 1e2dbc6..102e663 100644 --- a/cryptech/libhal.py +++ b/cryptech/libhal.py @@ -698,13 +698,14 @@ class HSM(object): n = length s = 0 while n == length: - with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, mask, flags, - attributes, s, length, u, client = client) as r: - s = r.unpack_uint() - n = r.unpack_uint() - for i in range(n): - u = UUID(bytes = r.unpack_bytes()) - yield u + r = self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, mask, flags, + attributes, s, length, u, client = client) + s = r.unpack_uint() + n = r.unpack_uint() + for i in range(n): + u = UUID(bytes = r.unpack_bytes()) + yield u + r.done() def pkey_set_attributes(self, pkey, attributes): with self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTES, pkey, attributes): diff --git a/tests/parallel-signatures.py b/tests/parallel-signatures.py index 87bac0a..200e550 100755 --- a/tests/parallel-signatures.py +++ b/tests/parallel-signatures.py @@ -305,7 +305,7 @@ class Result(object): @property def mean(self): - return self.sum / self.n + return self.sum / self.n @property def secs_per_sig(self): diff --git a/tests/test-ecdsa.py b/tests/test-ecdsa.py index 4c14d9f..4989619 100644 --- a/tests/test-ecdsa.py +++ b/tests/test-ecdsa.py @@ -52,27 +52,27 @@ from pyasn1.codec.der.decoder import decode as DER_Decode wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2) def long_to_bytes(number, order): - # - # This is just plain nasty. - # - s = "{:x}".format(number) - s = ("0" * (order/8 - len(s))) + s - return unhexlify(s) + # + # This is just plain nasty. + # + s = "{:x}".format(number) + s = ("0" * (order/8 - len(s))) + s + return unhexlify(s) def bytes_to_bits(b): - # - # This, on the other hand, is not just plain nasty, this is fancy nasty. - # This is nasty with raisins in it. - # - s = bin(int(hexlify(b), 16))[2:] - if len(s) % 8: - s = ("0" * (8 - len(s) % 8)) + s - return tuple(int(i) for i in s) + # + # This, on the other hand, is not just plain nasty, this is fancy nasty. + # This is nasty with raisins in it. + # + s = bin(int(hexlify(b), 16))[2:] + if len(s) % 8: + s = ("0" * (8 - len(s) % 8)) + s + return tuple(int(i) for i in s) ### def encode_sig(r, s, order): - return long_to_bytes(r, order) + long_to_bytes(s, order) + return long_to_bytes(r, order) + long_to_bytes(s, order) p256_sig = encode_sig(p256_r, p256_s, 256) p384_sig = encode_sig(p384_r, p384_s, 384) @@ -80,23 +80,23 @@ p384_sig = encode_sig(p384_r, p384_s, 384) ### class ECPrivateKey(Sequence): - componentType = NamedTypes( - NamedType("version", Integer(namedValues = NamedValues(("ecPrivkeyVer1", 1)) - ).subtype(subtypeSpec = Integer.subtypeSpec + SingleValueConstraint(1))), - NamedType("privateKey", OctetString()), - OptionalNamedType("parameters", ObjectIdentifier().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 0))), - OptionalNamedType("publicKey", BitString().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 1)))) + componentType = NamedTypes( + NamedType("version", Integer(namedValues = NamedValues(("ecPrivkeyVer1", 1)) + ).subtype(subtypeSpec = Integer.subtypeSpec + SingleValueConstraint(1))), + NamedType("privateKey", OctetString()), + OptionalNamedType("parameters", ObjectIdentifier().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 0))), + OptionalNamedType("publicKey", BitString().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 1)))) def encode_key(d, Qx, Qy, order, oid): - private_key = long_to_bytes(d, order) - public_key = bytes_to_bits(chr(0x04) + long_to_bytes(Qx, order) + long_to_bytes(Qy, order)) - parameters = oid - key = ECPrivateKey() - key["version"] = 1 - key["privateKey"] = private_key - key["parameters"] = parameters - key["publicKey"] = public_key - return DER_Encode(key) + private_key = long_to_bytes(d, order) + public_key = bytes_to_bits(chr(0x04) + long_to_bytes(Qx, order) + long_to_bytes(Qy, order)) + parameters = oid + key = ECPrivateKey() + key["version"] = 1 + key["privateKey"] = private_key + key["parameters"] = parameters + key["publicKey"] = public_key + return DER_Encode(key) p256_key = encode_key(p256_d, p256_Qx, p256_Qy, 256, "1.2.840.10045.3.1.7") p384_key = encode_key(p384_d, p384_Qx, p384_Qy, 384, "1.3.132.0.34") @@ -112,41 +112,41 @@ curves = ("p256", "p384") vars = set() for name in dir(): - head, sep, tail = name.partition("_") - if head in curves: - vars.add(tail) + head, sep, tail = name.partition("_") + if head in curves: + vars.add(tail) vars = sorted(vars) for curve in curves: - order = int(curve[1:]) - for var in vars: - name = curve + "_" + var - value = globals().get(name, None) - if isinstance(value, int): - value = long_to_bytes(value, order) - if value is not None: - value = hexlify(value).decode("ascii") - print() - print("static const uint8_t {}[] = {{ /* {:d} bytes */".format(name, len(value))) - print(wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2)))) - print("};") + order = int(curve[1:]) + for var in vars: + name = curve + "_" + var + value = globals().get(name, None) + if isinstance(value, int): + value = long_to_bytes(value, order) + if value is not None: + value = hexlify(value).decode("ascii") + print() + print("static const uint8_t {}[] = {{ /* {:d} bytes */".format(name, len(value))) + print(wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2)))) + print("};") print() print("typedef struct {") print(" hal_curve_name_t curve;") for var in vars: - print(" const uint8_t *{0:>8}; size_t {0:>8}_len;".format(var)) + print(" const uint8_t *{0:>8}; size_t {0:>8}_len;".format(var)) print("} ecdsa_tc_t;") print() print("static const ecdsa_tc_t ecdsa_tc[] = {") for curve in curves: - print(" {{ HAL_CURVE_{},".format(curve.upper())) - for var in vars: - name = curve + "_" + var - if name in globals(): - print(" {:<14} sizeof({}),".format(name + ",", name)) - else: - print(" {:<14} 0,".format("NULL,")) - print(" },") + print(" {{ HAL_CURVE_{},".format(curve.upper())) + for var in vars: + name = curve + "_" + var + if name in globals(): + print(" {:<14} sizeof({}),".format(name + ",", name)) + else: + print(" {:<14} 0,".format("NULL,")) + print(" },") print("};") diff --git a/tests/test-rsa.py b/tests/test-rsa.py index d0538ed..39f46cd 100644 --- a/tests/test-rsa.py +++ b/tests/test-rsa.py @@ -45,10 +45,10 @@ from textwrap import TextWrapper import sys, os.path def KeyLengthType(arg): - val = int(arg) - if val % 8 != 0: - raise ValueError - return val + val = int(arg) + if val % 8 != 0: + raise ValueError + return val parser = ArgumentParser(description = __doc__) parser.add_argument("--pad-to-modulus", action = "store_true", @@ -71,22 +71,22 @@ scriptname = os.path.basename(sys.argv[0]) wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2) def printlines(*lines, **kwargs): - for line in lines: - args.output.write(line.format(**kwargs) + "\n") + for line in lines: + args.output.write(line.format(**kwargs) + "\n") def trailing_comma(item, sequence): - return "" if item == sequence[-1] else "," + return "" if item == sequence[-1] else "," def print_hex(name, value, comment): - value = hexlify(value).decode("ascii") - printlines("static const uint8_t {name}[] = {{ /* {comment}, {length:d} bytes */", - wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2))) - "}};", "", - name = name, comment = comment, length = len(value)) + value = hexlify(value).decode("ascii") + printlines("static const uint8_t {name}[] = {{ /* {comment}, {length:d} bytes */", + wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2))) + "}};", "", + name = name, comment = comment, length = len(value)) def pad_to_blocksize(value, blocksize): - extra = len(value) % blocksize - return value if extra == 0 else (b"\x00" * (blocksize - extra)) + value + extra = len(value) % blocksize + return value if extra == 0 else (b"\x00" * (blocksize - extra)) + value # Funnily enough, PyCrypto and Cryptlib use exactly the same names for # RSA key components, see Cryptlib documentation pages 186-187 & 339. @@ -109,48 +109,48 @@ fields = ("n", "e", "d", "p", "q", "dP", "dQ", "u", "m", "s") for k_len in args.key_lengths: - k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason, - while k.u >= k.p: # and I'm sure not going to argue the math with Peter, - k = RSA.generate(k_len) # so keep trying until we pass this test - - m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8) - s = PKCS115_SigScheme(k).sign(h) - assert len(m) == len(s) - - if args.pad_to_modulus: - blocksize = k_len/8 - if args.extra_word: - blocksize += 4 - else: - blocksize = 4 - - printlines("/* {k_len:d}-bit RSA private key (PKCS #{pkcs:d})", - k.exportKey(format = "PEM", pkcs = args.pkcs_encoding), - "*/", "", - k_len = k_len, pkcs = args.pkcs_encoding) - - # PyCrypto doesn't precalculate dP or dQ, and for some reason it - # does u backwards (uses (1/p % q) and swaps the roles of p and q in - # the CRT calculation to compensate), so we just calculate our own. - - for name in fields: - if name in "ms": - continue - elif name == "dP": - value = k.d % (k.p - 1) - elif name == "dQ": - value = k.d % (k.q - 1) - elif name == "u": - value = inverse(k.q, k.p) - else: - value = getattr(k, name) + k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason, + while k.u >= k.p: # and I'm sure not going to argue the math with Peter, + k = RSA.generate(k_len) # so keep trying until we pass this test - print_hex("{}_{:d}".format(name, k_len), - long_to_bytes(value, blocksize = blocksize), - "key component {}".format(name)) + m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8) + s = PKCS115_SigScheme(k).sign(h) + assert len(m) == len(s) - print_hex("m_{:d}".format(k_len), pad_to_blocksize(m, blocksize), "message to be signed") - print_hex("s_{:d}".format(k_len), pad_to_blocksize(s, blocksize), "signed message") + if args.pad_to_modulus: + blocksize = k_len/8 + if args.extra_word: + blocksize += 4 + else: + blocksize = 4 + + printlines("/* {k_len:d}-bit RSA private key (PKCS #{pkcs:d})", + k.exportKey(format = "PEM", pkcs = args.pkcs_encoding), + "*/", "", + k_len = k_len, pkcs = args.pkcs_encoding) + + # PyCrypto doesn't precalculate dP or dQ, and for some reason it + # does u backwards (uses (1/p % q) and swaps the roles of p and q in + # the CRT calculation to compensate), so we just calculate our own. + + for name in fields: + if name in "ms": + continue + elif name == "dP": + value = k.d % (k.p - 1) + elif name == "dQ": + value = k.d % (k.q - 1) + elif name == "u": + value = inverse(k.q, k.p) + else: + value = getattr(k, name) + + print_hex("{}_{:d}".format(name, k_len), + long_to_bytes(value, blocksize = blocksize), + "key component {}".format(name)) + + print_hex("m_{:d}".format(k_len), pad_to_blocksize(m, blocksize), "message to be signed") + print_hex("s_{:d}".format(k_len), pad_to_blocksize(s, blocksize), "signed message") printlines("typedef struct {{ const uint8_t *val; size_t len; }} rsa_tc_bn_t;", "typedef struct {{ size_t size; rsa_tc_bn_t {fields}; }} rsa_tc_t;", @@ -158,9 +158,9 @@ printlines("typedef struct {{ const uint8_t *val; size_t len; }} rsa_tc_bn_t;", "static const rsa_tc_t rsa_tc[] = {{", fields = ", ".join(fields)) for k_len in args.key_lengths: - printlines(" {{ {k_len:d},", k_len = k_len) - for field in fields: - printlines(" {{ {field}_{k_len:d}, sizeof({field}_{k_len:d}) }}{comma}", - field = field, k_len = k_len, comma = trailing_comma(field, fields)) - printlines(" }}{comma}", comma = trailing_comma(k_len, args.key_lengths)) + printlines(" {{ {k_len:d},", k_len = k_len) + for field in fields: + printlines(" {{ {field}_{k_len:d}, sizeof({field}_{k_len:d}) }}{comma}", + field = field, k_len = k_len, comma = trailing_comma(field, fields)) + printlines(" }}{comma}", comma = trailing_comma(k_len, args.key_lengths)) printlines("}};") diff --git a/unit-tests.py b/unit-tests.py index fab09e8..b1e8a21 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -1227,7 +1227,6 @@ class TestPKeyMatch(TestCaseLoggedIn): with hsm.pkey_load(obj.der, flags) as k: self.addCleanup(self.cleanup_key, k.uuid) uuids.add(k.uuid) - #print k.uuid, k.key_type, k.key_curve, self.key_flag_names(k.key_flags) k.set_attributes(dict((i, a) for i, a in enumerate((str(obj.keytype), str(obj.fn2))))) return uuids @@ -1256,14 +1255,14 @@ class TestPKeyMatch(TestCaseLoggedIn): n = 0 for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = keytype): self.assertEqual(k.key_type, keytype) - self.assertEqual(k.get_attributes({0}).pop(0), bytes(keytype)) + self.assertEqual(k.get_attributes({0}).pop(0).decode(), str(keytype)) self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype)) for curve in set(HALCurve.index.values()) - {HAL_CURVE_NONE}: n = 0 for n, k in self.match(mask = mask, flags = flags, uuids = uuids, curve = curve): self.assertEqual(k.key_curve, curve) - self.assertEqual(k.get_attributes({1}).pop(1), str(curve)) + self.assertEqual(k.get_attributes({1}).pop(1).decode(), str(curve)) self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC, HAL_KEY_TYPE_EC_PRIVATE)) self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve)) |