aboutsummaryrefslogblamecommitdiff
path: root/test/format_test_vectors.py
blob: f9e4ba8155ccd44cf0215d6a64c18c3b1e9469c4 (plain) (tree)
















































































                                                                                       







































                                                                                    



















                                                                                                 
                                                              

                                             









                                                                 









































                                                           
                                                                     


                                                                          

                                                          
        
                                                                             
                                                      

                                                         







                                                                                     


















                                                                                                                                    








































































































                                                                                      














                                                                                                      


                                               
















                                                                                        
                                                          
                                                                                                                                           
                                                                                                                                                  










                                           
#
# format_test_vectors.py
# ------------------------------------------
# Formats test vectors for modexp_fpga_model
#
# Author: Pavel Shatov
# Copyright (c) 2017, NORDUnet A/S
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
#   be used to endorse or promote products derived from this software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

#
# This script reads the test vectors generated by regenerate_test_vectors.py
# and writes nicely formatted C header file.
#

#
# imports
#
import subprocess

#
# get part of string between two markers
#
def string_between(s, s_left, s_right):
	s_begin = s.index(s_left) + len(s_left)
	s_end = s.index(s_right, s_begin)
	return s[s_begin:s_end]

#
# load message from file
#
def read_message(key):
	with open(key + ".txt", "r") as f:
		return f.readlines()[0]
	
#
# read modulus from file
#
def read_modulus(key):
	openssl_command = ["openssl", "rsa", "-in", key + ".key", "-noout", "-modulus"]
	openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8")
	return openssl_stdout.strip().split("=")[1]

#
# read private exponent from file
#
def read_secret(key):
	openssl_command = ["openssl", "rsa", "-in", key + ".key", "-noout", "-text"]
	openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8")
	openssl_secret = string_between(openssl_stdout, "privateExponent", "prime1")
	openssl_secret = openssl_secret.replace(":", "")
	openssl_secret = openssl_secret.replace("\n", "")
	openssl_secret = openssl_secret.replace(" ", "")	
	return openssl_secret

#
# read part of private key from file
#
def read_prime1(key):
	openssl_command = ["openssl", "rsa", "-in", key + ".key", "-noout", "-text"]
	openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8")
	openssl_secret = string_between(openssl_stdout, "prime1", "prime2")
	openssl_secret = openssl_secret.replace(":", "")
	openssl_secret = openssl_secret.replace("\n", "")
	openssl_secret = openssl_secret.replace(" ", "")	
	return openssl_secret
def read_prime2(key):
	openssl_command = ["openssl", "rsa", "-in", key + ".key", "-noout", "-text"]
	openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8")
	openssl_secret = string_between(openssl_stdout, "prime2", "exponent1")
	openssl_secret = openssl_secret.replace(":", "")
	openssl_secret = openssl_secret.replace("\n", "")
	openssl_secret = openssl_secret.replace(" ", "")	
	return openssl_secret

#
# read prive exponent from file
#
def read_exponent1(key):
	openssl_command = ["openssl", "rsa", "-in", key + ".key", "-noout", "-text"]
	openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8")
	openssl_secret = string_between(openssl_stdout, "exponent1", "exponent2")
	openssl_secret = openssl_secret.replace(":", "")
	openssl_secret = openssl_secret.replace("\n", "")
	openssl_secret = openssl_secret.replace(" ", "")	
	return openssl_secret
def read_exponent2(key):
	openssl_command = ["openssl", "rsa", "-in", key + ".key", "-noout", "-text"]
	openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8")
	openssl_secret = string_between(openssl_stdout, "exponent2", "coefficient")
	openssl_secret = openssl_secret.replace(":", "")
	openssl_secret = openssl_secret.replace("\n", "")
	openssl_secret = openssl_secret.replace(" ", "")	
	return openssl_secret

# 
# https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
#
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        (g, y, x) = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    (g, x, y) = egcd(a, m)
    if g != 1:
        raise Exception("Can't invert a = " + a)
    else:
        return x % m
		
#
# format one test vector
#
def format_c_header(f, key, n, m, d, s, p, q, dp, dq, mp, mq):

		# write all numbers in vector
	format_c_array(f, n,  "#define N_"  + str(key) + " \\\n")
	format_c_array(f, m,  "#define M_"  + str(key) + " \\\n")
	format_c_array(f, d,  "#define D_"  + str(key) + " \\\n")
	format_c_array(f, s,  "#define S_"  + str(key) + " \\\n")
	format_c_array(f, p,  "#define P_"  + str(key) + " \\\n")
	format_c_array(f, q,  "#define Q_"  + str(key) + " \\\n")
	format_c_array(f, dp, "#define DP_" + str(key) + " \\\n")
	format_c_array(f, dq, "#define DQ_" + str(key) + " \\\n")
	format_c_array(f, mp, "#define MP_" + str(key) + " \\\n")
	format_c_array(f, mq, "#define MQ_" + str(key) + " \\\n")

#
# calculate Montgomery factor
#
def calc_montgomery_factor(k, n):
	
	f = 1
	
	for i in range(2 * k):
		f1 = 2 * f
		f2 = f1 - n
		if f2 < 0:
			f = f1
		else:
			f = f2
	
	return f


#
# calculate Montgomery modulus-dependent helper coefficient
#
def calc_montgomery_n_coeff(k, n):

	r = 1
	b = 1 << k
	
	nn = b - n
	
	for i in range(k-1):
		t = (r * nn) % b
		mask = 1 << (i + 1)
		if (t & mask) == mask:
			r = r + (1 << (i + 1))

	return r
				

	
#
# format one test vector
#
def format_verilog_include(f, key, n, m, d, s, p, q, dp, dq, mp, mq):

		# calculate factor to bring message into Montgomery domain
	factor = calc_montgomery_factor(int(key), n)
	factor_p = calc_montgomery_factor(int(key)//2, p);
	factor_q = calc_montgomery_factor(int(key)//2, q);
	
		# calculate helper coefficients for Montgomery multiplication
	n_coeff = calc_montgomery_n_coeff(int(key), n)
	p_coeff = calc_montgomery_n_coeff(int(key)//2, p)
	q_coeff = calc_montgomery_n_coeff(int(key)//2, q)
			
		# calculate the extra coefficient Montgomery multiplication brings in
	coeff = modinv(1 << int(key), n)
	
		# convert m into Montgomery representation
	m_factor = (m * factor * coeff) % n
		
		# write all numbers
	format_verilog_concatenation(f, m,        "localparam [" + str(int(key)-1) + ":0] M_"        + key + " =\n")
	format_verilog_concatenation(f, n,        "localparam [" + str(int(key)-1) + ":0] N_"        + key + " =\n")
	format_verilog_concatenation(f, n_coeff,  "localparam [" + str(int(key)-1) + ":0] N_COEFF_"  + key + " =\n")
	format_verilog_concatenation(f, factor,   "localparam [" + str(int(key)-1) + ":0] FACTOR_"   + key + " =\n")
	format_verilog_concatenation(f, coeff,    "localparam [" + str(int(key)-1) + ":0] COEFF_"    + key + " =\n")
	format_verilog_concatenation(f, m_factor, "localparam [" + str(int(key)-1) + ":0] M_FACTOR_" + key + " =\n")
	format_verilog_concatenation(f, d,        "localparam [" + str(int(key)-1) + ":0] D_"        + key + " =\n")
	format_verilog_concatenation(f, s,        "localparam [" + str(int(key)-1) + ":0] S_"        + key + " =\n")
	
	format_verilog_concatenation(f, p,        "localparam [" + str(int(key)//2-1) + ":0] P_"        + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, q,        "localparam [" + str(int(key)//2-1) + ":0] Q_"        + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, p_coeff,  "localparam [" + str(int(key)//2-1) + ":0] P_COEFF_"  + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, q_coeff,  "localparam [" + str(int(key)//2-1) + ":0] Q_COEFF_"  + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, factor_p, "localparam [" + str(int(key)//2-1) + ":0] FACTOR_P_" + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, factor_q, "localparam [" + str(int(key)//2-1) + ":0] FACTOR_Q_" + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, dp,       "localparam [" + str(int(key)//2-1) + ":0] DP_"       + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, dq,       "localparam [" + str(int(key)//2-1) + ":0] DQ_"       + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, mp,       "localparam [" + str(int(key)//2-1) + ":0] MP_"       + str(int(key)//2) + " =\n")
	format_verilog_concatenation(f, mq,       "localparam [" + str(int(key)//2-1) + ":0] MQ_"       + str(int(key)//2) + " =\n")
	
	
#
# nicely format multi-word integer into C array initializer
#
def format_c_array(f, n, s):

		# print '#define ZZZ \'
	f.write(s)

		# convert number to hex string and prepend it with zeroes if necessary
	n_hex = hex(n).split("0x")[1]
	while (len(n_hex) % 8) > 0:
		n_hex = "0" + n_hex
	
		# get number of 32-bit words
	num_words = len(n_hex) // 8

		# print all words in n
	w = 0
	while w < num_words:
	
		n_part = ""

			# add tab for every new line
		if w == 0:
			n_part += "\t{"
		elif (w % 4) == 0:
			n_part += "\t "
			
			# add current word
		n_part += "0x" + n_hex[8 * w : 8 * (w + 1)]
		
			# add separator or newline
		if (w + 1) == num_words:
			n_part += "}\n"
		else:
			n_part += ", "
			if (w % 4) == 3:
				n_part += "\\\n"		
		w += 1
		
			# write current part
		f.write(n_part)
	
		# write final newline
	f.write("\n")

	
def format_verilog_concatenation(f, n, s):

		# print 'localparam ZZZ ='
	f.write(s)
	
		# convert number to hex string and prepend it with zeroes if necessary
	n_hex = hex(n).split("0x")[1]
	while (len(n_hex) % 8) > 0:
		n_hex = "0" + n_hex
	
		# get number of 32-bit words
	num_words = len(n_hex) // 8

		# print all words in n
	w = 0
	while w < num_words:
	
		n_part = ""
		
		if w == 0:
			n_part += "\t{"
		elif (w % 4) == 0:
			n_part += "\t "
			
		n_part += "32'h" + n_hex[8 * w : 8 * (w + 1)]
		
		if (w + 1) == num_words:
			n_part += "};\n"
		else:
			n_part += ", "
			if (w % 4) == 3:
				n_part += "\n"		
		w += 1
		
		f.write(n_part)
	
	f.write("\n")
	
if __name__ == "__main__":

		# list of key lengths to process
	keys = ["384", "512"]

		# open output files
	file_h = open('modexp_fpga_model_vectors.h', 'w')
	file_v = open('modexp_fpga_model_vectors.v', 'w')
	
		# write headers
	file_h.write("/* Generated automatically, do not edit. */\n\n")
	file_v.write("/* Generated automatically, do not edit. */\n\n")

	
		# process all the keys
	for key in keys:
	
			# prepare all the numbers
		modulus = int(read_modulus(key), 16)			# read number n from .key file
		message = int(read_message(key), 16)			# read number m from .txt file
		secret  = int(read_secret(key),  16)			# read number d from .key file
		signature = pow(message, secret, modulus)		# calculate signature
		prime1 = int(read_prime1(key), 16)				# read p
		prime2 = int(read_prime2(key), 16)				# read q
		exponent1 = int(read_exponent1(key), 16)		# read dp
		exponent2 = int(read_exponent2(key), 16)		# read dq
		message1 = pow(message, exponent1, prime1)		# calculate mp = m ^ dp mod p
		message2 = pow(message, exponent2, prime2)		# calculate mq = m ^ dq mod q
		coefficient = modinv(prime2, prime1)			# calculate

			# do CRT to make sure everything is correct
		h = coefficient * (message1 - message2) % prime1
		crt = message2 + h * prime2
		
			# print all the numbers
		print("key = " + key)
		print("  modulus     = " + hex(modulus))
		print("  message     = " + hex(message))
		print("  secret      = " + hex(secret))
		print("  signature   = " + hex(signature))
		print("  prime1      = " + hex(prime1))
		print("  prime2      = " + hex(prime2))
		print("  exponent1   = " + hex(exponent1))
		print("  exponent2   = " + hex(exponent2))
		print("  message1    = " + hex(message1))
		print("  message2    = " + hex(message2))
		print("  coefficient = " + hex(coefficient))
		print("  crt         = " + hex(crt))
		
			# check
		if crt != signature:
			raise Exception("Error, crt != signature (?)")			
			
			# format numbers and write to file
		format_c_header(file_h, key, modulus, message, secret, signature, prime1, prime2, exponent1, exponent2, message1, message2)
		format_verilog_include(file_v, key, modulus, message, secret, signature, prime1, prime2, exponent1, exponent2, message1, message2)


		# done
	file_h.close()
	
		# everything went just fine
	print("Test vectors formatted.")
	
#
# End of file
#