aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modexpng_fpga_model.py335
1 files changed, 185 insertions, 150 deletions
diff --git a/modexpng_fpga_model.py b/modexpng_fpga_model.py
index d33f314..cc3e868 100644
--- a/modexpng_fpga_model.py
+++ b/modexpng_fpga_model.py
@@ -74,12 +74,15 @@ _VECTOR_CLASS = "Vector"
# ------------------
# Debugging Settings
# ------------------
+FORCE_OVERFLOW = False
DUMP_VECTORS = False
DUMP_INDICES = False
+DUMP_MACS_INPUTS = False
DUMP_MACS_CLEARING = False
DUMP_MACS_ACCUMULATION = False
DUMP_MULT_PARTS = False
DUMP_RCMB = False
+DUMP_REDUCTION = False
#
@@ -111,14 +114,14 @@ class ModExpNG_Operand():
if i > 0:
if (i % 4) == 0: print("")
else: print(" ", end='')
- print("%s[%2d] = 17'h%05x;" % (name, i, self.words[i]), end='')
+ print("%s[%2d] = 18'h%05x;" % (name, i, self.words[i]), end='')
print("")
def _init_from_words(self, words, count):
for i in range(count):
- # word must not exceed 17 bits
+ # word must not exceed 18 bits
if words[i] >= (2 ** (_WORD_WIDTH + 2)):
raise Exception("Word is too large!")
@@ -221,7 +224,7 @@ class ModExpNG_PartRecombinator():
# shift to the right
z1 = z
y1 = y + self.z0
- x1 = x + self.y0 + (self.x0 >> 16) # IMPORTANT: This carry can be up to two bits wide!!
+ x1 = x + self.y0 + (self.x0 >> _WORD_WIDTH) # IMPORTANT: This carry can be up to two bits wide!!
# save lower 16 bits of the rightmost cell
t = self.x0 & 0xffff
@@ -287,7 +290,8 @@ class ModExpNG_PartRecombinator():
# recombine the lower half (n+1 parts)
# the first tick produces null result, so we need n + 1 + 1 = n + 2
- # ticks total and should only save the result word during the last n ticks
+ # ticks total and should only save the result word during the last
+ # n + 1 ticks
self._flush_pipeline(dump)
for i in range(ab_num_words + 2):
@@ -344,29 +348,9 @@ class ModExpNG_PartRecombinator():
return words
-
- # flush recombinator pipeline
- #self._flush_pipeline(dump)
-
- # the first tick produces null result, the last part produces
- # two words, so we need 2 * n + 2 ticks total and should only save
- # the result word during the last 2 * n + 1 ticks
- #for i in range(2 * ab_num_words + 2):
-
- #next_part = parts[i] if i < (2 * ab_num_words) else 0
- #next_word = self._push_pipeline(next_part, dump)
-
- #if i > 0:
- #words.append(next_word)
-
- return words
-
class ModExpNG_WordMultiplier():
- _a_seen_17 = False
- _b_seen_17 = False
-
def __init__(self):
self._macs = list()
@@ -388,38 +372,45 @@ class ModExpNG_WordMultiplier():
if dump and DUMP_MACS_CLEARING:
print("t=%2d, col=%2d > clear > all" % (t, col))
-
def _clear_one_mac(self, x, t, col, dump):
self._macs[x] = 0
if dump and DUMP_MACS_CLEARING:
print("t=%2d, col=%2d > clear > x=%d" % (t, col, x))
-
def _clear_mac_aux(self, t, col, dump):
self._mac_aux[0] = 0
if dump and DUMP_MACS_CLEARING:
print("t= 0, col=%2d > clear > aux" % (col))
+ def _update_one_mac(self, x, t, col, a, b, dump, need_aux=False):
- def _update_one_mac(self, x, a, b):
-
- if a > 0xFFFF:
- self._a_seen_17 = True
+ if a > 0x3FFFF:
+ raise Exception("a > 0x3FFFF!")
if b > 0xFFFF:
- self._b_seen_17 = True
+ raise Exception("b > 0xFFFF!")
+ p = a * b
+ if dump and DUMP_MACS_INPUTS:
+ if x == 0: print("t=%2d, col=%2d > b=%05x > " % (t, col, b), end='')
+ if x > 0: print("; ", end='')
+ print("MAC[%d]: a=%05x" % (x, a), end='')
+ if x == (NUM_MULTS-1) and not need_aux: print("")
+
+ self._macs[x] += p
+
+ def _update_mac_aux(self, y, col, a, b, dump):
+
if a > 0x3FFFF:
raise Exception("a > 0x3FFFF!")
- if b > 0x1FFFF:
- raise Exception("b > 0x1FFFF!")
+ if b > 0xFFFF:
+ raise Exception("b > 0xFFFF!")
p = a * b
- self._macs[x] += p
-
- def _update_mac_aux(self, value):
- self._mac_aux[0] += value
+ if dump and DUMP_MACS_INPUTS:
+ print("; AUX: a=%05x" % a)
+ self._mac_aux[0] += p
def _preset_indices(self, col):
for x in range(len(self._indices)):
@@ -440,7 +431,7 @@ class ModExpNG_WordMultiplier():
def _dump_macs(self, t, col):
self._dump_macs_helper(t, col)
- def _dump_macs_aux(self, t, col):
+ def _dump_macs_with_aux(self, t, col):
self._dump_macs_helper(t, col, True)
def _dump_indices_helper(self, t, col, aux=False):
@@ -454,7 +445,7 @@ class ModExpNG_WordMultiplier():
def _dump_indices(self, t, col):
self._dump_indices_helper(t, col)
- def _dump_indices_aux(self, t, col):
+ def _dump_indices_with_aux(self, t, col):
self._dump_indices_helper(t, col, True)
def _rotate_indices(self, num_words):
@@ -473,16 +464,14 @@ class ModExpNG_WordMultiplier():
print("t=%2d, col=%2d > parts[%2d]: mac[%d] = 0x%012x" %
(time, column, part_index, mac_index, parts[part_index]))
- def _mult_store_part_aux(self, parts, time, column, part_index, mac_index, dump):
- parts[part_index] = self._mac_aux[mac_index]
+ def _mult_store_part_aux(self, parts, time, column, part_index, dump):
+ parts[part_index] = self._mac_aux[0]
if dump and DUMP_MULT_PARTS:
print("t=%2d, col=%2d > parts[%2d]: mac_aux[%d] = 0x%012x" %
- (time, column, part_index, mac_index, parts[part_index]))
+ (time, column, part_index, 0, parts[part_index]))
def multiply_square(self, a_wide, b_narrow, ab_num_words, dump=False):
- if dump: print("multiply_square()")
-
num_cols = ab_num_words // NUM_MULTS
parts = list()
@@ -490,8 +479,8 @@ class ModExpNG_WordMultiplier():
parts.append(0)
for col in range(num_cols):
-
- bt_carry = 0
+
+ b_carry = 0
for t in range(ab_num_words):
@@ -511,15 +500,15 @@ class ModExpNG_WordMultiplier():
if dump and DUMP_INDICES: self._dump_indices(t, col)
# current b-word
- bt = b_narrow.words[t] + bt_carry
- bt_carry = bt >> _WORD_WIDTH
- bt &= 0xffff
-
+ # TODO: Explain how the 18th bit carry works!!
+ bt = b_narrow.words[t] + b_carry
+ b_carry = (bt & 0x30000) >> 16
+ bt &= 0xFFFF
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
- self._update_one_mac(x, ax, bt)
+ self._update_one_mac(x, t, col, ax, bt, dump)
if t == (col * NUM_MULTS + x):
part_index = t
@@ -540,8 +529,6 @@ class ModExpNG_WordMultiplier():
def multiply_triangle(self, a_wide, b_narrow, ab_num_words, dump=False):
- if dump: print("multiply_triangle()")
-
num_cols = ab_num_words // NUM_MULTS
parts = list()
@@ -571,7 +558,7 @@ class ModExpNG_WordMultiplier():
if t == 0: self._clear_mac_aux(t, col, dump)
# debug output
- if dump and DUMP_INDICES: self._dump_indices_aux(t, col)
+ if dump and DUMP_INDICES: self._dump_indices_with_aux(t, col)
# current b-word
bt = b_narrow.words[t]
@@ -579,7 +566,7 @@ class ModExpNG_WordMultiplier():
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
- self._update_one_mac(x, ax, bt)
+ self._update_one_mac(x, t, col, ax, bt, dump, last_col)
if t == (col * NUM_MULTS + x):
part_index = t
@@ -588,14 +575,14 @@ class ModExpNG_WordMultiplier():
# aux multiplier
if last_col:
ax = a_wide.words[self._index_aux[0]]
- self._update_mac_aux(ax * bt)
+ self._update_mac_aux(t, col, ax, bt, dump)
if t == ab_num_words:
part_index = t
- self._mult_store_part_aux(parts, t, col, part_index, 0, dump)
+ self._mult_store_part_aux(parts, t, col, part_index, dump)
# debug output
- if dump and DUMP_MACS_ACCUMULATION: self._dump_macs_aux(t, col)
+ if dump and DUMP_MACS_ACCUMULATION: self._dump_macs_with_aux(t, col)
# shortcut
if not last_col:
@@ -605,8 +592,6 @@ class ModExpNG_WordMultiplier():
def multiply_rectangle(self, a_wide, b_narrow, ab_num_words, dump=False):
- if dump: print("multiply_rectangle()")
-
num_cols = ab_num_words // NUM_MULTS
parts = list()
@@ -638,7 +623,7 @@ class ModExpNG_WordMultiplier():
# multiply by a-words
for x in range(NUM_MULTS):
ax = a_wide.words[self._indices[x]]
- self._update_one_mac(x, ax, bt)
+ self._update_one_mac(x, t, col, ax, bt, dump)
# don't save one value for the very last time instant per column
if t < ab_num_words and t == (col * NUM_MULTS + x):
@@ -686,6 +671,7 @@ class ModExpNG_LowlevelOperator():
return (sum_c, sum_s)
def sub_words(self, a, b, b_in):
+
self._check_word(a)
self._check_word(b)
self._check_carry_borrow(b_in)
@@ -704,14 +690,12 @@ class ModExpNG_LowlevelOperator():
class ModExpNG_Worker():
- max_zzz = 0
-
def __init__(self):
self.recombinator = ModExpNG_PartRecombinator()
self.multiplier = ModExpNG_WordMultiplier()
self.lowlevel = ModExpNG_LowlevelOperator()
- def exponentiate(self, iz, bz, e, n, n_factor, n_coeff, num_words):
+ def exponentiate(self, iz, bz, e, n, n_factor, n_coeff, num_words, dump_index=-1, dump_mode=""):
# working variables
t1, t2 = iz, bz
@@ -719,19 +703,51 @@ class ModExpNG_Worker():
# length-1, length-2, length-3, ..., 1, 0 (left-to-right)
for bit in range(_WORD_WIDTH * num_words - 1, -1, -1):
- if e.number() & (1 << bit):
- p1 = self.multiply(t1, t2, n, n_coeff, num_words)
- p2 = self.multiply(t2, t2, n, n_coeff, num_words)
+ debug_dump = bit == dump_index
+
+ bit_value = (e.number() & (1 << bit)) >> bit
+
+ if debug_dump:
+ print("\rladder_mode = %d" % bit_value)
+
+ if FORCE_OVERFLOW:
+ T1X = list(t1.words)
+ for i in range(num_words):
+ if i > 0:
+ bits = T1X[i-1] & (3 << 16)
+ if bits == 0:
+ bits = T1X[i] & 3
+ T1X[i] = T1X[i] ^ bits
+ T1X[i-1] |= (bits << 16)
+
+ for i in range(num_words):
+ t1.words[i] = T1X[i]
+
+ if DUMP_VECTORS:
+ print("num_words = %d" % num_words)
+ t1.format_verilog_concat("%s_T1" % dump_mode)
+ t2.format_verilog_concat("%s_T2" % dump_mode)
+ n.format_verilog_concat("%s_N" % dump_mode)
+ n_coeff.format_verilog_concat("%s_N_COEFF" % dump_mode)
+ # force the rarely seen overflow
+
+ if bit_value:
+ p1 = self.multiply(t1, t2, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="X")
+ p2 = self.multiply(t2, t2, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="Y")
else:
- p1 = self.multiply(t1, t1, n, n_coeff, num_words)
- p2 = self.multiply(t2, t1, n, n_coeff, num_words)
+ p1 = self.multiply(t1, t1, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="X")
+ p2 = self.multiply(t2, t1, n, n_coeff, num_words, dump=debug_dump, dump_mode=dump_mode, dump_phase="Y")
t1, t2 = p1, p2
+ if debug_dump and DUMP_VECTORS:
+ t1.format_verilog_concat("%s_X" % dump_mode)
+ t2.format_verilog_concat("%s_Y" % dump_mode)
+
if (bit % 8) == 0:
pct = float((_WORD_WIDTH * num_words - bit) / (_WORD_WIDTH * num_words)) * 100.0
print("\rpct: %5.1f%%" % pct, end='')
-
+
print("")
return t1
@@ -780,16 +796,11 @@ class ModExpNG_Worker():
return ModExpNG_Operand(None, 2*ab_num_words, ab)
- def multiply(self, a, b, n, n_coeff, ab_num_words, reduce_only=False, multiply_only=False, dump=False):
+ def multiply(self, a, b, n, n_coeff, ab_num_words, reduce_only=False, multiply_only=False, dump=False, dump_mode="", dump_phase=""):
- if dump and DUMP_VECTORS:
- print("num_words = %d" % ab_num_words)
- a.format_verilog_concat("A")
- b.format_verilog_concat("B")
- n.format_verilog_concat("N")
- n_coeff.format_verilog_concat("N_COEFF")
-
- # 1.
+ # 1. AB = A * B
+ if dump: print("multiply_square(%s_%s)" % (dump_mode, dump_phase))
+
if reduce_only:
ab = a
else:
@@ -797,61 +808,92 @@ class ModExpNG_Worker():
ab_words = self.recombinator.recombine_square(ab_parts, ab_num_words, dump)
ab = ModExpNG_Operand(None, 2 * ab_num_words, ab_words)
+ if dump and DUMP_VECTORS:
+ ab.format_verilog_concat("%s_%s_AB" % (dump_mode, dump_phase))
+
if multiply_only:
return ModExpNG_Operand(None, 2*ab_num_words, ab_words)
- # 2.
+
+ # 2. Q = LSB(AB) * N_COEFF
+ if dump: print("multiply_triangle(%s_%s)" % (dump_mode, dump_phase))
+
q_parts = self.multiplier.multiply_triangle(ab, n_coeff, ab_num_words, dump)
q_words = self.recombinator.recombine_triangle(q_parts, ab_num_words, dump)
q = ModExpNG_Operand(None, ab_num_words + 1, q_words)
- # 3.
+ if dump and DUMP_VECTORS:
+ q.format_verilog_concat("%s_%s_Q" % (dump_mode, dump_phase))
+
+ # 3. M = Q * N
+ if dump: print("multiply_rectangle(%s_%s)" % (dump_mode, dump_phase))
+
m_parts = self.multiplier.multiply_rectangle(n, q, ab_num_words, dump)
m_words = self.recombinator.recombine_rectangle(m_parts, ab_num_words, dump)
m = ModExpNG_Operand(None, 2 * ab_num_words + 1, m_words)
+
+ if dump and DUMP_VECTORS:
+ m.format_verilog_concat("%s_%s_M" % (dump_mode, dump_phase))
+
+ if (m.number() != (q.number() * n.number())):
+ print("MISMATCH")
+ sys.exit()
- # 4.
- cy = 0
+
+ # 4. R = AB + M
+
+ # 4a. compute carry (actual sum is all zeroes and need not be stored)
+ r_cy = 0 # this can be up to two bits, since we're adding extended words!!
for i in range(ab_num_words + 1):
- s = ab.words[i] + m.words[i] + cy
- cy = s >> 16
+ s = ab.words[i] + m.words[i] + r_cy
+ r_cy_new = s >> 16
+ if dump and DUMP_REDUCTION:
+ print("[%2d] 0x%05x + 0x%05x + 0x%x => {0x%x, [0x%05x]}" %
+ (i, ab.words[i], m.words[i], r_cy, r_cy_new, s & 0xffff))
+
+ r_cy = r_cy_new
+
+
+ # 4b. Initialize empty result
R = list()
for i in range(ab_num_words):
R.append(0)
- R[0] = cy # !!! (cy is 2 bits, i.e. 0..3)
-
- if dump:
- if ab.words[ab_num_words + 2] > 0:
- ab.words[ab_num_words + 2] -= 1
- ab.words[ab_num_words + 1] += 0x10000
- if m.words[ab_num_words + 2] > 0:
- m.words[ab_num_words + 2] -= 1
- m.words[ab_num_words + 1] += 0x10000
-
+ # 4c. compute the actual upper part of sum (take carry into account)
for i in range(ab_num_words):
+
+ if dump and DUMP_REDUCTION:
+ print("[%2d]" % i, end='')
+
ab_word = ab.words[ab_num_words + i + 1] if i < (ab_num_words - 1) else 0
+ if dump and DUMP_REDUCTION:
+ print(" 0x%05x" % ab_word, end='')
+
m_word = m.words[ab_num_words + i + 1]
-
- R[i] += ab_word + m_word
-
- #if i == 0:
- #if R[i] > self.max_zzz:
- #self.max_zzz = R[i]
- #print("self.max_zzz = %05x" % R[i])
- #if R[i] > 0x1ffff:
- #sys.exit(123)
+ if dump and DUMP_REDUCTION:
+ print(" + 0x%05x" % m_word, end='')
+ if i == 0: R[i] = r_cy
+ else: R[i] = 0
+ if (r_cy > 3): print("\rR_CY = %d!" % r_cy)
+
+ if dump and DUMP_REDUCTION:
+ print(" + 0x%x" % R[i], end='')
+ R[i] += ab_word
+ R[i] += m_word
+ if dump and DUMP_REDUCTION:
+ print(" = 0x%05x" % R[i])
+
return ModExpNG_Operand(None, ab_num_words, R)
def reduce(self, a):
carry = 0
for x in range(len(a.words)):
a.words[x] += carry
- carry = (a.words[x] >> _WORD_WIDTH) & 1
+ carry = (a.words[x] >> _WORD_WIDTH) & 3
a.words[x] &= self.lowlevel._word_mask
@@ -894,73 +936,66 @@ if __name__ == "__main__":
# s_crt = sq + q_sr_qinv
# unblind s
# mutate blinding factors
- ip_factor = worker.multiply(i, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
- iq_factor = worker.multiply(i, vector.q_factor, vector.q, vector.q_coeff, pq_num_words)
-
- x_factor = worker.multiply(vector.x, vector.n_factor, vector.n, vector.n_coeff, n_num_words)
- y_factor = worker.multiply(vector.y, vector.n_factor, vector.n, vector.n_coeff, n_num_words)
-
- m_blind = worker.multiply(vector.m, y_factor, vector.n, vector.n_coeff, n_num_words)
+
+ XF = worker.multiply(vector.x, vector.n_factor, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
+ YF = worker.multiply(vector.y, vector.n_factor, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
- worker.reduce(m_blind)
+ XMF = worker.multiply(XF, XF, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
+ YMF = worker.multiply(YF, YF, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
+
+ XM = worker.multiply(i, XMF, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
+ YM = worker.multiply(i, YMF, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
- mp_blind_inverse_factor = worker.multiply(m_blind, None, vector.p, vector.p_coeff, pq_num_words, reduce_only=True)
- mq_blind_inverse_factor = worker.multiply(m_blind, None, vector.q, vector.q_coeff, pq_num_words, reduce_only=True)
+ MB = worker.multiply(vector.m, YF, vector.n, vector.n_coeff, n_num_words) # mod_multiply (mod n)
- mp_blind = worker.multiply(mp_blind_inverse_factor, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
- mq_blind = worker.multiply(mq_blind_inverse_factor, vector.q_factor, vector.q, vector.q_coeff, pq_num_words)
+ worker.reduce(MB) # just_reduce
- mp_blind_factor = worker.multiply(mp_blind, vector.p_factor, vector.p, vector.p_coeff, pq_num_words, dump=True)
- mq_blind_factor = worker.multiply(mq_blind, vector.q_factor, vector.q, vector.q_coeff, pq_num_words)
+ mp_blind_inverse_factor = worker.multiply(MB, None, vector.p, vector.p_coeff, pq_num_words, reduce_only=True) # mod_reduce (mod p)
+ mq_blind_inverse_factor = worker.multiply(MB, None, vector.q, vector.q_coeff, pq_num_words, reduce_only=True) # mod_reduce (mod q)
- sp_blind_factor = worker.exponentiate(ip_factor, mp_blind_factor, vector.dp, vector.p, vector.p_factor, vector.p_coeff, pq_num_words)
- sq_blind_factor = worker.exponentiate(iq_factor, mq_blind_factor, vector.dq, vector.q, vector.q_factor, vector.q_coeff, pq_num_words)
+ mp_blind = worker.multiply(mp_blind_inverse_factor, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
+ mq_blind = worker.multiply(mq_blind_inverse_factor, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
- if worker.multiplier._a_seen_17:
- print("17-bit wide A's seen.")
- else:
- print("17-bit wide A's not detected.")
+ mp_blind_factor = worker.multiply(mp_blind, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
+ mq_blind_factor = worker.multiply(mq_blind, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
- if worker.multiplier._b_seen_17:
- print("17-bit wide B's seen.")
- else:
- print("17-bit wide B's not detected.")
+ ip_factor = worker.multiply(i, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
+ iq_factor = worker.multiply(i, vector.q_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
-
+ sp_blind_factor = worker.exponentiate(ip_factor, mp_blind_factor, vector.dp, vector.p, vector.p_factor, vector.p_coeff, pq_num_words, dump_index=99, dump_mode="P") # mod_multiply
+ sq_blind_factor = worker.exponentiate(iq_factor, mq_blind_factor, vector.dq, vector.q, vector.q_factor, vector.q_coeff, pq_num_words, dump_index=99, dump_mode="Q") # mod_multiply
- sp_blind = worker.multiply(i, sp_blind_factor, vector.p, vector.p_coeff, pq_num_words)
- sq_blind = worker.multiply(i, sq_blind_factor, vector.q, vector.q_coeff, pq_num_words)
+ SPB = worker.multiply(i, sp_blind_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
+ SQB = worker.multiply(i, sq_blind_factor, vector.q, vector.q_coeff, pq_num_words) # mod_multiply
- sr_blind = worker.subtract(sp_blind, sq_blind, vector.p, pq_num_words)
+ worker.reduce(SPB) # just_reduce
+ worker.reduce(SQB) # just_reduce
- sr_qinv_blind_inverse_factor = worker.multiply(sr_blind, vector.qinv, vector.p, vector.p_coeff, pq_num_words)
- sr_qinv_blind = worker.multiply(sr_qinv_blind_inverse_factor, vector.p_factor, vector.p, vector.p_coeff, pq_num_words)
- q_sr_qinv_blind = worker.multiply(vector.q, sr_qinv_blind, None, None, pq_num_words, multiply_only=True)
+ sr_blind = worker.subtract(SPB, SQB, vector.p, pq_num_words) # mod_subtract
- worker.reduce(q_sr_qinv_blind)
+ sr_qinv_blind_inverse_factor = worker.multiply(sr_blind, vector.qinv, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
+ sr_qinv_blind = worker.multiply(sr_qinv_blind_inverse_factor, vector.p_factor, vector.p, vector.p_coeff, pq_num_words) # mod_multiply
- s_crt_blinded = worker.add(sq_blind, q_sr_qinv_blind, pq_num_words)
+ q_sr_qinv_blind = worker.multiply(vector.q, sr_qinv_blind, None, None, pq_num_words, multiply_only=True) # just_multiply
- s_crt_unblinded = worker.multiply(s_crt_blinded, x_factor, vector.n, vector.n_coeff, n_num_words)
-
- x_mutated_factor = worker.multiply(x_factor, x_factor, vector.n, vector.n_coeff, n_num_words)
- y_mutated_factor = worker.multiply(y_factor, y_factor, vector.n, vector.n_coeff, n_num_words)
+ worker.reduce(q_sr_qinv_blind) # just_reduce
+
+ SB = worker.add(SQB, q_sr_qinv_blind, pq_num_words) # just_add
- x_mutated = worker.multiply(i, x_mutated_factor, vector.n, vector.n_coeff, n_num_words)
- y_mutated = worker.multiply(i, y_mutated_factor, vector.n, vector.n_coeff, n_num_words)
+ S = worker.multiply(SB, XF, vector.n, vector.n_coeff, n_num_words) # mod_multiply
- worker.reduce(s_crt_unblinded)
- worker.reduce(x_mutated)
- worker.reduce(y_mutated)
+ worker.reduce(S) # just_reduce
+ worker.reduce(XM) # just_reduce
+ worker.reduce(YM) # just_reduce
# check
- if s_crt_unblinded.number() != s_known: print("ERROR: s_crt_unblinded != s_known!")
+ if S.number() != s_known: print("ERROR: s_crt_unblinded != s_known!")
else: print("s is OK")
- if x_mutated.number() != x_mutated_known: print("ERROR: x_mutated != x_mutated_known!")
+ if XM.number() != x_mutated_known: print("ERROR: x_mutated != x_mutated_known!")
else: print("x_mutated is OK")
- if y_mutated.number() != y_mutated_known: print("ERROR: y_mutated != y_mutated_known!")
+ if YM.number() != y_mutated_known: print("ERROR: y_mutated != y_mutated_known!")
else: print("y_mutated is OK")