#!/usr/bin/python3
#
#
# ModExpNG core math model.
#
#
# Copyright (c) 2019, NORDUnet A/S
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
# be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# -------
# Imports
#--------
import sys
import importlib
# --------------
# Model Settings
# --------------
# length of public key
KEY_LENGTH = 1024
# how many parallel multipliers to use
NUM_MULTS = 8
# ---------------
# Internal Values
# ---------------
# half of key length
_KEY_LENGTH_HALF = KEY_LENGTH // 2
# width of internal math pipeline
_WORD_WIDTH = 16
# folder with test vector scripts
_VECTOR_PATH = "/vector"
# name of test vector class
_VECTOR_CLASS = "Vector"
#
# Multi-Precision Integer
#
class ModExpNG_Operand():
def __init__(self, number, length, words = None):
if words is None:
# length must be divisible by word width
if (length % _WORD_WIDTH) > 0:
raise Exception("Bad number length!")
self._init_from_number(number, length)
else:
# length must match words count
if len(words) != length:
raise Exception("Bad words count!")
self._init_from_words(words, length)
def _init_from_words(self, words, count):
for i in range(count):
# word must not exceed 17 bits
if words[i] >= (2 ** (_WORD_WIDTH + 1)):
raise Exception("Word is too large!")
self.words = words
def _init_from_number(self, number, length):
num_hexchars_per_word = _WORD_WIDTH // 4
num_hexchars_total = length // num_hexchars_per_word
value_hex = format(number, 'x')
# value must not be larger than specified, but it can be smaller, so
# we may need to prepend it with zeroes
if len(value_hex) > num_hexchars_total:
raise Exception("Number is too large!")
else:
while len(value_hex) < num_hexchars_total:
value_hex = "0" + value_hex
# create empty list
self.words = list()
# fill in words
while len(value_hex) > 0:
value_hex_part = value_hex[-num_hexchars_per_word:]
value_hex = value_hex[:-num_hexchars_per_word]
self.words.append(int(value_hex_part, 16))
def number(self):
ret = 0
shift = 0
for word in self.words:
ret += word << shift
shift += _WORD_WIDTH
return ret
#
# Test Vector
#
class ModExpNG_TestVector():
def __init__(self):
# format target filename
filename = "vector_" + str(KEY_LENGTH) + "_randomized"
# add ./vector to import search path
sys.path.insert(1, sys.path[0] + _VECTOR_PATH)
# import from filename
vector_module = importlib.import_module(filename)
# get vector class
vector_class = getattr(vector_module, _VECTOR_CLASS)
# instantiate vector class
vector_inst = vector_class()
# obtain parts of vector
self.m = ModExpNG_Operand(vector_inst.m, KEY_LENGTH)
self.n = ModExpNG_Operand(vector_inst.n, KEY_LENGTH)
self.d = ModExpNG_Operand(vector_inst.d, KEY_LENGTH)
self.p = ModExpNG_Operand(vector_inst.p, _KEY_LENGTH_HALF)
self.q = ModExpNG_Operand(vector_inst.q, _KEY_LENGTH_HALF)
self.dp = ModExpNG_Operand(vector_inst.dp, _KEY_LENGTH_HALF)
self.dq = ModExpNG_Operand(vector_inst.dq, _KEY_LENGTH_HALF)
self.qinv = ModExpNG_Operand(vector_inst.qinv, _KEY_LENGTH_HALF)
self.n_factor = ModExpNG_Operand(vector_inst.n_factor, KEY_LENGTH)
self.p_factor = ModExpNG_Operand(vector_inst.p_factor, _KEY_LENGTH_HALF)
self.q_factor = ModExpNG_Operand(vector_inst.q_factor, _KEY_LENGTH_HALF)
self.n_coeff = ModExpNG_Operand(vector_inst.n_coeff, KEY_LENGTH + _WORD_WIDTH)
self.p_coeff = ModExpNG_Operand(vector_inst.p_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH)
self.q_coeff = ModExpNG_Operand(vector_inst.q_coeff, _KEY_LENGTH_HALF + _WORD_WIDTH)
class ModExpNG_PartRecombinator():
def _bit_select(self, x, msb, lsb):
y = 0
for pos in range(lsb, msb+1):
y |= (x & (1 << pos)) >> lsb
return y
def _flush_pipeline(self):
self.z0, self.y0, self.x0 = 0, 0, 0
def _push_pipeline(self, part):
# split next part into 16-bit words
z = self._bit_select(part, 47, 32)
y = self._bit_select(part, 31, 16)
x = self._bit_select(part, 15, 0)
# shift to the right
z1 = z
y1 = y + self.z0
x1 = x + self.y0 + (self.x0 >> 16) # IMPORTANT: This carry can be up to two bits wide!!
# save lower 16 bits of the rightmost cell
t = self.x0 & 0xffff
# update internal latches
self.z0, self.y0, self.x0 = z1, y1, x1
# done
return t
def recombine_square(self, parts, ab_num_words):
# empty result so far
words = list()
# flush recombinator pipeline
self._flush_pipeline()
# the first tick produces null result, the last part produces
# two words, so we need (2*n - 1) + 2 = 2*n + 1 ticks total
# and should only save the result word during the last 2 * n ticks
for i in range(2 * ab_num_words + 1):
next_part = parts[i] if i < (2 * ab_num_words - 1) else 0
next_word = self._push_pipeline(next_part)
if i > 0:
words.append(next_word)
return words
def recombine_triangle(self, parts, ab_num_words):
# empty result so far
words = list()
# flush recombinator pipeline
self._flush_pipeline()
# the first tick produces null result, so we need n + 1 + 1 = n + 2
# ticks total and should only save the result word during the last n ticks
for i in range(ab_num_words + 2):
next_part = parts[i] if i < (ab_num_words + 1) else 0
next_word = self._push_pipeline(next_part)
if i > 0:
words.append(next_word)
return words
def recombine_rectangle(self, parts, ab_num_words):
# empty result so far
words = list()
# flush recombinator pipeline
self._flush_pipeline()
# the first tick produces null result, the last part produces
# two words, so we need 2 * n + 2 ticks total and should only save
# the result word during the last 2 * n + 1 ticks
for i in range(2 * ab_num_words + 2):
next_part = parts[i] if i < (2 * ab_num_words) else 0
next_word = self._push_pipeline(next_part)
if i > 0:
words.append(next_word)
return words
class ModExpNG_WordMultiplier():
def __init__(self):
self._macs = list()
self._indices = list()
self._mac_aux = list()
self._index_aux = list()
for x in range(NUM_MULTS):
self._macs.append(0)
self._indices.append(0)
self._mac_aux.append(0)
self._index_aux.append(0)
def _clear_all_macs(self):
for x in range(NUM_MULTS):
self._macs[x] = 0
def _clear_one_mac(self, x):
self._macs[x] = 0
def _clear_mac_aux(self):
self._mac_aux[0] = 0
def _update_one_mac(self, x, value):
self._macs[x] += value
def _update_mac_aux(self, value):
self._mac_aux[0] += value
def _preset_indices(self, col):
for x in range(len(self._indices)):
self._indices[x] = col * len(self._indices) + x
def _preset_index_aux(self, num_cols):
self._index_aux[0] = num_cols * len(self._indices)
def _rotate_indices(self, num_words):
for x in range(len(self._indices)):
self._indices[x] -= 1
if self._indices[x] < 0:
self._indices[x] += num_words
def _rotate_index_aux(self):
self._index_aux[0] -= 1
def multiply_square(self, a_wide, b_narrow, ab_num_words):
num_cols = ab_num_words // NUM_MULTS
parts = list()
for i in range(2 * ab_num_words - 1):
parts.append(0)
for col in range(num_cols):
self._clear_all_macs()
self._preset_indices(col)
for t in range(ab_num_words):
# current b-word
bt = b_narrow.words[t]
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
self._update_one_mac(x, ax * bt)
if t == (col * NUM_MULTS + x):
parts[t] = self._macs[x]
self._clear_one_mac(x)
# save the uppers part of product at end of column,
# for the last column don't save the very last part
if t == (ab_num_words - 1):
for x in range(NUM_MULTS):
if not (col == (num_cols - 1) and x == (NUM_MULTS - 1)):
parts[ab_num_words + col * NUM_MULTS + x] = self._macs[x]
self._rotate_indices(ab_num_words)
return parts
def multiply_triangle(self, a_wide, b_narrow, ab_num_words):
num_cols = ab_num_words // NUM_MULTS
parts = list()
for i in range(ab_num_words + 1):
parts.append(0)
for col in range(num_cols):
last_col = col == (num_cols - 1)
self._clear_all_macs()
self._preset_indices(col)
if last_col:
self._clear_mac_aux()
self._preset_index_aux(num_cols)
for t in range(ab_num_words + 1):
# current b-word
bt = b_narrow.words[t]
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
self._update_one_mac(x, ax * bt)
if t == (col * NUM_MULTS + x):
parts[t] = self._macs[x]
# aux multiplier
if last_col:
ax = a_wide.words[self._index_aux[0]]
self._update_mac_aux(ax * bt)
if t == ab_num_words:
parts[t] = self._mac_aux[0]
# shortcut
if not last_col:
if t == (NUM_MULTS * (col + 1) - 1): break
# advance indices
self._rotate_indices(ab_num_words)
if last_col:
self._rotate_index_aux()
return parts
def multiply_rectangle(self, a_wide, b_narrow, ab_num_words):
num_cols = ab_num_words // NUM_MULTS
parts = list()
for i in range(2 * ab_num_words):
parts.append(0)
for col in range(num_cols):
self._clear_all_macs()
self._preset_indices(col)
for t in range(ab_num_words+1):
# current b-word
bt = b_narrow.words[t]
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
self._update_one_mac(x, ax * bt)
# don't save one value for the very last time instant per column
if t < ab_num_words and t == (col * NUM_MULTS + x):
parts[t] = self._macs[x]
self._clear_one_mac(x)
# save the uppers part of product at end of column
if t == ab_num_words:
for x in range(NUM_MULTS):
parts[ab_num_words + col * NUM_MULTS + x] = self._macs[x]
self._rotate_indices(ab_num_words)
return parts
class ModExpNG_LowlevelOperator():
def __init__(self):
self._word_mask = 0
for x in range(_WORD_WIDTH):
self._word_mask |= (1 << x)
def _check_word(self, a):
if a < 0 or a >= (2 ** _WORD_WIDTH):
raise Exception("Word out of range!")
def _check_carry_borrow(self, cb):
if cb < 0 or cb > 1:
raise Exception("Carry or borrow out of range!")
def add_words(self, a, b, c_in):
self._check_word(a)
self._check_word(b)
self._check_carry_borrow(c_in)
sum = a + b + c_in
sum_s = sum & self._word_mask
sum_c = (sum >> _WORD_WIDTH) & 1
return (sum_c, sum_s)
def sub_words(self, a, b, b_in):
self._check_word(a)
self._check_word(b)
self._check_carry_borrow(b_in)
dif = a - b - b_in
if dif < 0:
dif_b = 1
dif_d = dif + 2 ** _WORD_WIDTH
else:
dif_b = 0
dif_d = dif
return (dif_b, dif_d)
class ModExpNG_Worker():
def __init__(self):
self.recombinator = ModExpNG_PartRecombinator()
self.multiplier = ModExpNG_WordMultiplier()
self.lowlevel = ModExpNG_LowlevelOperator()
def exponentiate(self, iz, bz, e, n, n_factor, n_coeff, num_words):
# working variables
t1, t2 = iz, bz
# length-1, length-2, length-3, ..., 1, 0 (left-to-right)
for bit in range(_WORD_WIDTH * num_words - 1, -1, -1):
if e.number() & (1 << bit):
p1 = self.multiply(t1, t2, n, n_coeff, num_words)
p2 = self.multiply(t2, t2, n, n_coeff, num_words)
else:
p1 = self.multiply(t1, t1, n, n_coeff, num_words)
p2 = self.multiply(t2, t1, n, n_coeff, num_words)
t1, t2 = p1, p2
if (bit % 8) == 0:
pct = float((_WORD_WIDTH * num_words - bit) / (_WORD_WIDTH * num_words)) * 100.0
print("\rpct: %5.1f%%" % pct, end='')
print("")
return t1
def subtract(self, a, b, n, ab_num_words):
c_in = 0
b_in = 0
ab = list()
ab_n = list()
for x in range(ab_num_words):
a_word = a.words[x]
b_word = b.words[x]
(b_out, d_out) = self.lowlevel.sub_words(a_word, b_word, b_in)
(c_out, s_out) = self.lowlevel.add_words(d_out, n.words[x], c_in)
ab.append(d_out)
ab_n.append(s_out)
(c_in, b_in) = (c_out, b_out)
d = ab if not b_out else ab_n
return ModExpNG_Operand(None, ab_num_words, d)
def add(self, a, b, ab_num_words):
c_in = 0
ab = list()
for x in range(2 * ab_num_words):
a_word = a.words[x] if x < ab_num_words else 0
b_word = b.words[x]
(c_out, s_out) = self.lowlevel.add_words(a_word, b_word, c_in)
ab.append(s_out)
c_in = c_out
return ModExpNG_Operand(None, 2*ab_num_words, ab)
def multiply(self, a, b, n, n_coeff, ab_num_words, reduce_only=False, multiply_only=False):
# 1.
if reduce_only:
ab = a
else:
ab_parts = self.multiplier.multiply_square(a, b, ab_num_words)
ab_words = self.recombinator.recombine_square(ab_parts, ab_num_words)
ab = ModExpNG_Operand(None, 2 * ab_num_words, ab_words)
if multiply_only:
return ModExpNG_Operand(None, 2*ab_num_words, ab_words)
# 2.
q_parts = self.multiplier.multiply_triangle(ab, n_coeff, ab_num_words)
q_words = self.recombinator.recombine_triangle(q_parts, ab_num_words)
q = ModExpNG_Operand(None, ab_num_words + 1, q_words)
# 3.
m_parts = self.multiplier.multiply_rectangle(n, q, ab_num_words)
m_words = self.recombinator.recombine_rectangle(m_parts, ab_num_words)
m = ModExpNG_Operand(None, 2 * ab_num_words + 1, m_words)
# 4.
r_xwords = list()
for i in range(2*ab_num_words):
r_xwords.append(ab.words[i] + m.words[i])
r_xwords.append(m.words[2 * ab_num_words])
cy = 0
for i in range(ab_num_words+1):
s = r_xwords[i] + cy
cy = s >> 16
R = list()
for i in range(ab_num_words):
R.append(0)
R[0] += cy # !!!
for i in range(ab_num_words):
R[i] += r_xwords[ab_num_words + i + 1]
return ModExpNG_Operand(None, ab_num_words, R)
if __name__ == "__main__":
# load test vector
vector = ModExpNG_TestVector()
# create worker
worker = ModExpNG_Worker()
# number of words
pq_num_words = _KEY_LENGTH_HALF // _WORD_WIDTH
# obtain known good reference values with built-in math
s_known = pow(vector.m.number(), vector.d.number(), vector.n.number())
sp_known = pow(vector.m.number(), vector.dp.number(), vector.p.number())
sq_known = pow(vector.m.number(), vector.dq.number(), vector.q.number())
# first reduce message, this glues 2**-r to the message as a side effect
mpa = worker.multiply(vector.m, None, vector.p, vector.p_coeff, pq_num_words, reduce_only=True)
mqa = worker.multiply(vector.m, None, vector.q, vector.q_coeff, pq_num_words, reduce_only=True)
# unglue 2**-r from message by gluing 2**r to it to compensate
mp = worker.multiply(mpa, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
mq = worker.multiply(mqa, vector.q_factor, vector.q, vector.q_coeff, pq_num_words)
# one
i = ModExpNG_Operand(1, _KEY_LENGTH_HALF)
# bring one into Montgomery domain (glue 2**r to one)
ipz = worker.multiply(i, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
iqz = worker.multiply(i, vector.q_factor, vector.q, vector.q_coeff, pq_num_words)
# bring message into Montgomery domain (glue 2**r to message)
mpz = worker.multiply(mp, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
mqz = worker.multiply(mq, vector.q_factor, vector.q, vector.q_coeff, pq_num_words)
# do "easier" exponentiations
spz = worker.exponentiate(ipz, mpz, vector.dp, vector.p, vector.p_factor, vector.p_coeff, pq_num_words)
sqz = worker.exponentiate(iqz, mqz, vector.dq, vector.q, vector.q_factor, vector.q_coeff, pq_num_words)
# return "easier" parts from Montgomery domain (unglue 2**r from result)
sp = worker.multiply(i, spz, vector.p, vector.p_coeff, pq_num_words)
sq = worker.multiply(i, sqz, vector.q, vector.q_coeff, pq_num_words)
# check "easier" results
if sp.number() == sp_known: print("sp is OK")
else: print("sp is WRONG!")
if sq.number() == sq_known: print("sq is OK")
else: print("sq is WRONG!")
# do the "Garner's formula" part
# 1. r = sp - sq mod p
sr = worker.subtract(sp, sq, vector.p, pq_num_words)
# 2. sr_qinv = sr * qinv mod p
sr_qinv_a = worker.multiply(sr, vector.qinv, vector.p, vector.p_coeff, pq_num_words)
sr_qinv = worker.multiply(sr_qinv_a, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
# 3. q_sr_qinv = q * sr_qinv
q_sr_qinv = worker.multiply(vector.q, sr_qinv, None, None, pq_num_words, multiply_only=True)
# 4. s_crt = sq + q_sr_qinv
s_crt = worker.add(sq, q_sr_qinv, pq_num_words)
# check
if s_crt.number() != s_known:
print("ERROR: s_crt != s_known!")
else:
print("s is OK")