#!/usr/bin/python3 # # # ModExpNG core math model. # # # Copyright (c) 2019, NORDUnet A/S # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # - Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # # - Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # - Neither the name of the NORDUnet nor the names of its contributors may # be used to endorse or promote products derived from this software # without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # ------- # Imports #-------- import sys import importlib from enum import Enum, auto # -------------- # Model Settings # -------------- # length of public key KEY_LENGTH = 1024 # how many parallel multipliers to use NUM_MULTS = 8 # --------------- # Internal Values # --------------- # half of key length _KEY_LENGTH_HALF = KEY_LENGTH // 2 # width of internal math pipeline _WORD_WIDTH = 16 _WORD_WIDTH_EXT = 18 # folder with test vector scripts _VECTOR_PATH = "/vector" # name of test vector class _VECTOR_CLASS = "Vector" # ------------------ # Debugging Settings # ------------------ FORCE_OVERFLOW = False DUMP_VECTORS = False DUMP_INDICES = False DUMP_MACS_INPUTS = False DUMP_MACS_CLEARING = False DUMP_MACS_ACCUMULATION = False DUMP_MULT_PARTS = False DUMP_RCMB = False DUMP_REDUCTION = False # # Multi-Precision Integer # class ModExpNG_Operand(): def __init__(self, number, length, words = None): if words is None: # length must be divisible by word width if (length % _WORD_WIDTH) > 0: raise Exception("Bad number length!") self._init_from_number(number, length) else: # length must match words count if len(words) != length: raise Exception("Bad words count!") self._init_from_words(words, length) def format_verilog_concat(self, name): for i in range(len(self.words)): if i > 0: if (i % 4) == 0: print("") else: print(" ", end='') print("%s[%2d] = 18'h%05x;" % (name, i, self.words[i]), end='') print("") def _init_from_words(self, words, count): for i in range(count): # word must not exceed 18 bits if words[i] >= (2 ** (_WORD_WIDTH_EXT)): raise Exception("Word is too large!") self.words = words def _init_from_number(self, number, length): num_hexchars_per_word = _WORD_WIDTH // 4 num_hexchars_total = length // num_hexchars_per_word value_hex = format(number, 'x') # value must not be larger than specified, but it can be smaller, so # we may need to prepend it with zeroes if len(value_hex) > num_hexchars_total: raise Exception("Number is too large!") else: while len(value_hex) < num_hexchars_total: value_hex = "0" + value_hex # create empty list self.words = list() # fill in words while len(value_hex) > 0: value_hex_part = value_hex[-num_hexchars_per_word:] value_hex = value_hex[:-num_hexchars_per_word] self.words.append(int(value_hex_part, 16)) def number(self): ret = 0 shift = 0 for word in self.words: ret += word << shift shift += _WORD_WIDTH return ret # # Test Vector # class ModExpNG_TestVector(): def __init__(self): # format target filename filename = "vector_" + str(KEY_LENGTH) + "_randomized" # add ./vector to import search path sys.path.insert(1, sys.path[0] + _VECTOR_PATH) # import from filename vector_module = importlib.import_module(filename) # get vector class vector_class = getattr(vector_module, _VECTOR_CLASS) # instantiate vector class vector_inst = vector_class() # obtain parts of vector self.m = ModExpNG_Operand(vector_inst.m, KEY_LENGTH) self.n = ModExpNG_Operand(vector_inst.n, KEY_LENGTH) self.d = ModExpNG_Operand(vector_inst.d, KEY_LENGTH) self.p = ModExpNG_Operand(vector_inst.p, _KEY_LENGTH_HALF) self.q = ModExpNG_Operand(vector_inst.q, _KEY_LENGTH_HALF) self.dp = ModExpNG_Operand(vector_inst.dp, _KEY_LENGTH_HALF) self.dq = ModExpNG_Operand(vector_inst.dq, _KEY_LENGTH_HALF) self.qinv = ModExpNG_Operand(vector_inst.qinv, _KEY_LENGTH_HALF) self.n_factor = ModExpNG_Operand(vector_inst.n_factor, KEY_LENGTH) self.p_factor = ModExpNG_Operand(vector_inst.p_factor, _KEY_LENGTH_HALF) self.q_factor = ModExpNG_Operand(vector_inst.q_factor, _KEY_LENGTH_HALF) self.n_coeff = ModExpNG_Operand(vector_inst.n_coeff, KEY_LENGTH + _WORD_WIDTH) self.p_coeff = ModExpNG_Operand(vector_inst.p_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH) self.q_coeff = ModExpNG_Operand(vector_inst.q_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH) self.x = ModExpNG_Operand(vector_inst.x, KEY_LENGTH) self.y = ModExpNG_Operand(vector_inst.y, KEY_LENGTH) class ModExpNG_WideBankEnum(Enum): A = auto() B = auto() C = auto() D = auto() E = auto() N = auto() class ModExpNG_NarrowBankEnum(Enum): A = auto() B = auto() C = auto() D = auto() E = auto() N_COEFF = auto() I = auto() class ModExpNG_WideBank(): def __init__(self): self.a = None self.b = None self.c = None self.d = None self.e = None self.n = None def _get_value(self, sel): if sel == ModExpNG_WideBankEnum.A: return self.a elif sel == ModExpNG_WideBankEnum.B: return self.b elif sel == ModExpNG_WideBankEnum.C: return self.c elif sel == ModExpNG_WideBankEnum.D: return self.d elif sel == ModExpNG_WideBankEnum.E: return self.e elif sel == ModExpNG_WideBankEnum.N: return self.n else: raise Exception("ModExpNG_WideBank._get_value(): Invalid selector!") def _set_value(self, sel, value): if sel == ModExpNG_WideBankEnum.A: self.a = value elif sel == ModExpNG_WideBankEnum.B: self.b = value elif sel == ModExpNG_WideBankEnum.C: self.c = value elif sel == ModExpNG_WideBankEnum.D: self.d = value elif sel == ModExpNG_WideBankEnum.E: self.e = value elif sel == ModExpNG_WideBankEnum.N: self.n = value else: raise Exception("ModExpNG_WideBank._set_value(): Invalid selector!") class ModExpNG_NarrowBank(): def __init__(self, i): self.a = None self.b = None self.c = None self.d = None self.e = None self.n_coeff = None self.i = i def _get_value(self, sel): if sel == ModExpNG_NarrowBankEnum.A: return self.a elif sel == ModExpNG_NarrowBankEnum.B: return self.b elif sel == ModExpNG_NarrowBankEnum.C: return self.c elif sel == ModExpNG_NarrowBankEnum.D: return self.d elif sel == ModExpNG_NarrowBankEnum.E: return self.e elif sel == ModExpNG_NarrowBankEnum.N_COEFF: return self.n_coeff elif sel == ModExpNG_NarrowBankEnum.I: return self.i else: raise Exception("ModExpNG_NarrowBank._get_value(): Invalid selector!") def _set_value(self, sel, value): if sel == ModExpNG_NarrowBankEnum.A: self.a = value elif sel == ModExpNG_NarrowBankEnum.B: self.b = value elif sel == ModExpNG_NarrowBankEnum.C: self.c = value elif sel == ModExpNG_NarrowBankEnum.D: self.d = value elif sel == ModExpNG_NarrowBankEnum.E: self.e = value elif sel == ModExpNG_NarrowBankEnum.N_COEFF: self.n_coeff = value else: raise Exception("ModExpNG_NarrowBank._set_value(): Invalid selector!") class ModExpNG_BanksPair(): def __init__(self, i): self.wide = ModExpNG_WideBank() self.narrow = ModExpNG_NarrowBank(i) def _get_value_wide(self, sel): return self.wide._get_value(sel) def _get_value_narrow(self, sel): return self.narrow._get_value(sel) class ModExpNG_BanksLadder(): def __init__(self, i): self.ladder_x = ModExpNG_BanksPair(i) self.ladder_y = ModExpNG_BanksPair(i) def set_modulus(self, n, n_coeff): self.ladder_x.wide._set_value(ModExpNG_WideBankEnum.N, n) self.ladder_y.wide._set_value(ModExpNG_WideBankEnum.N, n) self.ladder_x.narrow._set_value(ModExpNG_NarrowBankEnum.N_COEFF, n_coeff) self.ladder_y.narrow._set_value(ModExpNG_NarrowBankEnum.N_COEFF, n_coeff) def set_operand(self, sel_wide, sel_narrow, x, y): if sel_wide is not None: self.ladder_x.wide._set_value(sel_wide, x) self.ladder_y.wide._set_value(sel_wide, y) if sel_narrow is not None: self.ladder_x.narrow._set_value(sel_narrow, x) self.ladder_y.narrow._set_value(sel_narrow, y) class ModExpNG_BanksCRT(): def __init__(self, i): self.crt_x = ModExpNG_BanksLadder(i) self.crt_y = ModExpNG_BanksLadder(i) class ModExpNG_PartRecombinator(): def _bit_select(self, x, msb, lsb): y = 0 for pos in range(lsb, msb+1): y |= (x & (1 << pos)) >> lsb return y def _flush_pipeline(self, dump): self.z0, self.y0, self.x0 = 0, 0, 0 if dump and DUMP_RCMB: print("RCMB -> flush()") def _push_pipeline(self, part, dump): # split next part into 16-bit words z = self._bit_select(part, 46, 32) y = self._bit_select(part, 31, 16) x = self._bit_select(part, 15, 0) # shift to the right z1 = z y1 = y + self.z0 x1 = x + self.y0 + (self.x0 >> _WORD_WIDTH) # IMPORTANT: This carry can be up to two bits wide!! # save lower 16 bits of the rightmost cell t = self.x0 & 0xffff # update internal latches self.z0, self.y0, self.x0 = z1, y1, x1 # dump if dump and DUMP_RCMB: print("RCMB -> push(): part = 0x%012x, word = 0x%04x" % (part, t)) # done return t def recombine_square(self, parts, ab_num_words, dump): # empty results so far words_lsb = list() # n words words_msb = list() # n words # recombine the lower half (n parts) # the first tick produces null result, the last part # produces three words and needs two extra ticks self._flush_pipeline(dump) for i in range(ab_num_words + 1 + 2): next_part = parts[i] if i < ab_num_words else 0 next_word = self._push_pipeline(next_part, dump) if i > 0: words_lsb.append(next_word) # recombine the upper half (n-1 parts) # the first tick produces null result self._flush_pipeline(dump) for i in range(ab_num_words + 1): next_part = parts[i + ab_num_words] if i < (ab_num_words - 1) else 0 next_word = self._push_pipeline(next_part, dump) if i > 0: words_msb.append(next_word) # merge words words = list() # merge lower half for x in range(ab_num_words): next_word = words_lsb[x] words.append(next_word) # merge upper half adding the two overlapping words for x in range(ab_num_words): next_word = words_msb[x] if x < 2: next_word += words_lsb[x + ab_num_words] words.append(next_word) return words def recombine_triangle(self, parts, ab_num_words, dump): # empty result so far words_lsb = list() # recombine the lower half (n+1 parts) # the first tick produces null result, so we need n + 1 + 1 = n + 2 # ticks total and should only save the result word during the last # n + 1 ticks self._flush_pipeline(dump) for i in range(ab_num_words + 2): next_part = parts[i] if i < (ab_num_words + 1) else 0 next_word = self._push_pipeline(next_part, dump) if i > 0: words_lsb.append(next_word) return words_lsb def recombine_rectangle(self, parts, ab_num_words, dump): # empty result so far words_lsb = list() # n words words_msb = list() # n+1 words # recombine the lower half (n parts) # the first tick produces null result, the last part # produces three words and needs two extra ticks self._flush_pipeline(dump) for i in range(ab_num_words + 1 + 2): next_part = parts[i] if i < ab_num_words else 0 next_word = self._push_pipeline(next_part, dump) if i > 0: words_lsb.append(next_word) # recombine the upper half (n parts) # the first tick produces null result, the last part # produces two words and needs an extra tick self._flush_pipeline(dump) for i in range(ab_num_words + 2): next_part = parts[i + ab_num_words] if i < ab_num_words else 0 next_word = self._push_pipeline(next_part, dump) if i > 0: words_msb.append(next_word) # merge words words = list() # merge lower half for x in range(ab_num_words): next_word = words_lsb[x] words.append(next_word) # merge upper half adding the two overlapping words for x in range(ab_num_words + 1): next_word = words_msb[x] if x < 2: next_word += words_lsb[x + ab_num_words] words.append(next_word) return words class ModExpNG_WordMultiplier(): def __init__(self): self._macs = list() self._indices = list() self._mac_aux = list() self._index_aux = list() for x in range(NUM_MULTS): self._macs.append(0) self._indices.append(0) self._mac_aux.append(0) self._index_aux.append(0) def _clear_all_macs(self, t, col, dump): for x in range(NUM_MULTS): self._macs[x] = 0 if dump and DUMP_MACS_CLEARING: print("t=%2d, col=%2d > clear > all" % (t, col)) def _clear_one_mac(self, x, t, col, dump): self._macs[x] = 0 if dump and DUMP_MACS_CLEARING: print("t=%2d, col=%2d > clear > x=%d" % (t, col, x)) def _clear_mac_aux(self, t, col, dump): self._mac_aux[0] = 0 if dump and DUMP_MACS_CLEARING: print("t= 0, col=%2d > clear > aux" % (col)) def _update_one_mac(self, x, t, col, a, b, dump, need_aux=False): if a > 0x3FFFF: raise Exception("a > 0x3FFFF!") if b > 0xFFFF: raise Exception("b > 0xFFFF!") p = a * b if dump and DUMP_MACS_INPUTS: if x == 0: print("t=%2d, col=%2d > b=%05x > " % (t, col, b), end='') if x > 0: print("; ", end='') print("MAC[%d]: a=%05x" % (x, a), end='') if x == (NUM_MULTS-1) and not need_aux: print("") self._macs[x] += p def _update_mac_aux(self, y, col, a, b, dump): if a > 0x3FFFF: raise Exception("a > 0x3FFFF!") if b > 0xFFFF: raise Exception("b > 0xFFFF!") p = a * b if dump and DUMP_MACS_INPUTS: print("; AUX: a=%05x" % a) self._mac_aux[0] += p def _preset_indices(self, col): for x in range(len(self._indices)): self._indices[x] = col * len(self._indices) + x def _preset_index_aux(self, num_cols): self._index_aux[0] = num_cols * len(self._indices) def _dump_macs_helper(self, t, col, aux=False): print("t=%2d, col=%2d > "% (t, col), end='') for i in range(NUM_MULTS): if i > 0: print(" | ", end='') print("mac[%d]: 0x%012x" % (i, self._macs[i]), end='') if aux: print(" | mac_aux[ 0]: 0x%012x" % (self._mac_aux[0]), end='') print("") def _dump_macs(self, t, col): self._dump_macs_helper(t, col) def _dump_macs_with_aux(self, t, col): self._dump_macs_helper(t, col, True) def _dump_indices_helper(self, t, col, aux=False): print("t=%2d, col=%2d > indices:" % (t, col), end='') for i in range(NUM_MULTS): print(" %2d" % self._indices[i], end='') if aux: print(" %2d" % self._index_aux[0], end='') print("") def _dump_indices(self, t, col): self._dump_indices_helper(t, col) def _dump_indices_with_aux(self, t, col): self._dump_indices_helper(t, col, True) def _rotate_indices(self, num_words): for x in range(len(self._indices)): if self._indices[x] > 0: self._indices[x] -= 1 else: self._indices[x] = num_words - 1 def _rotate_index_aux(self): self._index_aux[0] -= 1 def _mult_store_part(self, parts, time, column, part_index, mac_index, dump): parts[part_index] = self._macs[mac_index] if dump and DUMP_MULT_PARTS: print("t=%2d, col=%2d > parts[%2d]: mac[%d] = 0x%012x" % (time, column, part_index, mac_index, parts[part_index])) def _mult_store_part_aux(self, parts, time, column, part_index, dump): parts[part_index] = self._mac_aux[0] if dump and DUMP_MULT_PARTS: print("t=%2d, col=%2d > parts[%2d]: mac_aux[%d] = 0x%012x" % (time, column, part_index, 0, parts[part_index])) def multiply_square(self, a_wide, b_narrow, ab_num_words, dump=False): num_cols = ab_num_words // NUM_MULTS parts = list() for i in range(2 * ab_num_words - 1): parts.append(0) for col in range(num_cols): b_carry = 0 for t in range(ab_num_words): # take care of indices if t == 0: self._preset_indices(col) else: self._rotate_indices(ab_num_words) # take care of macs if t == 0: self._clear_all_macs(t, col, dump) else: t1 = t - 1 if (t1 // 8) == col: self._clear_one_mac(t1 % NUM_MULTS, t, col, dump) # debug output if dump and DUMP_INDICES: self._dump_indices(t, col) # current b-word # TODO: Explain how the 18th bit carry works!! bt = b_narrow.words[t] + b_carry b_carry = (bt & 0x30000) >> 16 bt &= 0xFFFF # multiply by a-words for x in range(NUM_MULTS): ax = a_wide.words[self._indices[x]] self._update_one_mac(x, t, col, ax, bt, dump) if t == (col * NUM_MULTS + x): part_index = t self._mult_store_part(parts, t, col, part_index, x, dump) # debug output if dump and DUMP_MACS_ACCUMULATION: self._dump_macs(t, col) # save the uppers part of product at end of column, # for the last column don't save the very last part if t == (ab_num_words - 1): for x in range(NUM_MULTS): if not (col == (num_cols - 1) and x == (NUM_MULTS - 1)): part_index = ab_num_words + col * NUM_MULTS + x self._mult_store_part(parts, t, col, part_index, x, dump) return parts def multiply_triangle(self, a_wide, b_narrow, ab_num_words, dump=False): num_cols = ab_num_words // NUM_MULTS parts = list() for i in range(ab_num_words + 1): parts.append(0) for col in range(num_cols): last_col = col == (num_cols - 1) for t in range(ab_num_words + 1): # take care of indices if t == 0: self._preset_indices(col) else: self._rotate_indices(ab_num_words) # take care of auxilary index if last_col: if t == 0: self._preset_index_aux(num_cols) else: self._rotate_index_aux() # take care of macs if t == 0: self._clear_all_macs(t, col, dump) # take care of auxilary mac if last_col: if t == 0: self._clear_mac_aux(t, col, dump) # debug output if dump and DUMP_INDICES: self._dump_indices_with_aux(t, col) # current b-word bt = b_narrow.words[t] # multiply by a-words for x in range(NUM_MULTS): ax = a_wide.words[self._indices[x]] self._update_one_mac(x, t, col, ax, bt, dump, last_col) if t == (col * NUM_MULTS + x): part_index = t self._mult_store_part(parts, t, col, part_index, x, dump) # aux multiplier if last_col: ax = a_wide.words[self._index_aux[0]] self._update_mac_aux(t, col, ax, bt, dump) if t == ab_num_words: part_index = t self._mult_store_part_aux(parts, t, col, part_index, dump) # debug output if dump and DUMP_MACS_ACCUMULATION: self._dump_macs_with_aux(t, col) # shortcut if not last_col: if t == (NUM_MULTS * (col + 1) - 1): break return parts def multiply_rectangle(self, a_wide, b_narrow, ab_num_words, dump=False): num_cols = ab_num_words // NUM_MULTS parts = list() for i in range(2 * ab_num_words): parts.append(0) for col in range(num_cols): for t in range(ab_num_words + 1): # take care of indices if t == 0: self._preset_indices(col) else: self._rotate_indices(ab_num_words) # take care of macs if t == 0: self._clear_all_macs(t, col, dump) else: t1 = t - 1 if (t1 // 8) == col: self._clear_one_mac(t1 % NUM_MULTS, t, col, dump) # debug output if dump and DUMP_INDICES: self._dump_indices(t, col) # current b-word bt = b_narrow.words[t] # multiply by a-words for x in range(NUM_MULTS): ax = a_wide.words[self._indices[x]] self._update_one_mac(x, t, col, ax, bt, dump) # don't save one value for the very last time instant per column if t < ab_num_words and t == (col * NUM_MULTS + x): part_index = t self._mult_store_part(parts, t, col, part_index, x, dump) # debug output if dump and DUMP_MACS_ACCUMULATION: self._dump_macs(t, col) # save the upper parts of product at end of column if t == ab_num_words: for x in range(NUM_MULTS): part_index = ab_num_words + col * NUM_MULTS + x self._mult_store_part(parts, t, col, part_index, x, dump) return parts class ModExpNG_LowlevelOperator(): def __init__(self): self._word_mask = 0 for x in range(_WORD_WIDTH): self._word_mask |= (1 << x) def _check_word(self, a): if a < 0 or a >= (2 ** _WORD_WIDTH): raise Exception("Word out of range!") def _check_carry_borrow(self, cb): if cb < 0 or cb > 1: raise Exception("Carry or borrow out of range!") def add_words(self, a, b, c_in): self._check_word(a) self._check_word(b) self._check_carry_borrow(c_in) sum = a + b + c_in sum_s = sum & self._word_mask sum_c = (sum >> _WORD_WIDTH) & 1 return (sum_c, sum_s) def sub_words(self, a, b, b_in): self._check_word(a) self._check_word(b) self._check_carry_borrow(b_in) dif = a - b - b_in if dif < 0: dif_b = 1 dif_d = dif + 2 ** _WORD_WIDTH else: dif_b = 0 dif_d = dif return (dif_b, dif_d) class ModExpNG_Worker(): def __init__(self): self.recombinator = ModExpNG_PartRecombinator() self.multiplier = ModExpNG_WordMultiplier() self.lowlevel = ModExpNG_LowlevelOperator() def exponentiate(self, iz, bz, e, n, n_factor, n_coeff, num_words, dump_index=-1, dump_mode=""): # working variables t1, t2 = iz, bz # length-1, length-2, length-3, ..., 1, 0 (left-to-right) for bit in range(_WORD_WIDTH * num_words - 1, -1, -1): debug_dump = bit == dump_index bit_value = (e.number() & (1 << bit)) >> bit if debug_dump: print("\rladder_mode = %d" % bit_value) if FORCE_OVERFLOW: T1X = list(t1.words) for i in range(num_words): if i > 0: bits = T1X[i-1] & (3 << 16) if bits == 0: bits = T1X[i] & 3 T1X[i] = T1X[i] ^ bits T1X[i-1] |= (bits << 16) for i in range(num_words): t1.words[i] = T1X[i] if DUMP_VECTORS: print("num_words = %d" % num_words) t1.format_verilog_concat("%s_T1" % dump_mode) t2.format_verilog_concat("%s_T2" % dump_mode) n.format_verilog_concat("%s_N" % dump_mode) n_coeff.format_verilog_concat("%s_N_COEFF" % dump_mode) # force the rarely seen overflow if bit_value: p1 = self.multiply(t1, t2, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="X") p2 = self.multiply(t2, t2, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="Y") else: p1 = self.multiply(t1, t1, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="X") p2 = self.multiply(t2, t1, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="Y") t1, t2 = p1, p2 if debug_dump and DUMP_VECTORS: t1.format_verilog_concat("%s_X" % dump_mode) t2.format_verilog_concat("%s_Y" % dump_mode) if (bit % 8) == 0: pct = float((_WORD_WIDTH * num_words - bit) / (_WORD_WIDTH * num_words)) * 100.0 print("\rpct: %5.1f%%" % pct, end='') print("") return t1 def subtract(self, a, b, n, ab_num_words): c_in = 0 b_in = 0 ab = list() ab_n = list() for x in range(ab_num_words): a_word = a.words[x] b_word = b.words[x] (b_out, d_out) = self.lowlevel.sub_words(a_word, b_word, b_in) (c_out, s_out) = self.lowlevel.add_words(d_out, n.words[x], c_in) ab.append(d_out) ab_n.append(s_out) (c_in, b_in) = (c_out, b_out) d = ab if not b_out else ab_n return ModExpNG_Operand(None, ab_num_words, d) def add(self, a, b, ab_num_words): c_in = 0 ab = list() for x in range(2 * ab_num_words): a_word = a.words[x] if x < ab_num_words else 0 b_word = b.words[x] (c_out, s_out) = self.lowlevel.add_words(a_word, b_word, c_in) ab.append(s_out) c_in = c_out return ModExpNG_Operand(None, 2*ab_num_words, ab) def multiply(self, a, b, n, n_coeff, ab_num_words, reduce_only=False, multiply_only=False, dump=False, dump_mode="", dump_phase=""): # 1. AB = A * B if dump: print("multiply_square(%s_%s)" % (dump_mode, dump_phase)) if reduce_only: ab = a else: ab_parts = self.multiplier.multiply_square(a, b, ab_num_words, dump) ab_words = self.recombinator.recombine_square(ab_parts, ab_num_words, dump) ab = ModExpNG_Operand(None, 2 * ab_num_words, ab_words) if dump and DUMP_VECTORS: ab.format_verilog_concat("%s_%s_AB" % (dump_mode, dump_phase)) if multiply_only: return ModExpNG_Operand(None, 2*ab_num_words, ab_words) # 2. Q = LSB(AB) * N_COEFF if dump: print("multiply_triangle(%s_%s)" % (dump_mode, dump_phase)) q_parts = self.multiplier.multiply_triangle(ab, n_coeff, ab_num_words, dump) q_words = self.recombinator.recombine_triangle(q_parts, ab_num_words, dump) q = ModExpNG_Operand(None, ab_num_words + 1, q_words) if dump and DUMP_VECTORS: q.format_verilog_concat("%s_%s_Q" % (dump_mode, dump_phase)) # 3. M = Q * N if dump: print("multiply_rectangle(%s_%s)" % (dump_mode, dump_phase)) m_parts = self.multiplier.multiply_rectangle(n, q, ab_num_words, dump) m_words = self.recombinator.recombine_rectangle(m_parts, ab_num_words, dump) m = ModExpNG_Operand(None, 2 * ab_num_words + 1, m_words) if dump and DUMP_VECTORS: m.format_verilog_concat("%s_%s_M" % (dump_mode, dump_phase)) if (m.number() != (q.number() * n.number())): print("MISMATCH") sys.exit() # 4. R = AB + M # 4a. compute carry (actual sum is all zeroes and need not be stored) r_cy = 0 # this can be up to two bits, since we're adding extended words!! for i in range(ab_num_words + 1): s = ab.words[i] + m.words[i] + r_cy r_cy_new = s >> 16 if dump and DUMP_REDUCTION: print("[%2d] 0x%05x + 0x%05x + 0x%x => {0x%x, [0x%05x]}" % (i, ab.words[i], m.words[i], r_cy, r_cy_new, s & 0xffff)) r_cy = r_cy_new # 4b. Initialize empty result R = list() for i in range(ab_num_words): R.append(0) # 4c. compute the actual upper part of sum (take carry into account) for i in range(ab_num_words): if dump and DUMP_REDUCTION: print("[%2d]" % i, end='') ab_word = ab.words[ab_num_words + i + 1] if i < (ab_num_words - 1) else 0 if dump and DUMP_REDUCTION: print(" 0x%05x" % ab_word, end='') m_word = m.words[ab_num_words + i + 1] if dump and DUMP_REDUCTION: print(" + 0x%05x" % m_word, end='') if i == 0: R[i] = r_cy else: R[i] = 0 if (r_cy > 3): print("\rR_CY = %d!" % r_cy) if dump and DUMP_REDUCTION: print(" + 0x%x" % R[i], end='') R[i] += ab_word R[i] += m_word if dump and DUMP_REDUCTION: print(" = 0x%05x" % R[i]) return ModExpNG_Operand(None, ab_num_words, R) def reduce(self, a, num_words): carry = 0 for x in range(num_words): a.words[x] += carry carry = (a.words[x] >> _WORD_WIDTH) & 3 a.words[x] &= self.lowlevel._word_mask class ModExpNG_CoreOutputEnum(Enum): XM = auto() YM = auto() S = auto() class ModExpNG_CoreOutput(): def __init__(self): self._xm = None self._ym = None self._s = None def _set_value(self, sel, value): if sel == ModExpNG_CoreOutputEnum.XM: self._xm = value elif sel == ModExpNG_CoreOutputEnum.YM: self._ym = value elif sel == ModExpNG_CoreOutputEnum.S: self._s = value else: raise Exception("ModExpNG_CoreOutput._set_value(): invalid selector!") def get_value(self, sel): if sel == ModExpNG_CoreOutputEnum.XM: return self._xm elif sel == ModExpNG_CoreOutputEnum.YM: return self._ym elif sel == ModExpNG_CoreOutputEnum.S: return self._s else: raise Exception("ModExpNG_CoreOutput.get_value(): invalid selector!") class ModExpNG_Core(): def __init__(self, i): self.wrk = ModExpNG_Worker() self.bnk = ModExpNG_BanksCRT(i) self.out = ModExpNG_CoreOutput() def multiply(self, sel_wide_in, sel_narrow_in, sel_wide_out, sel_narrow_out, num_words, mode=(True, True)): xn = self.bnk.crt_x.ladder_x.wide._get_value(ModExpNG_WideBankEnum.N) yn = self.bnk.crt_y.ladder_x.wide._get_value(ModExpNG_WideBankEnum.N) xn_coeff = self.bnk.crt_x.ladder_x.narrow._get_value(ModExpNG_NarrowBankEnum.N_COEFF) yn_coeff = self.bnk.crt_y.ladder_x.narrow._get_value(ModExpNG_NarrowBankEnum.N_COEFF) xxa = self.bnk.crt_x.ladder_x.wide._get_value(sel_wide_in) xya = self.bnk.crt_x.ladder_y.wide._get_value(sel_wide_in) yxa = self.bnk.crt_y.ladder_x.wide._get_value(sel_wide_in) yya = self.bnk.crt_y.ladder_y.wide._get_value(sel_wide_in) xxb = self.bnk.crt_x.ladder_x.narrow._get_value(sel_narrow_in) xyb = self.bnk.crt_x.ladder_y.narrow._get_value(sel_narrow_in) yxb = self.bnk.crt_y.ladder_x.narrow._get_value(sel_narrow_in) yyb = self.bnk.crt_y.ladder_y.narrow._get_value(sel_narrow_in) if not mode[0]: xb = xxb else: xb = xyb if not mode[1]: yb = yxb else: yb = yyb xxp = self.wrk.multiply(xxa, xb, xn, xn_coeff, num_words) xyp = self.wrk.multiply(xya, xb, xn, xn_coeff, num_words) yxp = self.wrk.multiply(yxa, yb, yn, yn_coeff, num_words) yyp = self.wrk.multiply(yya, yb, yn, yn_coeff, num_words) if sel_wide_out is not None: self.bnk.crt_x.ladder_x.wide._set_value(sel_wide_out, xxp) self.bnk.crt_x.ladder_y.wide._set_value(sel_wide_out, xyp) self.bnk.crt_y.ladder_x.wide._set_value(sel_wide_out, yxp) self.bnk.crt_y.ladder_y.wide._set_value(sel_wide_out, yyp) if sel_narrow_out is not None: self.bnk.crt_x.ladder_x.narrow._set_value(sel_narrow_out, xxp) self.bnk.crt_x.ladder_y.narrow._set_value(sel_narrow_out, xyp) self.bnk.crt_y.ladder_x.narrow._set_value(sel_narrow_out, yxp) self.bnk.crt_y.ladder_y.narrow._set_value(sel_narrow_out, yyp) def simply_reduce(self, sel_narrow, num_words): self.wrk.reduce(self.bnk.crt_x.ladder_x.narrow._get_value(sel_narrow), num_words) self.wrk.reduce(self.bnk.crt_x.ladder_y.narrow._get_value(sel_narrow), num_words) self.wrk.reduce(self.bnk.crt_y.ladder_x.narrow._get_value(sel_narrow), num_words) self.wrk.reduce(self.bnk.crt_y.ladder_y.narrow._get_value(sel_narrow), num_words) def set_output(self, sel_output, banks_ladder, sel_narrow): self.out._set_value(sel_output, banks_ladder.ladder_x.narrow._get_value(sel_narrow)) def mirror_yx(self, sel_wide, sel_narrow): if sel_wide is not None: self.bnk.crt_x.ladder_x.wide._set_value(sel_wide, self.bnk.crt_y.ladder_x.wide._get_value(sel_wide)) self.bnk.crt_x.ladder_y.wide._set_value(sel_wide, self.bnk.crt_y.ladder_y.wide._get_value(sel_wide)) if sel_narrow is not None: self.bnk.crt_x.ladder_x.narrow._set_value(sel_narrow, self.bnk.crt_y.ladder_x.narrow._get_value(sel_narrow)) self.bnk.crt_x.ladder_y.narrow._set_value(sel_narrow, self.bnk.crt_y.ladder_y.narrow._get_value(sel_narrow)) if __name__ == "__main__": # load test vector # create worker # set numbers of words # obtain known good reference value with built-in math # create helper quantity # mutate blinding quantities with built-in math n_num_words = KEY_LENGTH // _WORD_WIDTH pq_num_words = n_num_words // 2 i = ModExpNG_Operand(1, KEY_LENGTH) vector = ModExpNG_TestVector() core = ModExpNG_Core(i) s_known = pow(vector.m.number(), vector.d.number(), vector.n.number()) x_mutated_known = pow(vector.x.number(), 2, vector.n.number()) y_mutated_known = pow(vector.y.number(), 2, vector.n.number()) # bring one into Montgomery domain (glue 2**r to one) # bring blinding coefficients into Montgomery domain (glue 2**(2*r) to x and y) # blind message # convert message to non-redundant representation # first reduce message, this glues 2**-r to the message as a side effect # unglue 2**-r from message by gluing 2**r to it to compensate # bring message into Montgomery domain (glue 2**r to message) # do "easier" exponentiations # return "easier" parts from Montgomery domain (unglue 2**r from result) # do the "Garner's formula" part # r = sp - sq mod p # sr_qinv = sr * qinv mod p # q_sr_qinv = q * sr_qinv # s_crt = sq + q_sr_qinv # unblind s # mutate blinding factors W = ModExpNG_WideBankEnum N = ModExpNG_NarrowBankEnum O = ModExpNG_CoreOutputEnum core.bnk.crt_x.set_modulus(vector.n, vector.n_coeff) core.bnk.crt_y.set_modulus(vector.n, vector.n_coeff) core.bnk.crt_x.set_operand(W.A, N.A, vector.x, vector.n_factor) core.bnk.crt_y.set_operand(W.A, N.A, vector.y, vector.n_factor) core.bnk.crt_x.set_operand(W.E, N.E, vector.m, vector.m) core.bnk.crt_y.set_operand(W.E, N.E, vector.m, vector.m) # | W | N # --+-----+----------- # A | # B | ? | ? # C | ? | ? # D | ? | ? # E | M | M # | A | B | C | D | E | # +----------------+-------+---------+-------+---+ # (YF, XF) =(Y,X)*N_FACTOR | X,Y ; N_FACTOR | ? | ? | ? | M | core.multiply(W.A, N.A, W.B, N.B, n_num_words) # (YF, XF) =(Y,X)*N_FACTOR | X,Y ; N_FACTOR | XF,YF | ? | ? | M | core.multiply(W.B, N.B, W.C, N.C, n_num_words, mode=(False, False)) # (YMF,XMF)=(YF*YF,XF*XF) | X,Y ; N_FACTOR | XF,YF | YMF,XMF | ? | M | core.multiply(W.C, N.I, W.D, N.D, n_num_words) # (YM, XM) =(YMF,XMF)*1 | X,Y ; N_FACTOR | XF,YF | YMF,XMF | XM,YM | M | core.simply_reduce(N.D, n_num_words) # | | | | | | core.set_output(O.XM, core.bnk.crt_x, N.D) # | | | | | | core.set_output(O.YM, core.bnk.crt_y, N.D) # | | | | | | core.multiply(W.E, N.B, W.C, N.C, n_num_words, mode=(False, False)) # (MB, _) =(M*YF,M*XF) | X,Y ; N_FACTOR | XF,YF | MB,_ | XM,YM | M | core.mirror_yx(W.C, N.C) # | X,Y ; N_FACTOR | XF,YF | MB,MB | XM,YM | M | core.simply_reduce(N.C, n_num_words) # | | | | | | XF = core.bnk.crt_x.ladder_x.wide._get_value(W.B) YF = core.bnk.crt_y.ladder_x.wide._get_value(W.B) MB = core.bnk.crt_y.ladder_x.narrow._get_value(N.C) PMBZ = core.wrk.multiply(MB, None, vector.p, vector.p_coeff, pq_num_words, reduce_only=True) # mod_reduce (mod p) QMBZ = core.wrk.multiply(MB, None, vector.q, vector.q_coeff, pq_num_words, reduce_only=True) # mod_reduce (mod q) mp_blind = core.wrk.multiply(PMBZ, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply mq_blind = core.wrk.multiply(QMBZ, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply mp_blind_factor = core.wrk.multiply(mp_blind, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply mq_blind_factor = core.wrk.multiply(mq_blind, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply ip_factor = core.wrk.multiply(i, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply iq_factor = core.wrk.multiply(i, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply sp_blind_factor = core.wrk.exponentiate(ip_factor, mp_blind_factor, vector.dp, vector.p, vector.p_factor, vector.p_coeff, pq_num_words, dump_index=99, dump_mode="P") # mod_multiply sq_blind_factor = core.wrk.exponentiate(iq_factor, mq_blind_factor, vector.dq, vector.q, vector.q_factor, vector.q_coeff, pq_num_words, dump_index=99, dump_mode="Q") # mod_multiply SPB = core.wrk.multiply(i, sp_blind_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply SQB = core.wrk.multiply(i, sq_blind_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply core.wrk.reduce(SPB, len(SPB.words)) # just_reduce core.wrk.reduce(SQB, len(SQB.words)) # just_reduce sr_blind = core.wrk.subtract(SPB, SQB, vector.p, pq_num_words) # mod_subtract sr_qinv_blind_inverse_factor = core.wrk.multiply(sr_blind, vector.qinv, vector.p, vector.p_coeff, pq_num_words) # mod_multiply sr_qinv_blind = core.wrk.multiply(sr_qinv_blind_inverse_factor, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply q_sr_qinv_blind = core.wrk.multiply(vector.q, sr_qinv_blind, None, None, pq_num_words, multiply_only=True) # just_multiply core.wrk.reduce(q_sr_qinv_blind, n_num_words) # just_reduce SB = core.wrk.add(SQB, q_sr_qinv_blind, pq_num_words) # just_add S = core.wrk.multiply(SB, XF, vector.n, vector.n_coeff, n_num_words) # mod_multiply core.wrk.reduce(S, len(S.words)) # just_reduce # check XM = core.out.get_value(O.XM) YM = core.out.get_value(O.YM) if S.number() != s_known: print("ERROR: s_crt_unblinded != s_known!") else: print("s is OK") if XM.number() != x_mutated_known: print("ERROR: x_mutated != x_mutated_known!") else: print("x_mutated is OK") if YM.number() != y_mutated_known: print("ERROR: y_mutated != y_mutated_known!") else: print("y_mutated is OK") # # End-of-File #