From 3ef1813079662305cf62ac68dc6a7729d3961d84 Mon Sep 17 00:00:00 2001 From: "Pavel V. Shatov (Meister)" Date: Sat, 23 Mar 2019 10:54:59 +0300 Subject: ModExpNG ("Next Generation") math model. --- modexpng_fpga_model.py | 690 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 690 insertions(+) create mode 100644 modexpng_fpga_model.py diff --git a/modexpng_fpga_model.py b/modexpng_fpga_model.py new file mode 100644 index 0000000..1152bdf --- /dev/null +++ b/modexpng_fpga_model.py @@ -0,0 +1,690 @@ +#!/usr/bin/python3 +# +# +# ModExpNG core math model. +# +# +# Copyright (c) 2019, NORDUnet A/S +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# - Neither the name of the NORDUnet nor the names of its contributors may +# be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + + +# ------- +# Imports +#-------- + +import sys +import importlib + + +# -------------- +# Model Settings +# -------------- + +# length of public key +KEY_LENGTH = 1024 + +# how many parallel multipliers to use +NUM_MULTS = 8 + + +# --------------- +# Internal Values +# --------------- + +# half of key length +_KEY_LENGTH_HALF = KEY_LENGTH // 2 + +# width of internal math pipeline +_WORD_WIDTH = 16 + +# folder with test vector scripts +_VECTOR_PATH = "/vector" + +# name of test vector class +_VECTOR_CLASS = "Vector" + + +# +# Multi-Precision Integer +# +class ModExpNG_Operand(): + + def __init__(self, number, length, words = None): + + if words is None: + + # length must be divisible by word width + if (length % _WORD_WIDTH) > 0: + raise Exception("Bad number length!") + + self._init_from_number(number, length) + + else: + + # length must match words count + if len(words) != length: + raise Exception("Bad words count!") + + self._init_from_words(words, length) + + + def _init_from_words(self, words, count): + + for i in range(count): + + # word must not exceed 17 bits + if words[i] >= (2 ** (_WORD_WIDTH + 1)): + raise Exception("Word is too large!") + + self.words = words + + def _init_from_number(self, number, length): + + num_hexchars_per_word = _WORD_WIDTH // 4 + num_hexchars_total = length // num_hexchars_per_word + + value_hex = format(number, 'x') + + # value must not be larger than specified, but it can be smaller, so + # we may need to prepend it with zeroes + if len(value_hex) > num_hexchars_total: + raise Exception("Number is too large!") + else: + while len(value_hex) < num_hexchars_total: + value_hex = "0" + value_hex + + # create empty list + self.words = list() + + # fill in words + while len(value_hex) > 0: + value_hex_part = value_hex[-num_hexchars_per_word:] + value_hex = value_hex[:-num_hexchars_per_word] + self.words.append(int(value_hex_part, 16)) + + def number(self): + ret = 0 + shift = 0 + for word in self.words: + ret += word << shift + shift += _WORD_WIDTH + return ret + + +# +# Test Vector +# +class ModExpNG_TestVector(): + + def __init__(self): + + # format target filename + filename = "vector_" + str(KEY_LENGTH) + "_randomized" + + # add ./vector to import search path + sys.path.insert(1, sys.path[0] + _VECTOR_PATH) + + # import from filename + vector_module = importlib.import_module(filename) + + # get vector class + vector_class = getattr(vector_module, _VECTOR_CLASS) + + # instantiate vector class + vector_inst = vector_class() + + # obtain parts of vector + self.m = ModExpNG_Operand(vector_inst.m, KEY_LENGTH) + self.n = ModExpNG_Operand(vector_inst.n, KEY_LENGTH) + self.d = ModExpNG_Operand(vector_inst.d, KEY_LENGTH) + self.p = ModExpNG_Operand(vector_inst.p, _KEY_LENGTH_HALF) + self.q = ModExpNG_Operand(vector_inst.q, _KEY_LENGTH_HALF) + self.dp = ModExpNG_Operand(vector_inst.dp, _KEY_LENGTH_HALF) + self.dq = ModExpNG_Operand(vector_inst.dq, _KEY_LENGTH_HALF) + self.qinv = ModExpNG_Operand(vector_inst.qinv, _KEY_LENGTH_HALF) + self.n_factor = ModExpNG_Operand(vector_inst.n_factor, KEY_LENGTH) + self.p_factor = ModExpNG_Operand(vector_inst.p_factor, _KEY_LENGTH_HALF) + self.q_factor = ModExpNG_Operand(vector_inst.q_factor, _KEY_LENGTH_HALF) + self.n_coeff = ModExpNG_Operand(vector_inst.n_coeff, KEY_LENGTH + _WORD_WIDTH) + self.p_coeff = ModExpNG_Operand(vector_inst.p_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH) + self.q_coeff = ModExpNG_Operand(vector_inst.q_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH) + + +class ModExpNG_PartRecombinator(): + + def _bit_select(self, x, msb, lsb): + y = 0 + for pos in range(lsb, msb+1): + y |= (x & (1 << pos)) >> lsb + return y + + def _flush_pipeline(self): + self.z0, self.y0, self.x0 = 0, 0, 0 + + def _push_pipeline(self, part): + + # split next part into 16-bit words + z = self._bit_select(part, 47, 32) + y = self._bit_select(part, 31, 16) + x = self._bit_select(part, 15, 0) + + # shift to the right + z1 = z + y1 = y + self.z0 + x1 = x + self.y0 + (self.x0 >> 16) # IMPORTANT: This carry can be up to two bits wide!! + + # save lower 16 bits of the rightmost cell + t = self.x0 & 0xffff + + # update internal latches + self.z0, self.y0, self.x0 = z1, y1, x1 + + # done + return t + + def recombine_square(self, parts, ab_num_words): + + # empty result so far + words = list() + + # flush recombinator pipeline + self._flush_pipeline() + + # the first tick produces null result, the last part produces + # two words, so we need (2*n - 1) + 2 = 2*n + 1 ticks total + # and should only save the result word during the last 2 * n ticks + for i in range(2 * ab_num_words + 1): + + next_part = parts[i] if i < (2 * ab_num_words - 1) else 0 + next_word = self._push_pipeline(next_part) + + if i > 0: + words.append(next_word) + + return words + + def recombine_triangle(self, parts, ab_num_words): + + # empty result so far + words = list() + + # flush recombinator pipeline + self._flush_pipeline() + + # the first tick produces null result, so we need n + 1 + 1 = n + 2 + # ticks total and should only save the result word during the last n ticks + for i in range(ab_num_words + 2): + + next_part = parts[i] if i < (ab_num_words + 1) else 0 + next_word = self._push_pipeline(next_part) + + if i > 0: + words.append(next_word) + + return words + + def recombine_rectangle(self, parts, ab_num_words): + + # empty result so far + words = list() + + # flush recombinator pipeline + self._flush_pipeline() + + # the first tick produces null result, the last part produces + # two words, so we need 2 * n + 2 ticks total and should only save + # the result word during the last 2 * n + 1 ticks + for i in range(2 * ab_num_words + 2): + + next_part = parts[i] if i < (2 * ab_num_words) else 0 + next_word = self._push_pipeline(next_part) + + if i > 0: + words.append(next_word) + + return words + + +class ModExpNG_WordMultiplier(): + + def __init__(self): + + self._macs = list() + self._indices = list() + + self._mac_aux = list() + self._index_aux = list() + + for x in range(NUM_MULTS): + self._macs.append(0) + self._indices.append(0) + + self._mac_aux.append(0) + self._index_aux.append(0) + + def _clear_all_macs(self): + for x in range(NUM_MULTS): + self._macs[x] = 0 + + def _clear_one_mac(self, x): + self._macs[x] = 0 + + def _clear_mac_aux(self): + self._mac_aux[0] = 0 + + def _update_one_mac(self, x, value): + self._macs[x] += value + + def _update_mac_aux(self, value): + self._mac_aux[0] += value + + def _preset_indices(self, col): + for x in range(len(self._indices)): + self._indices[x] = col * len(self._indices) + x + + def _preset_index_aux(self, num_cols): + self._index_aux[0] = num_cols * len(self._indices) + + def _rotate_indices(self, num_words): + for x in range(len(self._indices)): + self._indices[x] -= 1 + if self._indices[x] < 0: + self._indices[x] += num_words + + def _rotate_index_aux(self): + self._index_aux[0] -= 1 + + def multiply_square(self, a_wide, b_narrow, ab_num_words): + + num_cols = ab_num_words // NUM_MULTS + + parts = list() + for i in range(2 * ab_num_words - 1): + parts.append(0) + + for col in range(num_cols): + + self._clear_all_macs() + self._preset_indices(col) + + for t in range(ab_num_words): + + # current b-word + bt = b_narrow.words[t] + + # multiply by a-words + for x in range(NUM_MULTS): + ax = a_wide.words[self._indices[x]] + self._update_one_mac(x, ax * bt) + + if t == (col * NUM_MULTS + x): + parts[t] = self._macs[x] + self._clear_one_mac(x) + + # save the uppers part of product at end of column, + # for the last column don't save the very last part + if t == (ab_num_words - 1): + for x in range(NUM_MULTS): + if not (col == (num_cols - 1) and x == (NUM_MULTS - 1)): + parts[ab_num_words + col * NUM_MULTS + x] = self._macs[x] + + self._rotate_indices(ab_num_words) + + return parts + + def multiply_triangle(self, a_wide, b_narrow, ab_num_words): + + num_cols = ab_num_words // NUM_MULTS + + parts = list() + for i in range(ab_num_words + 1): + parts.append(0) + + for col in range(num_cols): + + last_col = col == (num_cols - 1) + + self._clear_all_macs() + self._preset_indices(col) + + if last_col: + self._clear_mac_aux() + self._preset_index_aux(num_cols) + + for t in range(ab_num_words + 1): + + # current b-word + bt = b_narrow.words[t] + + # multiply by a-words + for x in range(NUM_MULTS): + ax = a_wide.words[self._indices[x]] + self._update_one_mac(x, ax * bt) + + if t == (col * NUM_MULTS + x): + parts[t] = self._macs[x] + + # aux multiplier + if last_col: + ax = a_wide.words[self._index_aux[0]] + self._update_mac_aux(ax * bt) + + if t == ab_num_words: + parts[t] = self._mac_aux[0] + + # shortcut + if not last_col: + if t == (NUM_MULTS * (col + 1) - 1): break + + # advance indices + self._rotate_indices(ab_num_words) + if last_col: + self._rotate_index_aux() + + return parts + + def multiply_rectangle(self, a_wide, b_narrow, ab_num_words): + + num_cols = ab_num_words // NUM_MULTS + + parts = list() + for i in range(2 * ab_num_words): + parts.append(0) + + for col in range(num_cols): + + self._clear_all_macs() + self._preset_indices(col) + + for t in range(ab_num_words+1): + + # current b-word + bt = b_narrow.words[t] + + # multiply by a-words + for x in range(NUM_MULTS): + ax = a_wide.words[self._indices[x]] + self._update_one_mac(x, ax * bt) + + # don't save one value for the very last time instant per column + if t < ab_num_words and t == (col * NUM_MULTS + x): + parts[t] = self._macs[x] + self._clear_one_mac(x) + + # save the uppers part of product at end of column + if t == ab_num_words: + for x in range(NUM_MULTS): + parts[ab_num_words + col * NUM_MULTS + x] = self._macs[x] + + self._rotate_indices(ab_num_words) + + return parts + + +class ModExpNG_LowlevelOperator(): + + def __init__(self): + self._word_mask = 0 + for x in range(_WORD_WIDTH): + self._word_mask |= (1 << x) + + def _check_word(self, a): + if a < 0 or a >= (2 ** _WORD_WIDTH): + raise Exception("Word out of range!") + + def _check_carry_borrow(self, cb): + if cb < 0 or cb > 1: + raise Exception("Carry or borrow out of range!") + + def add_words(self, a, b, c_in): + + self._check_word(a) + self._check_word(b) + self._check_carry_borrow(c_in) + + sum = a + b + c_in + + sum_s = sum & self._word_mask + sum_c = (sum >> _WORD_WIDTH) & 1 + + return (sum_c, sum_s) + + def sub_words(self, a, b, b_in): + self._check_word(a) + self._check_word(b) + self._check_carry_borrow(b_in) + + dif = a - b - b_in + + if dif < 0: + dif_b = 1 + dif_d = dif + 2 ** _WORD_WIDTH + else: + dif_b = 0 + dif_d = dif + + return (dif_b, dif_d) + + +class ModExpNG_Worker(): + + def __init__(self): + self.recombinator = ModExpNG_PartRecombinator() + self.multiplier = ModExpNG_WordMultiplier() + self.lowlevel = ModExpNG_LowlevelOperator() + + def exponentiate(self, iz, bz, e, n, n_factor, n_coeff, num_words): + + # working variables + t1, t2 = iz, bz + + # length-1, length-2, length-3, ..., 1, 0 (left-to-right) + for bit in range(_WORD_WIDTH * num_words - 1, -1, -1): + + if e.number() & (1 << bit): + p1 = self.multiply(t1, t2, n, n_coeff, num_words) + p2 = self.multiply(t2, t2, n, n_coeff, num_words) + else: + p1 = self.multiply(t1, t1, n, n_coeff, num_words) + p2 = self.multiply(t2, t1, n, n_coeff, num_words) + + t1, t2 = p1, p2 + + if (bit % 8) == 0: + pct = float((_WORD_WIDTH * num_words - bit) / (_WORD_WIDTH * num_words)) * 100.0 + print("\rpct: %5.1f%%" % pct, end='') + + print("") + + return t1 + + def subtract(self, a, b, n, ab_num_words): + + c_in = 0 + b_in = 0 + + ab = list() + ab_n = list() + + for x in range(ab_num_words): + + a_word = a.words[x] + b_word = b.words[x] + + (b_out, d_out) = self.lowlevel.sub_words(a_word, b_word, b_in) + (c_out, s_out) = self.lowlevel.add_words(d_out, n.words[x], c_in) + + ab.append(d_out) + ab_n.append(s_out) + + (c_in, b_in) = (c_out, b_out) + + d = ab if not b_out else ab_n + + return ModExpNG_Operand(None, ab_num_words, d) + + def add(self, a, b, ab_num_words): + + c_in = 0 + + ab = list() + + for x in range(2 * ab_num_words): + + a_word = a.words[x] if x < ab_num_words else 0 + b_word = b.words[x] + + (c_out, s_out) = self.lowlevel.add_words(a_word, b_word, c_in) + + ab.append(s_out) + + c_in = c_out + + return ModExpNG_Operand(None, 2*ab_num_words, ab) + + def multiply(self, a, b, n, n_coeff, ab_num_words, reduce_only=False, multiply_only=False): + + # 1. + if reduce_only: + ab = a + else: + ab_parts = self.multiplier.multiply_square(a, b, ab_num_words) + ab_words = self.recombinator.recombine_square(ab_parts, ab_num_words) + ab = ModExpNG_Operand(None, 2 * ab_num_words, ab_words) + + if multiply_only: + return ModExpNG_Operand(None, 2*ab_num_words, ab_words) + + # 2. + q_parts = self.multiplier.multiply_triangle(ab, n_coeff, ab_num_words) + q_words = self.recombinator.recombine_triangle(q_parts, ab_num_words) + q = ModExpNG_Operand(None, ab_num_words + 1, q_words) + + # 3. + m_parts = self.multiplier.multiply_rectangle(n, q, ab_num_words) + m_words = self.recombinator.recombine_rectangle(m_parts, ab_num_words) + m = ModExpNG_Operand(None, 2 * ab_num_words + 1, m_words) + + # 4. + r_xwords = list() + for i in range(2*ab_num_words): + r_xwords.append(ab.words[i] + m.words[i]) + + r_xwords.append(m.words[2 * ab_num_words]) + + cy = 0 + for i in range(ab_num_words+1): + s = r_xwords[i] + cy + cy = s >> 16 + + R = list() + for i in range(ab_num_words): + R.append(0) + + R[0] += cy # !!! + + for i in range(ab_num_words): + R[i] += r_xwords[ab_num_words + i + 1] + + return ModExpNG_Operand(None, ab_num_words, R) + + +if __name__ == "__main__": + + # load test vector + vector = ModExpNG_TestVector() + + # create worker + worker = ModExpNG_Worker() + + # number of words + pq_num_words = _KEY_LENGTH_HALF // _WORD_WIDTH + + # obtain known good reference values with built-in math + s_known = pow(vector.m.number(), vector.d.number(), vector.n.number()) + sp_known = pow(vector.m.number(), vector.dp.number(), vector.p.number()) + sq_known = pow(vector.m.number(), vector.dq.number(), vector.q.number()) + + # first reduce message, this glues 2**-r to the message as a side effect + mpa = worker.multiply(vector.m, None, vector.p, vector.p_coeff, pq_num_words, reduce_only=True) + mqa = worker.multiply(vector.m, None, vector.q, vector.q_coeff, pq_num_words, reduce_only=True) + + # unglue 2**-r from message by gluing 2**r to it to compensate + mp = worker.multiply(mpa, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) + mq = worker.multiply(mqa, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) + + # one + i = ModExpNG_Operand(1, _KEY_LENGTH_HALF) + + # bring one into Montgomery domain (glue 2**r to one) + ipz = worker.multiply(i, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) + iqz = worker.multiply(i, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) + + # bring message into Montgomery domain (glue 2**r to message) + mpz = worker.multiply(mp, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) + mqz = worker.multiply(mq, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) + + # do "easier" exponentiations + spz = worker.exponentiate(ipz, mpz, vector.dp, vector.p, vector.p_factor, vector.p_coeff, pq_num_words) + sqz = worker.exponentiate(iqz, mqz, vector.dq, vector.q, vector.q_factor, vector.q_coeff, pq_num_words) + + # return "easier" parts from Montgomery domain (unglue 2**r from result) + sp = worker.multiply(i, spz, vector.p, vector.p_coeff, pq_num_words) + sq = worker.multiply(i, sqz, vector.q, vector.q_coeff, pq_num_words) + + # check "easier" results + if sp.number() == sp_known: print("sp is OK") + else: print("sp is WRONG!") + + if sq.number() == sq_known: print("sq is OK") + else: print("sq is WRONG!") + + + # do the "Garner's formula" part + + # 1. r = sp - sq mod p + sr = worker.subtract(sp, sq, vector.p, pq_num_words) + + # 2. sr_qinv = sr * qinv mod p + sr_qinv_a = worker.multiply(sr, vector.qinv, vector.p, vector.p_coeff, pq_num_words) + sr_qinv = worker.multiply(sr_qinv_a, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) + + # 3. q_sr_qinv = q * sr_qinv + q_sr_qinv = worker.multiply(vector.q, sr_qinv, None, None, pq_num_words, multiply_only=True) + + # 4. s_crt = sq + q_sr_qinv + s_crt = worker.add(sq, q_sr_qinv, pq_num_words) + + # check + if s_crt.number() != s_known: + print("ERROR: s_crt != s_known!") + else: + print("s is OK") + -- cgit v1.2.3