aboutsummaryrefslogblamecommitdiff
path: root/tests/test-rsa.py
blob: 6b52eb99391dd1b52dd1883982d60a495232ed28 (plain) (tree)
1
2
3
4
5
6
7
8
9
                     
 


                                                            
 
                     

                                  
 




                                                                          
 


                                                                        
 


                                                                          
 










                                                                          
 
                                                                       


                                                                             
                                                                     



                                                                                        




















                                                                                      






                                                                                        
                                           









                                                                                       



                                                                        


                                                                      













                                                                                         

                                                             
                              








                                                                                      






                         
                                                                 
                                                                    
                      
                                                       
 



















                                                                      

                                                                                   
 




                                                                               
                              



                                                                                   
                                                                             
                
#!/usr/bin/env python

"""
Use PyCrypto to generate test data for Cryptech ModExp core.
"""

# Author: Rob Austein
# Copyright (c) 2015, NORDUnet A/S
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
#   be used to endorse or promote products derived from this software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from argparse                           import ArgumentParser, FileType
from Crypto                             import __version__ as PyCryptoVersion
from Crypto.PublicKey                   import RSA
from Crypto.Hash                        import SHA256
from Crypto.Util.number                 import long_to_bytes, inverse
from Crypto.Signature.PKCS1_v1_5        import EMSA_PKCS1_V1_5_ENCODE, PKCS115_SigScheme
from textwrap                           import TextWrapper
import sys, os.path

def KeyLengthType(arg):
  val = int(arg)
  if val % 8 != 0:
    raise ValueError
  return val

parser = ArgumentParser(description = __doc__)
parser.add_argument("--pad-to-modulus", action = "store_true",
                    help = "zero-pad to modulus size (bug workaround)")
parser.add_argument("--extra-word", action = "store_true",
                    help = "add extra word of zero padding (bug workaround)")
parser.add_argument("-k", "--key-lengths", type = KeyLengthType,
                    nargs = "*", default = [1024, 2048, 4096],
                    help = "Lengths in bits of keys to generate")
parser.add_argument("--pkcs-encoding", type = int, choices = (1, 8), default = 8,
                    help = "PKCS encoding to use for PEM commented private key")
parser.add_argument("output", nargs = "?", type = FileType("w"), default = sys.stdout,
                    help = "output file")
args = parser.parse_args()

plaintext = "You can hack anything you want with TECO and DDT."

scriptname = os.path.basename(sys.argv[0])

wrapper = TextWrapper(width = 78, initial_indent = " " * 2, subsequent_indent = " " * 2)

def printlines(*lines, **kwargs):
  for line in lines:
    args.output.write(line % kwargs + "\n")

def trailing_comma(item, sequence):
  return "" if item == sequence[-1] else ","

def print_hex(name, value, comment):
  printlines("static const uint8_t %(name)s[] = { /* %(comment)s, %(length)d bytes */",
             wrapper.fill(", ".join("0x%02x" % ord(v) for v in value)),
             "};", "",
             name = name, comment = comment, length  = len(value))

def pad_to_blocksize(value, blocksize):
  extra = len(value) % blocksize
  return value if extra == 0 else ("\x00" * (blocksize - extra)) + value

# Funnily enough, PyCrypto and Cryptlib use exactly the same names for
# RSA key components, see Cryptlib documentation pages 186-187 & 339.

h = SHA256.new(plaintext)

printlines("/*",
           " * RSA signature test data for Cryptech project, automatically generated by",
           " * %(scriptname)s using PyCrypto version %(version)s. Do not edit.",
           " *",
           " * Plaintext: \"%(plaintext)s\"",
           " * SHA-256: %(digest)s",
           " */", "",
           scriptname = scriptname,
           version    = PyCryptoVersion,
           plaintext  = plaintext,
           digest     = h.hexdigest())

fields = ("n", "e", "d", "p", "q", "dP", "dQ", "u", "m", "s")

for k_len in args.key_lengths:

  k = RSA.generate(k_len)       # Cryptlib insists u < p, probably with good reason,
  while k.u >= k.p:             # and I'm sure not going to argue the math with Peter,
    k = RSA.generate(k_len)     # so keep trying until we pass this test

  m = EMSA_PKCS1_V1_5_ENCODE(h, k_len/8)
  s = PKCS115_SigScheme(k).sign(h)
  assert len(m) == len(s)

  if args.pad_to_modulus:
    blocksize = k_len/8
    if args.extra_word:
      blocksize += 4
  else:
    blocksize = 4

  printlines("/* %(k_len)d-bit RSA private key (PKCS #%(pkcs)d)",
             k.exportKey(format = "PEM", pkcs = args.pkcs_encoding),
             "*/", "",
             k_len = k_len, pkcs  = args.pkcs_encoding)

  # PyCrypto doesn't precalculate dP or dQ, and for some reason it
  # does u backwards (uses (1/p % q) and swaps the roles of p and q in
  # the CRT calculation to compensate), so we just calculate our own.

  for name in fields:
    if name in "ms":
      continue
    elif name == "dP":
      value = k.d % (k.p - 1)
    elif name == "dQ":
      value = k.d % (k.q - 1)
    elif name == "u":
      value = inverse(k.q, k.p)
    else:
      value = getattr(k, name)

    print_hex("%s_%d" % (name, k_len),
              long_to_bytes(value, blocksize = blocksize),
              "key component %s" % name)

  print_hex("m_%d" % k_len, pad_to_blocksize(m, blocksize), "message to be signed")
  print_hex("s_%d" % k_len, pad_to_blocksize(s, blocksize), "signed message")

printlines("typedef struct { const uint8_t *val; size_t len; } rsa_tc_bn_t;",
           "typedef struct { size_t size; rsa_tc_bn_t %(fields)s; } rsa_tc_t;",
           "",
           "static const rsa_tc_t rsa_tc[] = {",
           fields = ", ".join(fields))
for k_len in args.key_lengths:
  printlines("  { %(k_len)d,", k_len = k_len)
  for field in fields:
    printlines("    { %(field)s_%(k_len)d, sizeof(%(field)s_%(k_len)d) }%(comma)s",
               field = field, k_len = k_len, comma = trailing_comma(field, fields))
  printlines("  }%(comma)s", comma = trailing_comma(k_len, args.key_lengths))
printlines("};")