From d015707e0e26bafbfad72511c0a47cbb66c90971 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Wed, 10 Jun 2020 15:38:05 -0400 Subject: Fix remaining Python 3 unit test string encoding bug Really just one bug, but confusingly masked by an interaction between generators and our XDR context manager, so don't use the context manager in the one generator method in the cryptech.libhal API. Also run reindent.py on a few old test modules. --- tests/parallel-signatures.py | 2 +- tests/test-ecdsa.py | 108 +++++++++++++++++++-------------------- tests/test-rsa.py | 118 +++++++++++++++++++++---------------------- 3 files changed, 114 insertions(+), 114 deletions(-) (limited to 'tests') diff --git a/tests/parallel-signatures.py b/tests/parallel-signatures.py index 87bac0a..200e550 100755 --- a/tests/parallel-signatures.py +++ b/tests/parallel-signatures.py @@ -305,7 +305,7 @@ class Result(object): @property def mean(self): - return self.sum / self.n + return self.sum / self.n @property def secs_per_sig(self): diff --git a/tests/test-ecdsa.py b/tests/test-ecdsa.py index 4c14d9f..4989619 100644 --- a/tests/test-ecdsa.py +++ b/tests/test-ecdsa.py @@ -52,27 +52,27 @@ from pyasn1.codec.der.decoder import decode as DER_Decode wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2) def long_to_bytes(number, order): - # - # This is just plain nasty. - # - s = "{:x}".format(number) - s = ("0" * (order/8 - len(s))) + s - return unhexlify(s) + # + # This is just plain nasty. + # + s = "{:x}".format(number) + s = ("0" * (order/8 - len(s))) + s + return unhexlify(s) def bytes_to_bits(b): - # - # This, on the other hand, is not just plain nasty, this is fancy nasty. - # This is nasty with raisins in it. - # - s = bin(int(hexlify(b), 16))[2:] - if len(s) % 8: - s = ("0" * (8 - len(s) % 8)) + s - return tuple(int(i) for i in s) + # + # This, on the other hand, is not just plain nasty, this is fancy nasty. + # This is nasty with raisins in it. + # + s = bin(int(hexlify(b), 16))[2:] + if len(s) % 8: + s = ("0" * (8 - len(s) % 8)) + s + return tuple(int(i) for i in s) ### def encode_sig(r, s, order): - return long_to_bytes(r, order) + long_to_bytes(s, order) + return long_to_bytes(r, order) + long_to_bytes(s, order) p256_sig = encode_sig(p256_r, p256_s, 256) p384_sig = encode_sig(p384_r, p384_s, 384) @@ -80,23 +80,23 @@ p384_sig = encode_sig(p384_r, p384_s, 384) ### class ECPrivateKey(Sequence): - componentType = NamedTypes( - NamedType("version", Integer(namedValues = NamedValues(("ecPrivkeyVer1", 1)) - ).subtype(subtypeSpec = Integer.subtypeSpec + SingleValueConstraint(1))), - NamedType("privateKey", OctetString()), - OptionalNamedType("parameters", ObjectIdentifier().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 0))), - OptionalNamedType("publicKey", BitString().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 1)))) + componentType = NamedTypes( + NamedType("version", Integer(namedValues = NamedValues(("ecPrivkeyVer1", 1)) + ).subtype(subtypeSpec = Integer.subtypeSpec + SingleValueConstraint(1))), + NamedType("privateKey", OctetString()), + OptionalNamedType("parameters", ObjectIdentifier().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 0))), + OptionalNamedType("publicKey", BitString().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 1)))) def encode_key(d, Qx, Qy, order, oid): - private_key = long_to_bytes(d, order) - public_key = bytes_to_bits(chr(0x04) + long_to_bytes(Qx, order) + long_to_bytes(Qy, order)) - parameters = oid - key = ECPrivateKey() - key["version"] = 1 - key["privateKey"] = private_key - key["parameters"] = parameters - key["publicKey"] = public_key - return DER_Encode(key) + private_key = long_to_bytes(d, order) + public_key = bytes_to_bits(chr(0x04) + long_to_bytes(Qx, order) + long_to_bytes(Qy, order)) + parameters = oid + key = ECPrivateKey() + key["version"] = 1 + key["privateKey"] = private_key + key["parameters"] = parameters + key["publicKey"] = public_key + return DER_Encode(key) p256_key = encode_key(p256_d, p256_Qx, p256_Qy, 256, "1.2.840.10045.3.1.7") p384_key = encode_key(p384_d, p384_Qx, p384_Qy, 384, "1.3.132.0.34") @@ -112,41 +112,41 @@ curves = ("p256", "p384") vars = set() for name in dir(): - head, sep, tail = name.partition("_") - if head in curves: - vars.add(tail) + head, sep, tail = name.partition("_") + if head in curves: + vars.add(tail) vars = sorted(vars) for curve in curves: - order = int(curve[1:]) - for var in vars: - name = curve + "_" + var - value = globals().get(name, None) - if isinstance(value, int): - value = long_to_bytes(value, order) - if value is not None: - value = hexlify(value).decode("ascii") - print() - print("static const uint8_t {}[] = {{ /* {:d} bytes */".format(name, len(value))) - print(wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2)))) - print("};") + order = int(curve[1:]) + for var in vars: + name = curve + "_" + var + value = globals().get(name, None) + if isinstance(value, int): + value = long_to_bytes(value, order) + if value is not None: + value = hexlify(value).decode("ascii") + print() + print("static const uint8_t {}[] = {{ /* {:d} bytes */".format(name, len(value))) + print(wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2)))) + print("};") print() print("typedef struct {") print(" hal_curve_name_t curve;") for var in vars: - print(" const uint8_t *{0:>8}; size_t {0:>8}_len;".format(var)) + print(" const uint8_t *{0:>8}; size_t {0:>8}_len;".format(var)) print("} ecdsa_tc_t;") print() print("static const ecdsa_tc_t ecdsa_tc[] = {") for curve in curves: - print(" {{ HAL_CURVE_{},".format(curve.upper())) - for var in vars: - name = curve + "_" + var - if name in globals(): - print(" {:<14} sizeof({}),".format(name + ",", name)) - else: - print(" {:<14} 0,".format("NULL,")) - print(" },") + print(" {{ HAL_CURVE_{},".format(curve.upper())) + for var in vars: + name = curve + "_" + var + if name in globals(): + print(" {:<14} sizeof({}),".format(name + ",", name)) + else: + print(" {:<14} 0,".format("NULL,")) + print(" },") print("};") diff --git a/tests/test-rsa.py b/tests/test-rsa.py index d0538ed..39f46cd 100644 --- a/tests/test-rsa.py +++ b/tests/test-rsa.py @@ -45,10 +45,10 @@ from textwrap import TextWrapper import sys, os.path def KeyLengthType(arg): - val = int(arg) - if val % 8 != 0: - raise ValueError - return val + val = int(arg) + if val % 8 != 0: + raise ValueError + return val parser = ArgumentParser(description = __doc__) parser.add_argument("--pad-to-modulus", action = "store_true", @@ -71,22 +71,22 @@ scriptname = os.path.basename(sys.argv[0]) wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2) def printlines(*lines, **kwargs): - for line in lines: - args.output.write(line.format(**kwargs) + "\n") + for line in lines: + args.output.write(line.format(**kwargs) + "\n") def trailing_comma(item, sequence): - return "" if item == sequence[-1] else "," + return "" if item == sequence[-1] else "," def print_hex(name, value, comment): - value = hexlify(value).decode("ascii") - printlines("static const uint8_t {name}[] = {{ /* {comment}, {length:d} bytes */", - wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2))) - "}};", "", - name = name, comment = comment, length = len(value)) + value = hexlify(value).decode("ascii") + printlines("static const uint8_t {name}[] = {{ /* {comment}, {length:d} bytes */", + wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2))) + "}};", "", + name = name, comment = comment, length = len(value)) def pad_to_blocksize(value, blocksize): - extra = len(value) % blocksize - return value if extra == 0 else (b"\x00" * (blocksize - extra)) + value + extra = len(value) % blocksize + return value if extra == 0 else (b"\x00" * (blocksize - extra)) + value # Funnily enough, PyCrypto and Cryptlib use exactly the same names for # RSA key components, see Cryptlib documentation pages 186-187 & 339. @@ -109,48 +109,48 @@ fields = ("n", "e", "d", "p", "q", "dP", "dQ", "u", "m", "s") for k_len in args.key_lengths: - k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason, - while k.u >= k.p: # and I'm sure not going to argue the math with Peter, - k = RSA.generate(k_len) # so keep trying until we pass this test - - m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8) - s = PKCS115_SigScheme(k).sign(h) - assert len(m) == len(s) - - if args.pad_to_modulus: - blocksize = k_len/8 - if args.extra_word: - blocksize += 4 - else: - blocksize = 4 - - printlines("/* {k_len:d}-bit RSA private key (PKCS #{pkcs:d})", - k.exportKey(format = "PEM", pkcs = args.pkcs_encoding), - "*/", "", - k_len = k_len, pkcs = args.pkcs_encoding) - - # PyCrypto doesn't precalculate dP or dQ, and for some reason it - # does u backwards (uses (1/p % q) and swaps the roles of p and q in - # the CRT calculation to compensate), so we just calculate our own. - - for name in fields: - if name in "ms": - continue - elif name == "dP": - value = k.d % (k.p - 1) - elif name == "dQ": - value = k.d % (k.q - 1) - elif name == "u": - value = inverse(k.q, k.p) - else: - value = getattr(k, name) + k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason, + while k.u >= k.p: # and I'm sure not going to argue the math with Peter, + k = RSA.generate(k_len) # so keep trying until we pass this test - print_hex("{}_{:d}".format(name, k_len), - long_to_bytes(value, blocksize = blocksize), - "key component {}".format(name)) + m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8) + s = PKCS115_SigScheme(k).sign(h) + assert len(m) == len(s) - print_hex("m_{:d}".format(k_len), pad_to_blocksize(m, blocksize), "message to be signed") - print_hex("s_{:d}".format(k_len), pad_to_blocksize(s, blocksize), "signed message") + if args.pad_to_modulus: + blocksize = k_len/8 + if args.extra_word: + blocksize += 4 + else: + blocksize = 4 + + printlines("/* {k_len:d}-bit RSA private key (PKCS #{pkcs:d})", + k.exportKey(format = "PEM", pkcs = args.pkcs_encoding), + "*/", "", + k_len = k_len, pkcs = args.pkcs_encoding) + + # PyCrypto doesn't precalculate dP or dQ, and for some reason it + # does u backwards (uses (1/p % q) and swaps the roles of p and q in + # the CRT calculation to compensate), so we just calculate our own. + + for name in fields: + if name in "ms": + continue + elif name == "dP": + value = k.d % (k.p - 1) + elif name == "dQ": + value = k.d % (k.q - 1) + elif name == "u": + value = inverse(k.q, k.p) + else: + value = getattr(k, name) + + print_hex("{}_{:d}".format(name, k_len), + long_to_bytes(value, blocksize = blocksize), + "key component {}".format(name)) + + print_hex("m_{:d}".format(k_len), pad_to_blocksize(m, blocksize), "message to be signed") + print_hex("s_{:d}".format(k_len), pad_to_blocksize(s, blocksize), "signed message") printlines("typedef struct {{ const uint8_t *val; size_t len; }} rsa_tc_bn_t;", "typedef struct {{ size_t size; rsa_tc_bn_t {fields}; }} rsa_tc_t;", @@ -158,9 +158,9 @@ printlines("typedef struct {{ const uint8_t *val; size_t len; }} rsa_tc_bn_t;", "static const rsa_tc_t rsa_tc[] = {{", fields = ", ".join(fields)) for k_len in args.key_lengths: - printlines(" {{ {k_len:d},", k_len = k_len) - for field in fields: - printlines(" {{ {field}_{k_len:d}, sizeof({field}_{k_len:d}) }}{comma}", - field = field, k_len = k_len, comma = trailing_comma(field, fields)) - printlines(" }}{comma}", comma = trailing_comma(k_len, args.key_lengths)) + printlines(" {{ {k_len:d},", k_len = k_len) + for field in fields: + printlines(" {{ {field}_{k_len:d}, sizeof({field}_{k_len:d}) }}{comma}", + field = field, k_len = k_len, comma = trailing_comma(field, fields)) + printlines(" }}{comma}", comma = trailing_comma(k_len, args.key_lengths)) printlines("}};") -- cgit v1.2.3