aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cryptech/libhal.py15
-rwxr-xr-xtests/parallel-signatures.py2
-rw-r--r--tests/test-ecdsa.py108
-rw-r--r--tests/test-rsa.py118
-rw-r--r--unit-tests.py5
5 files changed, 124 insertions, 124 deletions
diff --git a/cryptech/libhal.py b/cryptech/libhal.py
index 1e2dbc6..102e663 100644
--- a/cryptech/libhal.py
+++ b/cryptech/libhal.py
@@ -698,13 +698,14 @@ class HSM(object):
n = length
s = 0
while n == length:
- with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, mask, flags,
- attributes, s, length, u, client = client) as r:
- s = r.unpack_uint()
- n = r.unpack_uint()
- for i in range(n):
- u = UUID(bytes = r.unpack_bytes())
- yield u
+ r = self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, mask, flags,
+ attributes, s, length, u, client = client)
+ s = r.unpack_uint()
+ n = r.unpack_uint()
+ for i in range(n):
+ u = UUID(bytes = r.unpack_bytes())
+ yield u
+ r.done()
def pkey_set_attributes(self, pkey, attributes):
with self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTES, pkey, attributes):
diff --git a/tests/parallel-signatures.py b/tests/parallel-signatures.py
index 87bac0a..200e550 100755
--- a/tests/parallel-signatures.py
+++ b/tests/parallel-signatures.py
@@ -305,7 +305,7 @@ class Result(object):
@property
def mean(self):
- return self.sum / self.n
+ return self.sum / self.n
@property
def secs_per_sig(self):
diff --git a/tests/test-ecdsa.py b/tests/test-ecdsa.py
index 4c14d9f..4989619 100644
--- a/tests/test-ecdsa.py
+++ b/tests/test-ecdsa.py
@@ -52,27 +52,27 @@ from pyasn1.codec.der.decoder import decode as DER_Decode
wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2)
def long_to_bytes(number, order):
- #
- # This is just plain nasty.
- #
- s = "{:x}".format(number)
- s = ("0" * (order/8 - len(s))) + s
- return unhexlify(s)
+ #
+ # This is just plain nasty.
+ #
+ s = "{:x}".format(number)
+ s = ("0" * (order/8 - len(s))) + s
+ return unhexlify(s)
def bytes_to_bits(b):
- #
- # This, on the other hand, is not just plain nasty, this is fancy nasty.
- # This is nasty with raisins in it.
- #
- s = bin(int(hexlify(b), 16))[2:]
- if len(s) % 8:
- s = ("0" * (8 - len(s) % 8)) + s
- return tuple(int(i) for i in s)
+ #
+ # This, on the other hand, is not just plain nasty, this is fancy nasty.
+ # This is nasty with raisins in it.
+ #
+ s = bin(int(hexlify(b), 16))[2:]
+ if len(s) % 8:
+ s = ("0" * (8 - len(s) % 8)) + s
+ return tuple(int(i) for i in s)
###
def encode_sig(r, s, order):
- return long_to_bytes(r, order) + long_to_bytes(s, order)
+ return long_to_bytes(r, order) + long_to_bytes(s, order)
p256_sig = encode_sig(p256_r, p256_s, 256)
p384_sig = encode_sig(p384_r, p384_s, 384)
@@ -80,23 +80,23 @@ p384_sig = encode_sig(p384_r, p384_s, 384)
###
class ECPrivateKey(Sequence):
- componentType = NamedTypes(
- NamedType("version", Integer(namedValues = NamedValues(("ecPrivkeyVer1", 1))
- ).subtype(subtypeSpec = Integer.subtypeSpec + SingleValueConstraint(1))),
- NamedType("privateKey", OctetString()),
- OptionalNamedType("parameters", ObjectIdentifier().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 0))),
- OptionalNamedType("publicKey", BitString().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 1))))
+ componentType = NamedTypes(
+ NamedType("version", Integer(namedValues = NamedValues(("ecPrivkeyVer1", 1))
+ ).subtype(subtypeSpec = Integer.subtypeSpec + SingleValueConstraint(1))),
+ NamedType("privateKey", OctetString()),
+ OptionalNamedType("parameters", ObjectIdentifier().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 0))),
+ OptionalNamedType("publicKey", BitString().subtype(explicitTag = Tag(tagClassContext, tagFormatSimple, 1))))
def encode_key(d, Qx, Qy, order, oid):
- private_key = long_to_bytes(d, order)
- public_key = bytes_to_bits(chr(0x04) + long_to_bytes(Qx, order) + long_to_bytes(Qy, order))
- parameters = oid
- key = ECPrivateKey()
- key["version"] = 1
- key["privateKey"] = private_key
- key["parameters"] = parameters
- key["publicKey"] = public_key
- return DER_Encode(key)
+ private_key = long_to_bytes(d, order)
+ public_key = bytes_to_bits(chr(0x04) + long_to_bytes(Qx, order) + long_to_bytes(Qy, order))
+ parameters = oid
+ key = ECPrivateKey()
+ key["version"] = 1
+ key["privateKey"] = private_key
+ key["parameters"] = parameters
+ key["publicKey"] = public_key
+ return DER_Encode(key)
p256_key = encode_key(p256_d, p256_Qx, p256_Qy, 256, "1.2.840.10045.3.1.7")
p384_key = encode_key(p384_d, p384_Qx, p384_Qy, 384, "1.3.132.0.34")
@@ -112,41 +112,41 @@ curves = ("p256", "p384")
vars = set()
for name in dir():
- head, sep, tail = name.partition("_")
- if head in curves:
- vars.add(tail)
+ head, sep, tail = name.partition("_")
+ if head in curves:
+ vars.add(tail)
vars = sorted(vars)
for curve in curves:
- order = int(curve[1:])
- for var in vars:
- name = curve + "_" + var
- value = globals().get(name, None)
- if isinstance(value, int):
- value = long_to_bytes(value, order)
- if value is not None:
- value = hexlify(value).decode("ascii")
- print()
- print("static const uint8_t {}[] = {{ /* {:d} bytes */".format(name, len(value)))
- print(wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2))))
- print("};")
+ order = int(curve[1:])
+ for var in vars:
+ name = curve + "_" + var
+ value = globals().get(name, None)
+ if isinstance(value, int):
+ value = long_to_bytes(value, order)
+ if value is not None:
+ value = hexlify(value).decode("ascii")
+ print()
+ print("static const uint8_t {}[] = {{ /* {:d} bytes */".format(name, len(value)))
+ print(wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2))))
+ print("};")
print()
print("typedef struct {")
print(" hal_curve_name_t curve;")
for var in vars:
- print(" const uint8_t *{0:>8}; size_t {0:>8}_len;".format(var))
+ print(" const uint8_t *{0:>8}; size_t {0:>8}_len;".format(var))
print("} ecdsa_tc_t;")
print()
print("static const ecdsa_tc_t ecdsa_tc[] = {")
for curve in curves:
- print(" {{ HAL_CURVE_{},".format(curve.upper()))
- for var in vars:
- name = curve + "_" + var
- if name in globals():
- print(" {:<14} sizeof({}),".format(name + ",", name))
- else:
- print(" {:<14} 0,".format("NULL,"))
- print(" },")
+ print(" {{ HAL_CURVE_{},".format(curve.upper()))
+ for var in vars:
+ name = curve + "_" + var
+ if name in globals():
+ print(" {:<14} sizeof({}),".format(name + ",", name))
+ else:
+ print(" {:<14} 0,".format("NULL,"))
+ print(" },")
print("};")
diff --git a/tests/test-rsa.py b/tests/test-rsa.py
index d0538ed..39f46cd 100644
--- a/tests/test-rsa.py
+++ b/tests/test-rsa.py
@@ -45,10 +45,10 @@ from textwrap import TextWrapper
import sys, os.path
def KeyLengthType(arg):
- val = int(arg)
- if val % 8 != 0:
- raise ValueError
- return val
+ val = int(arg)
+ if val % 8 != 0:
+ raise ValueError
+ return val
parser = ArgumentParser(description = __doc__)
parser.add_argument("--pad-to-modulus", action = "store_true",
@@ -71,22 +71,22 @@ scriptname = os.path.basename(sys.argv[0])
wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2)
def printlines(*lines, **kwargs):
- for line in lines:
- args.output.write(line.format(**kwargs) + "\n")
+ for line in lines:
+ args.output.write(line.format(**kwargs) + "\n")
def trailing_comma(item, sequence):
- return "" if item == sequence[-1] else ","
+ return "" if item == sequence[-1] else ","
def print_hex(name, value, comment):
- value = hexlify(value).decode("ascii")
- printlines("static const uint8_t {name}[] = {{ /* {comment}, {length:d} bytes */",
- wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2)))
- "}};", "",
- name = name, comment = comment, length = len(value))
+ value = hexlify(value).decode("ascii")
+ printlines("static const uint8_t {name}[] = {{ /* {comment}, {length:d} bytes */",
+ wrapper.fill(", ".join("0x" + value[i : i + 2] for i in range(0, len(value), 2)))
+ "}};", "",
+ name = name, comment = comment, length = len(value))
def pad_to_blocksize(value, blocksize):
- extra = len(value) % blocksize
- return value if extra == 0 else (b"\x00" * (blocksize - extra)) + value
+ extra = len(value) % blocksize
+ return value if extra == 0 else (b"\x00" * (blocksize - extra)) + value
# Funnily enough, PyCrypto and Cryptlib use exactly the same names for
# RSA key components, see Cryptlib documentation pages 186-187 & 339.
@@ -109,48 +109,48 @@ fields = ("n", "e", "d", "p", "q", "dP", "dQ", "u", "m", "s")
for k_len in args.key_lengths:
- k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason,
- while k.u >= k.p: # and I'm sure not going to argue the math with Peter,
- k = RSA.generate(k_len) # so keep trying until we pass this test
-
- m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8)
- s = PKCS115_SigScheme(k).sign(h)
- assert len(m) == len(s)
-
- if args.pad_to_modulus:
- blocksize = k_len/8
- if args.extra_word:
- blocksize += 4
- else:
- blocksize = 4
-
- printlines("/* {k_len:d}-bit RSA private key (PKCS #{pkcs:d})",
- k.exportKey(format = "PEM", pkcs = args.pkcs_encoding),
- "*/", "",
- k_len = k_len, pkcs = args.pkcs_encoding)
-
- # PyCrypto doesn't precalculate dP or dQ, and for some reason it
- # does u backwards (uses (1/p % q) and swaps the roles of p and q in
- # the CRT calculation to compensate), so we just calculate our own.
-
- for name in fields:
- if name in "ms":
- continue
- elif name == "dP":
- value = k.d % (k.p - 1)
- elif name == "dQ":
- value = k.d % (k.q - 1)
- elif name == "u":
- value = inverse(k.q, k.p)
- else:
- value = getattr(k, name)
+ k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason,
+ while k.u >= k.p: # and I'm sure not going to argue the math with Peter,
+ k = RSA.generate(k_len) # so keep trying until we pass this test
- print_hex("{}_{:d}".format(name, k_len),
- long_to_bytes(value, blocksize = blocksize),
- "key component {}".format(name))
+ m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8)
+ s = PKCS115_SigScheme(k).sign(h)
+ assert len(m) == len(s)
- print_hex("m_{:d}".format(k_len), pad_to_blocksize(m, blocksize), "message to be signed")
- print_hex("s_{:d}".format(k_len), pad_to_blocksize(s, blocksize), "signed message")
+ if args.pad_to_modulus:
+ blocksize = k_len/8
+ if args.extra_word:
+ blocksize += 4
+ else:
+ blocksize = 4
+
+ printlines("/* {k_len:d}-bit RSA private key (PKCS #{pkcs:d})",
+ k.exportKey(format = "PEM", pkcs = args.pkcs_encoding),
+ "*/", "",
+ k_len = k_len, pkcs = args.pkcs_encoding)
+
+ # PyCrypto doesn't precalculate dP or dQ, and for some reason it
+ # does u backwards (uses (1/p % q) and swaps the roles of p and q in
+ # the CRT calculation to compensate), so we just calculate our own.
+
+ for name in fields:
+ if name in "ms":
+ continue
+ elif name == "dP":
+ value = k.d % (k.p - 1)
+ elif name == "dQ":
+ value = k.d % (k.q - 1)
+ elif name == "u":
+ value = inverse(k.q, k.p)
+ else:
+ value = getattr(k, name)
+
+ print_hex("{}_{:d}".format(name, k_len),
+ long_to_bytes(value, blocksize = blocksize),
+ "key component {}".format(name))
+
+ print_hex("m_{:d}".format(k_len), pad_to_blocksize(m, blocksize), "message to be signed")
+ print_hex("s_{:d}".format(k_len), pad_to_blocksize(s, blocksize), "signed message")
printlines("typedef struct {{ const uint8_t *val; size_t len; }} rsa_tc_bn_t;",
"typedef struct {{ size_t size; rsa_tc_bn_t {fields}; }} rsa_tc_t;",
@@ -158,9 +158,9 @@ printlines("typedef struct {{ const uint8_t *val; size_t len; }} rsa_tc_bn_t;",
"static const rsa_tc_t rsa_tc[] = {{",
fields = ", ".join(fields))
for k_len in args.key_lengths:
- printlines(" {{ {k_len:d},", k_len = k_len)
- for field in fields:
- printlines(" {{ {field}_{k_len:d}, sizeof({field}_{k_len:d}) }}{comma}",
- field = field, k_len = k_len, comma = trailing_comma(field, fields))
- printlines(" }}{comma}", comma = trailing_comma(k_len, args.key_lengths))
+ printlines(" {{ {k_len:d},", k_len = k_len)
+ for field in fields:
+ printlines(" {{ {field}_{k_len:d}, sizeof({field}_{k_len:d}) }}{comma}",
+ field = field, k_len = k_len, comma = trailing_comma(field, fields))
+ printlines(" }}{comma}", comma = trailing_comma(k_len, args.key_lengths))
printlines("}};")
diff --git a/unit-tests.py b/unit-tests.py
index fab09e8..b1e8a21 100644
--- a/unit-tests.py
+++ b/unit-tests.py
@@ -1227,7 +1227,6 @@ class TestPKeyMatch(TestCaseLoggedIn):
with hsm.pkey_load(obj.der, flags) as k:
self.addCleanup(self.cleanup_key, k.uuid)
uuids.add(k.uuid)
- #print k.uuid, k.key_type, k.key_curve, self.key_flag_names(k.key_flags)
k.set_attributes(dict((i, a) for i, a in enumerate((str(obj.keytype), str(obj.fn2)))))
return uuids
@@ -1256,14 +1255,14 @@ class TestPKeyMatch(TestCaseLoggedIn):
n = 0
for n, k in self.match(mask = mask, flags = flags, uuids = uuids, type = keytype):
self.assertEqual(k.key_type, keytype)
- self.assertEqual(k.get_attributes({0}).pop(0), bytes(keytype))
+ self.assertEqual(k.get_attributes({0}).pop(0).decode(), str(keytype))
self.assertEqual(n, sum(1 for t1, t2 in tags if t1 == keytype))
for curve in set(HALCurve.index.values()) - {HAL_CURVE_NONE}:
n = 0
for n, k in self.match(mask = mask, flags = flags, uuids = uuids, curve = curve):
self.assertEqual(k.key_curve, curve)
- self.assertEqual(k.get_attributes({1}).pop(1), str(curve))
+ self.assertEqual(k.get_attributes({1}).pop(1).decode(), str(curve))
self.assertIn(k.key_type, (HAL_KEY_TYPE_EC_PUBLIC,
HAL_KEY_TYPE_EC_PRIVATE))
self.assertEqual(n, sum(1 for t1, t2 in tags if t2 == curve))