#!/usr/bin/python3
#
#
# ModExpNG core math model.
#
#
# Copyright (c) 2019, NORDUnet A/S
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
# be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# -------
# Imports
#--------
import sys
import importlib
from enum import Enum, auto
# --------------
# Model Settings
# --------------
# length of public key
KEY_LENGTH = 1024
# how many parallel multipliers to use
NUM_MULTS = 8
# ---------------
# Internal Values
# ---------------
# half of key length
_KEY_LENGTH_HALF = KEY_LENGTH // 2
# width of internal math pipeline
_WORD_WIDTH = 16
_WORD_WIDTH_EXT = 18
# folder with test vector scripts
_VECTOR_PATH = "/vector"
# name of test vector class
_VECTOR_CLASS = "Vector"
# ------------------
# Debugging Settings
# ------------------
FORCE_OVERFLOW = False
DUMP_VECTORS = False
DUMP_INDICES = False
DUMP_MACS_INPUTS = False
DUMP_MACS_CLEARING = False
DUMP_MACS_ACCUMULATION = False
DUMP_MULT_PARTS = False
DUMP_RCMB = False
DUMP_REDUCTION = False
#
# Multi-Precision Integer
#
class ModExpNG_Operand():
def __init__(self, number, length, words = None):
if words is None:
# length must be divisible by word width
if (length % _WORD_WIDTH) > 0:
raise Exception("Bad number length!")
self._init_from_number(number, length)
else:
# length must match words count
if len(words) != length:
raise Exception("Bad words count!")
self._init_from_words(words, length)
def format_verilog_concat(self, name):
for i in range(len(self.words)):
if i > 0:
if (i % 4) == 0: print("")
else: print(" ", end='')
print("%s[%2d] = 18'h%05x;" % (name, i, self.words[i]), end='')
print("")
def _init_from_words(self, words, count):
for i in range(count):
# word must not exceed 18 bits
if words[i] >= (2 ** (_WORD_WIDTH_EXT)):
raise Exception("Word is too large!")
self.words = words
def _init_from_number(self, number, length):
num_hexchars_per_word = _WORD_WIDTH // 4
num_hexchars_total = length // num_hexchars_per_word
value_hex = format(number, 'x')
# value must not be larger than specified, but it can be smaller, so
# we may need to prepend it with zeroes
if len(value_hex) > num_hexchars_total:
raise Exception("Number is too large!")
else:
while len(value_hex) < num_hexchars_total:
value_hex = "0" + value_hex
# create empty list
self.words = list()
# fill in words
while len(value_hex) > 0:
value_hex_part = value_hex[-num_hexchars_per_word:]
value_hex = value_hex[:-num_hexchars_per_word]
self.words.append(int(value_hex_part, 16))
def number(self):
ret = 0
shift = 0
for word in self.words:
ret += word << shift
shift += _WORD_WIDTH
return ret
#
# Test Vector
#
class ModExpNG_TestVector():
def __init__(self):
# format target filename
filename = "vector_" + str(KEY_LENGTH) + "_randomized"
# add ./vector to import search path
sys.path.insert(1, sys.path[0] + _VECTOR_PATH)
# import from filename
vector_module = importlib.import_module(filename)
# get vector class
vector_class = getattr(vector_module, _VECTOR_CLASS)
# instantiate vector class
vector_inst = vector_class()
# obtain parts of vector
self.m = ModExpNG_Operand(vector_inst.m, KEY_LENGTH)
self.n = ModExpNG_Operand(vector_inst.n, KEY_LENGTH)
self.d = ModExpNG_Operand(vector_inst.d, KEY_LENGTH)
self.p = ModExpNG_Operand(vector_inst.p, _KEY_LENGTH_HALF)
self.q = ModExpNG_Operand(vector_inst.q, _KEY_LENGTH_HALF)
self.dp = ModExpNG_Operand(vector_inst.dp, _KEY_LENGTH_HALF)
self.dq = ModExpNG_Operand(vector_inst.dq, _KEY_LENGTH_HALF)
self.qinv = ModExpNG_Operand(vector_inst.qinv, _KEY_LENGTH_HALF)
self.n_factor = ModExpNG_Operand(vector_inst.n_factor, KEY_LENGTH)
self.p_factor = ModExpNG_Operand(vector_inst.p_factor, _KEY_LENGTH_HALF)
self.q_factor = ModExpNG_Operand(vector_inst.q_factor, _KEY_LENGTH_HALF)
self.n_coeff = ModExpNG_Operand(vector_inst.n_coeff, KEY_LENGTH + _WORD_WIDTH)
self.p_coeff = ModExpNG_Operand(vector_inst.p_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH)
self.q_coeff = ModExpNG_Operand(vector_inst.q_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH)
self.x = ModExpNG_Operand(vector_inst.x, KEY_LENGTH)
self.y = ModExpNG_Operand(vector_inst.y, KEY_LENGTH)
class ModExpNG_WideBankEnum(Enum):
A = auto()
B = auto()
C = auto()
D = auto()
E = auto()
N = auto()
class ModExpNG_NarrowBankEnum(Enum):
A = auto()
B = auto()
C = auto()
D = auto()
E = auto()
N_COEFF = auto()
I = auto()
class ModExpNG_WideBank():
def __init__(self):
self.a = None
self.b = None
self.c = None
self.d = None
self.e = None
self.n = None
def _get_value(self, sel):
if sel == ModExpNG_WideBankEnum.A: return self.a
elif sel == ModExpNG_WideBankEnum.B: return self.b
elif sel == ModExpNG_WideBankEnum.C: return self.c
elif sel == ModExpNG_WideBankEnum.D: return self.d
elif sel == ModExpNG_WideBankEnum.E: return self.e
elif sel == ModExpNG_WideBankEnum.N: return self.n
else: raise Exception("ModExpNG_WideBank._get_value(): Invalid selector!")
def _set_value(self, sel, value):
if sel == ModExpNG_WideBankEnum.A: self.a = value
elif sel == ModExpNG_WideBankEnum.B: self.b = value
elif sel == ModExpNG_WideBankEnum.C: self.c = value
elif sel == ModExpNG_WideBankEnum.D: self.d = value
elif sel == ModExpNG_WideBankEnum.E: self.e = value
elif sel == ModExpNG_WideBankEnum.N: self.n = value
else: raise Exception("ModExpNG_WideBank._set_value(): Invalid selector!")
class ModExpNG_NarrowBank():
def __init__(self, i):
self.a = None
self.b = None
self.c = None
self.d = None
self.e = None
self.n_coeff = None
self.i = i
def _get_value(self, sel):
if sel == ModExpNG_NarrowBankEnum.A: return self.a
elif sel == ModExpNG_NarrowBankEnum.B: return self.b
elif sel == ModExpNG_NarrowBankEnum.C: return self.c
elif sel == ModExpNG_NarrowBankEnum.D: return self.d
elif sel == ModExpNG_NarrowBankEnum.E: return self.e
elif sel == ModExpNG_NarrowBankEnum.N_COEFF: return self.n_coeff
elif sel == ModExpNG_NarrowBankEnum.I: return self.i
else: raise Exception("ModExpNG_NarrowBank._get_value(): Invalid selector!")
def _set_value(self, sel, value):
if sel == ModExpNG_NarrowBankEnum.A: self.a = value
elif sel == ModExpNG_NarrowBankEnum.B: self.b = value
elif sel == ModExpNG_NarrowBankEnum.C: self.c = value
elif sel == ModExpNG_NarrowBankEnum.D: self.d = value
elif sel == ModExpNG_NarrowBankEnum.E: self.e = value
elif sel == ModExpNG_NarrowBankEnum.N_COEFF: self.n_coeff = value
else: raise Exception("ModExpNG_NarrowBank._set_value(): Invalid selector!")
class ModExpNG_BanksPair():
def __init__(self, i):
self.wide = ModExpNG_WideBank()
self.narrow = ModExpNG_NarrowBank(i)
def _get_value_wide(self, sel):
return self.wide._get_value(sel)
def _get_value_narrow(self, sel):
return self.narrow._get_value(sel)
class ModExpNG_BanksLadder():
def __init__(self, i):
self.ladder_x = ModExpNG_BanksPair(i)
self.ladder_y = ModExpNG_BanksPair(i)
def set_modulus(self, n, n_coeff):
self.ladder_x.wide._set_value(ModExpNG_WideBankEnum.N, n)
self.ladder_y.wide._set_value(ModExpNG_WideBankEnum.N, n)
self.ladder_x.narrow._set_value(ModExpNG_NarrowBankEnum.N_COEFF, n_coeff)
self.ladder_y.narrow._set_value(ModExpNG_NarrowBankEnum.N_COEFF, n_coeff)
def set_operand(self, sel_wide, sel_narrow, x, y):
if sel_wide is not None:
self.ladder_x.wide._set_value(sel_wide, x)
self.ladder_y.wide._set_value(sel_wide, y)
if sel_narrow is not None:
self.ladder_x.narrow._set_value(sel_narrow, x)
self.ladder_y.narrow._set_value(sel_narrow, y)
class ModExpNG_BanksCRT():
def __init__(self, i):
self.crt_x = ModExpNG_BanksLadder(i)
self.crt_y = ModExpNG_BanksLadder(i)
class ModExpNG_PartRecombinator():
def _bit_select(self, x, msb, lsb):
y = 0
for pos in range(lsb, msb+1):
y |= (x & (1 << pos)) >> lsb
return y
def _flush_pipeline(self, dump):
self.z0, self.y0, self.x0 = 0, 0, 0
if dump and DUMP_RCMB:
print("RCMB -> flush()")
def _push_pipeline(self, part, dump):
# split next part into 16-bit words
z = self._bit_select(part, 46, 32)
y = self._bit_select(part, 31, 16)
x = self._bit_select(part, 15, 0)
# shift to the right
z1 = z
y1 = y + self.z0
x1 = x + self.y0 + (self.x0 >> _WORD_WIDTH) # IMPORTANT: This carry can be up to two bits wide!!
# save lower 16 bits of the rightmost cell
t = self.x0 & 0xffff
# update internal latches
self.z0, self.y0, self.x0 = z1, y1, x1
# dump
if dump and DUMP_RCMB:
print("RCMB -> push(): part = 0x%012x, word = 0x%04x" % (part, t))
# done
return t
def recombine_square(self, parts, ab_num_words, dump):
# empty results so far
words_lsb = list() # n words
words_msb = list() # n words
# recombine the lower half (n parts)
# the first tick produces null result, the last part
# produces three words and needs two extra ticks
self._flush_pipeline(dump)
for i in range(ab_num_words + 1 + 2):
next_part = parts[i] if i < ab_num_words else 0
next_word = self._push_pipeline(next_part, dump)
if i > 0:
words_lsb.append(next_word)
# recombine the upper half (n-1 parts)
# the first tick produces null result
self._flush_pipeline(dump)
for i in range(ab_num_words + 1):
next_part = parts[i + ab_num_words] if i < (ab_num_words - 1) else 0
next_word = self._push_pipeline(next_part, dump)
if i > 0:
words_msb.append(next_word)
# merge words
words = list()
# merge lower half
for x in range(ab_num_words):
next_word = words_lsb[x]
words.append(next_word)
# merge upper half adding the two overlapping words
for x in range(ab_num_words):
next_word = words_msb[x]
if x < 2:
next_word += words_lsb[x + ab_num_words]
words.append(next_word)
return words
def recombine_triangle(self, parts, ab_num_words, dump):
# empty result so far
words_lsb = list()
# recombine the lower half (n+1 parts)
# the first tick produces null result, so we need n + 1 + 1 = n + 2
# ticks total and should only save the result word during the last
# n + 1 ticks
self._flush_pipeline(dump)
for i in range(ab_num_words + 2):
next_part = parts[i] if i < (ab_num_words + 1) else 0
next_word = self._push_pipeline(next_part, dump)
if i > 0:
words_lsb.append(next_word)
return words_lsb
def recombine_rectangle(self, parts, ab_num_words, dump):
# empty result so far
words_lsb = list() # n words
words_msb = list() # n+1 words
# recombine the lower half (n parts)
# the first tick produces null result, the last part
# produces three words and needs two extra ticks
self._flush_pipeline(dump)
for i in range(ab_num_words + 1 + 2):
next_part = parts[i] if i < ab_num_words else 0
next_word = self._push_pipeline(next_part, dump)
if i > 0:
words_lsb.append(next_word)
# recombine the upper half (n parts)
# the first tick produces null result, the last part
# produces two words and needs an extra tick
self._flush_pipeline(dump)
for i in range(ab_num_words + 2):
next_part = parts[i + ab_num_words] if i < ab_num_words else 0
next_word = self._push_pipeline(next_part, dump)
if i > 0:
words_msb.append(next_word)
# merge words
words = list()
# merge lower half
for x in range(ab_num_words):
next_word = words_lsb[x]
words.append(next_word)
# merge upper half adding the two overlapping words
for x in range(ab_num_words + 1):
next_word = words_msb[x]
if x < 2:
next_word += words_lsb[x + ab_num_words]
words.append(next_word)
return words
class ModExpNG_WordMultiplier():
def __init__(self):
self._macs = list()
self._indices = list()
self._mac_aux = list()
self._index_aux = list()
for x in range(NUM_MULTS):
self._macs.append(0)
self._indices.append(0)
self._mac_aux.append(0)
self._index_aux.append(0)
def _clear_all_macs(self, t, col, dump):
for x in range(NUM_MULTS):
self._macs[x] = 0
if dump and DUMP_MACS_CLEARING:
print("t=%2d, col=%2d > clear > all" % (t, col))
def _clear_one_mac(self, x, t, col, dump):
self._macs[x] = 0
if dump and DUMP_MACS_CLEARING:
print("t=%2d, col=%2d > clear > x=%d" % (t, col, x))
def _clear_mac_aux(self, t, col, dump):
self._mac_aux[0] = 0
if dump and DUMP_MACS_CLEARING:
print("t= 0, col=%2d > clear > aux" % (col))
def _update_one_mac(self, x, t, col, a, b, dump, need_aux=False):
if a > 0x3FFFF:
raise Exception("a > 0x3FFFF!")
if b > 0xFFFF:
raise Exception("b > 0xFFFF!")
p = a * b
if dump and DUMP_MACS_INPUTS:
if x == 0: print("t=%2d, col=%2d > b=%05x > " % (t, col, b), end='')
if x > 0: print("; ", end='')
print("MAC[%d]: a=%05x" % (x, a), end='')
if x == (NUM_MULTS-1) and not need_aux: print("")
self._macs[x] += p
def _update_mac_aux(self, y, col, a, b, dump):
if a > 0x3FFFF:
raise Exception("a > 0x3FFFF!")
if b > 0xFFFF:
raise Exception("b > 0xFFFF!")
p = a * b
if dump and DUMP_MACS_INPUTS:
print("; AUX: a=%05x" % a)
self._mac_aux[0] += p
def _preset_indices(self, col):
for x in range(len(self._indices)):
self._indices[x] = col * len(self._indices) + x
def _preset_index_aux(self, num_cols):
self._index_aux[0] = num_cols * len(self._indices)
def _dump_macs_helper(self, t, col, aux=False):
print("t=%2d, col=%2d > "% (t, col), end='')
for i in range(NUM_MULTS):
if i > 0: print(" | ", end='')
print("mac[%d]: 0x%012x" % (i, self._macs[i]), end='')
if aux:
print(" | mac_aux[ 0]: 0x%012x" % (self._mac_aux[0]), end='')
print("")
def _dump_macs(self, t, col):
self._dump_macs_helper(t, col)
def _dump_macs_with_aux(self, t, col):
self._dump_macs_helper(t, col, True)
def _dump_indices_helper(self, t, col, aux=False):
print("t=%2d, col=%2d > indices:" % (t, col), end='')
for i in range(NUM_MULTS):
print(" %2d" % self._indices[i], end='')
if aux:
print(" %2d" % self._index_aux[0], end='')
print("")
def _dump_indices(self, t, col):
self._dump_indices_helper(t, col)
def _dump_indices_with_aux(self, t, col):
self._dump_indices_helper(t, col, True)
def _rotate_indices(self, num_words):
for x in range(len(self._indices)):
if self._indices[x] > 0:
self._indices[x] -= 1
else:
self._indices[x] = num_words - 1
def _rotate_index_aux(self):
self._index_aux[0] -= 1
def _mult_store_part(self, parts, time, column, part_index, mac_index, dump):
parts[part_index] = self._macs[mac_index]
if dump and DUMP_MULT_PARTS:
print("t=%2d, col=%2d > parts[%2d]: mac[%d] = 0x%012x" %
(time, column, part_index, mac_index, parts[part_index]))
def _mult_store_part_aux(self, parts, time, column, part_index, dump):
parts[part_index] = self._mac_aux[0]
if dump and DUMP_MULT_PARTS:
print("t=%2d, col=%2d > parts[%2d]: mac_aux[%d] = 0x%012x" %
(time, column, part_index, 0, parts[part_index]))
def multiply_square(self, a_wide, b_narrow, ab_num_words, dump=False):
num_cols = ab_num_words // NUM_MULTS
parts = list()
for i in range(2 * ab_num_words - 1):
parts.append(0)
for col in range(num_cols):
b_carry = 0
for t in range(ab_num_words):
# take care of indices
if t == 0: self._preset_indices(col)
else: self._rotate_indices(ab_num_words)
# take care of macs
if t == 0:
self._clear_all_macs(t, col, dump)
else:
t1 = t - 1
if (t1 // 8) == col:
self._clear_one_mac(t1 % NUM_MULTS, t, col, dump)
# debug output
if dump and DUMP_INDICES: self._dump_indices(t, col)
# current b-word
# TODO: Explain how the 18th bit carry works!!
bt = b_narrow.words[t] + b_carry
b_carry = (bt & 0x30000) >> 16
bt &= 0xFFFF
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
self._update_one_mac(x, t, col, ax, bt, dump)
if t == (col * NUM_MULTS + x):
part_index = t
self._mult_store_part(parts, t, col, part_index, x, dump)
# debug output
if dump and DUMP_MACS_ACCUMULATION: self._dump_macs(t, col)
# save the uppers part of product at end of column,
# for the last column don't save the very last part
if t == (ab_num_words - 1):
for x in range(NUM_MULTS):
if not (col == (num_cols - 1) and x == (NUM_MULTS - 1)):
part_index = ab_num_words + col * NUM_MULTS + x
self._mult_store_part(parts, t, col, part_index, x, dump)
return parts
def multiply_triangle(self, a_wide, b_narrow, ab_num_words, dump=False):
num_cols = ab_num_words // NUM_MULTS
parts = list()
for i in range(ab_num_words + 1):
parts.append(0)
for col in range(num_cols):
last_col = col == (num_cols - 1)
for t in range(ab_num_words + 1):
# take care of indices
if t == 0: self._preset_indices(col)
else: self._rotate_indices(ab_num_words)
# take care of auxilary index
if last_col:
if t == 0: self._preset_index_aux(num_cols)
else: self._rotate_index_aux()
# take care of macs
if t == 0: self._clear_all_macs(t, col, dump)
# take care of auxilary mac
if last_col:
if t == 0: self._clear_mac_aux(t, col, dump)
# debug output
if dump and DUMP_INDICES: self._dump_indices_with_aux(t, col)
# current b-word
bt = b_narrow.words[t]
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
self._update_one_mac(x, t, col, ax, bt, dump, last_col)
if t == (col * NUM_MULTS + x):
part_index = t
self._mult_store_part(parts, t, col, part_index, x, dump)
# aux multiplier
if last_col:
ax = a_wide.words[self._index_aux[0]]
self._update_mac_aux(t, col, ax, bt, dump)
if t == ab_num_words:
part_index = t
self._mult_store_part_aux(parts, t, col, part_index, dump)
# debug output
if dump and DUMP_MACS_ACCUMULATION: self._dump_macs_with_aux(t, col)
# shortcut
if not last_col:
if t == (NUM_MULTS * (col + 1) - 1): break
return parts
def multiply_rectangle(self, a_wide, b_narrow, ab_num_words, dump=False):
num_cols = ab_num_words // NUM_MULTS
parts = list()
for i in range(2 * ab_num_words):
parts.append(0)
for col in range(num_cols):
for t in range(ab_num_words + 1):
# take care of indices
if t == 0: self._preset_indices(col)
else: self._rotate_indices(ab_num_words)
# take care of macs
if t == 0:
self._clear_all_macs(t, col, dump)
else:
t1 = t - 1
if (t1 // 8) == col:
self._clear_one_mac(t1 % NUM_MULTS, t, col, dump)
# debug output
if dump and DUMP_INDICES: self._dump_indices(t, col)
# current b-word
bt = b_narrow.words[t]
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
self._update_one_mac(x, t, col, ax, bt, dump)
# don't save one value for the very last time instant per column
if t < ab_num_words and t == (col * NUM_MULTS + x):
part_index = t
self._mult_store_part(parts, t, col, part_index, x, dump)
# debug output
if dump and DUMP_MACS_ACCUMULATION: self._dump_macs(t, col)
# save the upper parts of product at end of column
if t == ab_num_words:
for x in range(NUM_MULTS):
part_index = ab_num_words + col * NUM_MULTS + x
self._mult_store_part(parts, t, col, part_index, x, dump)
return parts
class ModExpNG_LowlevelOperator():
def __init__(self):
self._word_mask = 0
for x in range(_WORD_WIDTH):
self._word_mask |= (1 << x)
def _check_word(self, a):
if a < 0 or a >= (2 ** _WORD_WIDTH):
raise Exception("Word out of range!")
def _check_carry_borrow(self, cb):
if cb < 0 or cb > 1:
raise Exception("Carry or borrow out of range!")
def add_words(self, a, b, c_in):
self._check_word(a)
self._check_word(b)
self._check_carry_borrow(c_in)
sum = a + b + c_in
sum_s = sum & self._word_mask
sum_c = (sum >> _WORD_WIDTH) & 1
return (sum_c, sum_s)
def sub_words(self, a, b, b_in):
self._check_word(a)
self._check_word(b)
self._check_carry_borrow(b_in)
dif = a - b - b_in
if dif < 0:
dif_b = 1
dif_d = dif + 2 ** _WORD_WIDTH
else:
dif_b = 0
dif_d = dif
return (dif_b, dif_d)
class ModExpNG_Worker():
def __init__(self):
self.recombinator = ModExpNG_PartRecombinator()
self.multiplier = ModExpNG_WordMultiplier()
self.lowlevel = ModExpNG_LowlevelOperator()
def exponentiate(self, iz, bz, e, n, n_factor, n_coeff, num_words, dump_index=-1, dump_mode=""):
# working variables
t1, t2 = iz, bz
# length-1, length-2, length-3, ..., 1, 0 (left-to-right)
for bit in range(_WORD_WIDTH * num_words - 1, -1, -1):
debug_dump = bit == dump_index
bit_value = (e.number() & (1 << bit)) >> bit
if debug_dump:
print("\rladder_mode = %d" % bit_value)
if FORCE_OVERFLOW:
T1X = list(t1.words)
for i in range(num_words):
if i > 0:
bits = T1X[i-1] & (3 << 16)
if bits == 0:
bits = T1X[i] & 3
T1X[i] = T1X[i] ^ bits
T1X[i-1] |= (bits << 16)
for i in range(num_words):
t1.words[i] = T1X[i]
if DUMP_VECTORS:
print("num_words = %d" % num_words)
t1.format_verilog_concat("%s_T1" % dump_mode)
t2.format_verilog_concat("%s_T2" % dump_mode)
n.format_verilog_concat("%s_N" % dump_mode)
n_coeff.format_verilog_concat("%s_N_COEFF" % dump_mode)
# force the rarely seen overflow
if bit_value:
p1 = self.multiply(t1, t2, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="X")
p2 = self.multiply(t2, t2, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="Y")
else:
p1 = self.multiply(t1, t1, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="X")
p2 = self.multiply(t2, t1, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="Y")
t1, t2 = p1, p2
if debug_dump and DUMP_VECTORS:
t1.format_verilog_concat("%s_X" % dump_mode)
t2.format_verilog_concat("%s_Y" % dump_mode)
if (bit % 8) == 0:
pct = float((_WORD_WIDTH * num_words - bit) / (_WORD_WIDTH * num_words)) * 100.0
print("\rpct: %5.1f%%" % pct, end='')
print("")
return t1
def subtract(self, a, b, n, ab_num_words):
c_in = 0
b_in = 0
ab = list()
ab_n = list()
for x in range(ab_num_words):
a_word = a.words[x]
b_word = b.words[x]
(b_out, d_out) = self.lowlevel.sub_words(a_word, b_word, b_in)
(c_out, s_out) = self.lowlevel.add_words(d_out, n.words[x], c_in)
ab.append(d_out)
ab_n.append(s_out)
(c_in, b_in) = (c_out, b_out)
d = ab if not b_out else ab_n
return ModExpNG_Operand(None, ab_num_words, d)
def add(self, a, b, ab_num_words):
c_in = 0
ab = list()
for x in range(2 * ab_num_words):
a_word = a.words[x] if x < ab_num_words else 0
b_word = b.words[x]
(c_out, s_out) = self.lowlevel.add_words(a_word, b_word, c_in)
ab.append(s_out)
c_in = c_out
return ModExpNG_Operand(None, 2*ab_num_words, ab)
def multiply(self, a, b, n, n_coeff, ab_num_words, reduce_only=False, multiply_only=False, dump=False, dump_mode="", dump_phase=""):
# 1. AB = A * B
if dump: print("multiply_square(%s_%s)" % (dump_mode, dump_phase))
if reduce_only:
ab = a
else:
ab_parts = self.multiplier.multiply_square(a, b, ab_num_words, dump)
ab_words = self.recombinator.recombine_square(ab_parts, ab_num_words, dump)
ab = ModExpNG_Operand(None, 2 * ab_num_words, ab_words)
if dump and DUMP_VECTORS:
ab.format_verilog_concat("%s_%s_AB" % (dump_mode, dump_phase))
if multiply_only:
return ModExpNG_Operand(None, 2*ab_num_words, ab_words)
# 2. Q = LSB(AB) * N_COEFF
if dump: print("multiply_triangle(%s_%s)" % (dump_mode, dump_phase))
q_parts = self.multiplier.multiply_triangle(ab, n_coeff, ab_num_words, dump)
q_words = self.recombinator.recombine_triangle(q_parts, ab_num_words, dump)
q = ModExpNG_Operand(None, ab_num_words + 1, q_words)
if dump and DUMP_VECTORS:
q.format_verilog_concat("%s_%s_Q" % (dump_mode, dump_phase))
# 3. M = Q * N
if dump: print("multiply_rectangle(%s_%s)" % (dump_mode, dump_phase))
m_parts = self.multiplier.multiply_rectangle(n, q, ab_num_words, dump)
m_words = self.recombinator.recombine_rectangle(m_parts, ab_num_words, dump)
m = ModExpNG_Operand(None, 2 * ab_num_words + 1, m_words)
if dump and DUMP_VECTORS:
m.format_verilog_concat("%s_%s_M" % (dump_mode, dump_phase))
if (m.number() != (q.number() * n.number())):
print("MISMATCH")
sys.exit()
# 4. R = AB + M
# 4a. compute carry (actual sum is all zeroes and need not be stored)
r_cy = 0 # this can be up to two bits, since we're adding extended words!!
for i in range(ab_num_words + 1):
s = ab.words[i] + m.words[i] + r_cy
r_cy_new = s >> 16
if dump and DUMP_REDUCTION:
print("[%2d] 0x%05x + 0x%05x + 0x%x => {0x%x, [0x%05x]}" %
(i, ab.words[i], m.words[i], r_cy, r_cy_new, s & 0xffff))
r_cy = r_cy_new
# 4b. Initialize empty result
R = list()
for i in range(ab_num_words):
R.append(0)
# 4c. compute the actual upper part of sum (take carry into account)
for i in range(ab_num_words):
if dump and DUMP_REDUCTION:
print("[%2d]" % i, end='')
ab_word = ab.words[ab_num_words + i + 1] if i < (ab_num_words - 1) else 0
if dump and DUMP_REDUCTION:
print(" 0x%05x" % ab_word, end='')
m_word = m.words[ab_num_words + i + 1]
if dump and DUMP_REDUCTION:
print(" + 0x%05x" % m_word, end='')
if i == 0: R[i] = r_cy
else: R[i] = 0
if (r_cy > 3): print("\rR_CY = %d!" % r_cy)
if dump and DUMP_REDUCTION:
print(" + 0x%x" % R[i], end='')
R[i] += ab_word
R[i] += m_word
if dump and DUMP_REDUCTION:
print(" = 0x%05x" % R[i])
return ModExpNG_Operand(None, ab_num_words, R)
def reduce(self, a, num_words):
carry = 0
for x in range(num_words):
a.words[x] += carry
carry = (a.words[x] >> _WORD_WIDTH) & 3
a.words[x] &= self.lowlevel._word_mask
class ModExpNG_CoreOutputEnum(Enum):
XM = auto()
YM = auto()
S = auto()
class ModExpNG_CoreOutput():
def __init__(self):
self._xm = None
self._ym = None
self._s = None
def _set_value(self, sel, value):
if sel == ModExpNG_CoreOutputEnum.XM: self._xm = value
elif sel == ModExpNG_CoreOutputEnum.YM: self._ym = value
elif sel == ModExpNG_CoreOutputEnum.S: self._s = value
else: raise Exception("ModExpNG_CoreOutput._set_value(): invalid selector!")
def get_value(self, sel):
if sel == ModExpNG_CoreOutputEnum.XM: return self._xm
elif sel == ModExpNG_CoreOutputEnum.YM: return self._ym
elif sel == ModExpNG_CoreOutputEnum.S: return self._s
else: raise Exception("ModExpNG_CoreOutput.get_value(): invalid selector!")
class ModExpNG_Core():
def __init__(self, i):
self.wrk = ModExpNG_Worker()
self.bnk = ModExpNG_BanksCRT(i)
self.out = ModExpNG_CoreOutput()
def multiply(self, sel_wide_in, sel_narrow_in, sel_wide_out, sel_narrow_out, num_words, mode=(True, True)):
xn = self.bnk.crt_x.ladder_x.wide._get_value(ModExpNG_WideBankEnum.N)
yn = self.bnk.crt_y.ladder_x.wide._get_value(ModExpNG_WideBankEnum.N)
xn_coeff = self.bnk.crt_x.ladder_x.narrow._get_value(ModExpNG_NarrowBankEnum.N_COEFF)
yn_coeff = self.bnk.crt_y.ladder_x.narrow._get_value(ModExpNG_NarrowBankEnum.N_COEFF)
xxa = self.bnk.crt_x.ladder_x.wide._get_value(sel_wide_in)
xya = self.bnk.crt_x.ladder_y.wide._get_value(sel_wide_in)
yxa = self.bnk.crt_y.ladder_x.wide._get_value(sel_wide_in)
yya = self.bnk.crt_y.ladder_y.wide._get_value(sel_wide_in)
xxb = self.bnk.crt_x.ladder_x.narrow._get_value(sel_narrow_in)
xyb = self.bnk.crt_x.ladder_y.narrow._get_value(sel_narrow_in)
yxb = self.bnk.crt_y.ladder_x.narrow._get_value(sel_narrow_in)
yyb = self.bnk.crt_y.ladder_y.narrow._get_value(sel_narrow_in)
if not mode[0]: xb = xxb
else: xb = xyb
if not mode[1]: yb = yxb
else: yb = yyb
xxp = self.wrk.multiply(xxa, xb, xn, xn_coeff, num_words)
xyp = self.wrk.multiply(xya, xb, xn, xn_coeff, num_words)
yxp = self.wrk.multiply(yxa, yb, yn, yn_coeff, num_words)
yyp = self.wrk.multiply(yya, yb, yn, yn_coeff, num_words)
if sel_wide_out is not None:
self.bnk.crt_x.ladder_x.wide._set_value(sel_wide_out, xxp)
self.bnk.crt_x.ladder_y.wide._set_value(sel_wide_out, xyp)
self.bnk.crt_y.ladder_x.wide._set_value(sel_wide_out, yxp)
self.bnk.crt_y.ladder_y.wide._set_value(sel_wide_out, yyp)
if sel_narrow_out is not None:
self.bnk.crt_x.ladder_x.narrow._set_value(sel_narrow_out, xxp)
self.bnk.crt_x.ladder_y.narrow._set_value(sel_narrow_out, xyp)
self.bnk.crt_y.ladder_x.narrow._set_value(sel_narrow_out, yxp)
self.bnk.crt_y.ladder_y.narrow._set_value(sel_narrow_out, yyp)
def simply_reduce(self, sel_narrow, num_words):
self.wrk.reduce(self.bnk.crt_x.ladder_x.narrow._get_value(sel_narrow), num_words)
self.wrk.reduce(self.bnk.crt_x.ladder_y.narrow._get_value(sel_narrow), num_words)
self.wrk.reduce(self.bnk.crt_y.ladder_x.narrow._get_value(sel_narrow), num_words)
self.wrk.reduce(self.bnk.crt_y.ladder_y.narrow._get_value(sel_narrow), num_words)
def set_output(self, sel_output, banks_ladder, sel_narrow):
self.out._set_value(sel_output, banks_ladder.ladder_x.narrow._get_value(sel_narrow))
def mirror_yx(self, sel_wide, sel_narrow):
if sel_wide is not None:
self.bnk.crt_x.ladder_x.wide._set_value(sel_wide, self.bnk.crt_y.ladder_x.wide._get_value(sel_wide))
self.bnk.crt_x.ladder_y.wide._set_value(sel_wide, self.bnk.crt_y.ladder_y.wide._get_value(sel_wide))
if sel_narrow is not None:
self.bnk.crt_x.ladder_x.narrow._set_value(sel_narrow, self.bnk.crt_y.ladder_x.narrow._get_value(sel_narrow))
self.bnk.crt_x.ladder_y.narrow._set_value(sel_narrow, self.bnk.crt_y.ladder_y.narrow._get_value(sel_narrow))
if __name__ == "__main__":
# load test vector
# create worker
# set numbers of words
# obtain known good reference value with built-in math
# create helper quantity
# mutate blinding quantities with built-in math
n_num_words = KEY_LENGTH // _WORD_WIDTH
pq_num_words = n_num_words // 2
i = ModExpNG_Operand(1, KEY_LENGTH)
vector = ModExpNG_TestVector()
core = ModExpNG_Core(i)
s_known = pow(vector.m.number(), vector.d.number(), vector.n.number())
x_mutated_known = pow(vector.x.number(), 2, vector.n.number())
y_mutated_known = pow(vector.y.number(), 2, vector.n.number())
# bring one into Montgomery domain (glue 2**r to one)
# bring blinding coefficients into Montgomery domain (glue 2**(2*r) to x and y)
# blind message
# convert message to non-redundant representation
# first reduce message, this glues 2**-r to the message as a side effect
# unglue 2**-r from message by gluing 2**r to it to compensate
# bring message into Montgomery domain (glue 2**r to message)
# do "easier" exponentiations
# return "easier" parts from Montgomery domain (unglue 2**r from result)
# do the "Garner's formula" part
# r = sp - sq mod p
# sr_qinv = sr * qinv mod p
# q_sr_qinv = q * sr_qinv
# s_crt = sq + q_sr_qinv
# unblind s
# mutate blinding factors
W = ModExpNG_WideBankEnum
N = ModExpNG_NarrowBankEnum
O = ModExpNG_CoreOutputEnum
core.bnk.crt_x.set_modulus(vector.n, vector.n_coeff)
core.bnk.crt_y.set_modulus(vector.n, vector.n_coeff)
core.bnk.crt_x.set_operand(W.A, N.A, vector.x, vector.n_factor)
core.bnk.crt_y.set_operand(W.A, N.A, vector.y, vector.n_factor)
core.bnk.crt_x.set_operand(W.E, N.E, vector.m, vector.m)
core.bnk.crt_y.set_operand(W.E, N.E, vector.m, vector.m)
# | W | N
# --+-----+-----------
# A |
# B | ? | ?
# C | ? | ?
# D | ? | ?
# E | M | M
# | A | B | C | D | E |
# +----------------+-------+---------+-------+---+
# (YF, XF) =(Y,X)*N_FACTOR | X,Y ; N_FACTOR | ? | ? | ? | M |
core.multiply(W.A, N.A, W.B, N.B, n_num_words) # (YF, XF) =(Y,X)*N_FACTOR | X,Y ; N_FACTOR | XF,YF | ? | ? | M |
core.multiply(W.B, N.B, W.C, N.C, n_num_words, mode=(False, False)) # (YMF,XMF)=(YF*YF,XF*XF) | X,Y ; N_FACTOR | XF,YF | YMF,XMF | ? | M |
core.multiply(W.C, N.I, W.D, N.D, n_num_words) # (YM, XM) =(YMF,XMF)*1 | X,Y ; N_FACTOR | XF,YF | YMF,XMF | XM,YM | M |
core.simply_reduce(N.D, n_num_words) # | | | | | |
core.set_output(O.XM, core.bnk.crt_x, N.D) # | | | | | |
core.set_output(O.YM, core.bnk.crt_y, N.D) # | | | | | |
core.multiply(W.E, N.B, W.C, N.C, n_num_words, mode=(False, False)) # (MB, _) =(M*YF,M*XF) | X,Y ; N_FACTOR | XF,YF | MB,_ | XM,YM | M |
core.mirror_yx(W.C, N.C) # | X,Y ; N_FACTOR | XF,YF | MB,MB | XM,YM | M |
core.simply_reduce(N.C, n_num_words) # | | | | | |
XF = core.bnk.crt_x.ladder_x.wide._get_value(W.B)
YF = core.bnk.crt_y.ladder_x.wide._get_value(W.B)
MB = core.bnk.crt_y.ladder_x.narrow._get_value(N.C)
PMBZ = core.wrk.multiply(MB, None, vector.p, vector.p_coeff, pq_num_words, reduce_only=True) # mod_reduce (mod p)
QMBZ = core.wrk.multiply(MB, None, vector.q, vector.q_coeff, pq_num_words, reduce_only=True) # mod_reduce (mod q)
mp_blind = core.wrk.multiply(PMBZ, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
mq_blind = core.wrk.multiply(QMBZ, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
mp_blind_factor = core.wrk.multiply(mp_blind, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
mq_blind_factor = core.wrk.multiply(mq_blind, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
ip_factor = core.wrk.multiply(i, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
iq_factor = core.wrk.multiply(i, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
sp_blind_factor = core.wrk.exponentiate(ip_factor, mp_blind_factor, vector.dp, vector.p, vector.p_factor, vector.p_coeff, pq_num_words, dump_index=99, dump_mode="P") # mod_multiply
sq_blind_factor = core.wrk.exponentiate(iq_factor, mq_blind_factor, vector.dq, vector.q, vector.q_factor, vector.q_coeff, pq_num_words, dump_index=99, dump_mode="Q") # mod_multiply
SPB = core.wrk.multiply(i, sp_blind_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
SQB = core.wrk.multiply(i, sq_blind_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
core.wrk.reduce(SPB, len(SPB.words)) # just_reduce
core.wrk.reduce(SQB, len(SQB.words)) # just_reduce
sr_blind = core.wrk.subtract(SPB, SQB, vector.p, pq_num_words) # mod_subtract
sr_qinv_blind_inverse_factor = core.wrk.multiply(sr_blind, vector.qinv, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
sr_qinv_blind = core.wrk.multiply(sr_qinv_blind_inverse_factor, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
q_sr_qinv_blind = core.wrk.multiply(vector.q, sr_qinv_blind, None, None, pq_num_words, multiply_only=True) # just_multiply
core.wrk.reduce(q_sr_qinv_blind, n_num_words) # just_reduce
SB = core.wrk.add(SQB, q_sr_qinv_blind, pq_num_words) # just_add
S = core.wrk.multiply(SB, XF, vector.n, vector.n_coeff, n_num_words) # mod_multiply
core.wrk.reduce(S, len(S.words)) # just_reduce
# check
XM = core.out.get_value(O.XM)
YM = core.out.get_value(O.YM)
if S.number() != s_known: print("ERROR: s_crt_unblinded != s_known!")
else: print("s is OK")
if XM.number() != x_mutated_known: print("ERROR: x_mutated != x_mutated_known!")
else: print("x_mutated is OK")
if YM.number() != y_mutated_known: print("ERROR: y_mutated != y_mutated_known!")
else: print("y_mutated is OK")
#
# End-of-File
#