From 86e3b101356659aadafd607ef9c9ddb2d3425fd5 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Mon, 1 Jun 2015 16:44:45 -0400 Subject: Add padding options to test workaround for current ModExp bugs. --- tests/test-rsa.py | 65 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/tests/test-rsa.py b/tests/test-rsa.py index 759d5a3..e6b6d56 100644 --- a/tests/test-rsa.py +++ b/tests/test-rsa.py @@ -1,13 +1,13 @@ -# Use PyCrypto to generate test data for Cryptech ModExp core. -# -# Funnily enough, PyCrypto and Cryptlib use exactly the same names for -# RSA key components, see Cryptlib documentation pages 186-187 & 339. +#!/usr/bin/env python -key_lengths = (1024, 2048, 4096) # Lengths in bits of keys to generate -pkcs_encoding = 8 # PKCS encoding for PEM comment (1 or 8) +""" +Use PyCrypto to generate test data for Cryptech ModExp core. +""" -plaintext = "You can hack anything you want with TECO and DDT." +# Funnily enough, PyCrypto and Cryptlib use exactly the same names for +# RSA key components, see Cryptlib documentation pages 186-187 & 339. +from argparse import ArgumentParser, FileType from Crypto import __version__ as PyCryptoVersion from Crypto.PublicKey import RSA from Crypto.Hash import SHA256 @@ -16,7 +16,27 @@ from Crypto.Signature.PKCS1_v1_5 import EMSA_PKCS1_V1_5_ENCODE, PKCS115_S from textwrap import TextWrapper import sys, os.path -assert all(key_length % 8 == 0 for key_length in key_lengths) +def KeyLengthType(arg): + val = int(arg) + if val % 8 != 0: + raise ValueError + return val + +parser = ArgumentParser(description = __doc__) +parser.add_argument("--pad-to-modulus", action = "store_true", + help = "zero-pad to modulus size (bug workaround)") +parser.add_argument("--extra-word", action = "store_true", + help = "add extra word of zero padding (bug workaround)") +parser.add_argument("-k", "--key-lengths", type = KeyLengthType, + nargs = "*", default = [1024, 2048, 4096], + help = "Lengths in bits of keys to generate") +parser.add_argument("--pkcs-encoding", type = int, choices = (1, 8), default = 8, + help = "PKCS encoding to use for PEM commented private key") +parser.add_argument("output", nargs = "?", type = FileType("w"), default = sys.stdout, + help = "output file") +args = parser.parse_args() + +plaintext = "You can hack anything you want with TECO and DDT." scriptname = os.path.basename(sys.argv[0]) @@ -24,7 +44,7 @@ wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = def printlines(*lines, **kwargs): for line in lines: - sys.stdout.write(line % kwargs + "\n") + args.output.write(line % kwargs + "\n") def trailing_comma(item, sequence): return "" if item == sequence[-1] else "," @@ -35,6 +55,10 @@ def print_hex(name, value, comment): "};", "", name = name, comment = comment, length = len(value)) +def pad_to_blocksize(value, blocksize): + extra = len(value) % blocksize + return value if extra == 0 else ("\x00" * (blocksize - extra)) + value + h = SHA256.new(plaintext) printlines("/*", @@ -49,7 +73,7 @@ printlines("/*", plaintext = plaintext, digest = h.hexdigest()) -for k_len in key_lengths: +for k_len in args.key_lengths: k = RSA.generate(k_len) # Cryptlib insists u < p, probably with good reason, while k.u >= k.p: # and I'm sure not going to argue the math with Peter, @@ -59,17 +83,24 @@ for k_len in key_lengths: s = PKCS115_SigScheme(k).sign(h) assert len(m) == len(s) + if args.pad_to_modulus: + blocksize = k_len/8 + if args.extra_word: + blocksize += 4 + else: + blocksize = 4 + printlines("/* %(k_len)d-bit RSA private key (PKCS #%(pkcs)d)", - k.exportKey(format = "PEM", pkcs = pkcs_encoding), + k.exportKey(format = "PEM", pkcs = args.pkcs_encoding), "*/", "", - k_len = k_len, pkcs = pkcs_encoding) + k_len = k_len, pkcs = args.pkcs_encoding) for component in k.keydata: print_hex("%s_%d" % (component, k_len), - long_to_bytes(getattr(k, component), blocksize = 4), + long_to_bytes(getattr(k, component), blocksize = blocksize), "key component %s" % component) - print_hex("m_%d" % k_len, m, "message to be signed") - print_hex("s_%d" % k_len, s, "signed message") + print_hex("m_%d" % k_len, pad_to_blocksize(m, blocksize), "message to be signed") + print_hex("s_%d" % k_len, pad_to_blocksize(s, blocksize), "signed message") fields = "nedpqums" printlines("typedef struct { const uint8_t *val; size_t len; } rsa_tc_bn_t;", @@ -77,10 +108,10 @@ printlines("typedef struct { const uint8_t *val; size_t len; } rsa_tc_bn_t;", "", "static const rsa_tc_t rsa_tc[] = {", fields = ", ".join(fields)) -for k_len in key_lengths: +for k_len in args.key_lengths: printlines(" { %(k_len)d,", k_len = k_len) for field in fields: printlines(" { %(field)s_%(k_len)d, sizeof(%(field)s_%(k_len)d) }%(comma)s", field = field, k_len = k_len, comma = trailing_comma(field, fields)) - printlines(" }%(comma)s", comma = trailing_comma(k_len, key_lengths)) + printlines(" }%(comma)s", comma = trailing_comma(k_len, args.key_lengths)) printlines("};") -- cgit v1.2.3