aboutsummaryrefslogblamecommitdiff
path: root/vector/vector_util.py
blob: 951ac6e567ded5a73305e45e9e1620c531a5de3e (plain) (tree)
1
2
3
4
5
6
7
8





                                                                

                                                         



                                                                        
 






                                                                          


                                                                        








































































































































































































































































































                                                                                           
#!/usr/bin/python3
#
#
# Helper routines for ModExpNG randomized test vector generator.
#
#
# Copyright 2019 The Commons Conservancy Cryptech Project
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# - Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
#
# - Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#


import sys
import random
import subprocess
from enum import Enum, auto


class VectorPiece(Enum):
    VectorPieceOther = auto()
    VectorPieceN     = auto()
    VectorPieceD     = auto()
    VectorPieceP     = auto()
    VectorPieceQ     = auto()
    VectorPieceDP    = auto()
    VectorPieceDQ    = auto()
    VectorPieceQINV  = auto()


class Vector:

    # public exponent
    _f4 = 0x10001

    def __init__(self, length):
        self._bits = length
        self._n    = ""
        self._d    = ""
        self._p    = ""
        self._q    = ""
        self._dp   = ""
        self._dq   = ""
        self._qinv = ""

    def _add_piece(self, type, value):
        value = value.replace(":",   "")
        value = value.replace("\r",  "")
        value = value.replace("\n",  "")
        value = value.replace(" ",   "")

        if   type == VectorPiece.VectorPieceN:    self._n    += value
        elif type == VectorPiece.VectorPieceD:    self._d    += value
        elif type == VectorPiece.VectorPieceP:    self._p    += value
        elif type == VectorPiece.VectorPieceQ:    self._q    += value
        elif type == VectorPiece.VectorPieceDP:   self._dp   += value
        elif type == VectorPiece.VectorPieceDQ:   self._dq   += value
        elif type == VectorPiece.VectorPieceQINV: self._qinv += value
        else: raise Exception("Invalid vector piece type!")

    def _calc_mont_factor(self, length, modulus):
        return pow(2, 2*length, modulus)

    def _calc_mod_coeff(self, length, modulus):

        pwr = pow(2, length)
        pwr_mask = pwr - 1

        r = 1
        b = 1

        nn = ((modulus ^ pwr_mask) + 1) % pwr

        for i in range(1, length):

            b = (b << 1) % pwr
            t = (r * nn) % pwr

            if t & (1 << i): r += b

        return r

    def _calc_blind_y(self, x, modulus):
        x_inv = self._modinv(x, modulus)
        return pow(x_inv, self._f4, modulus)

    def _egcd(self, a, b):
        if a == 0:
            return (b, 0, 1)
        else:
            g, y, x = self._egcd(b % a, a)
        return (g, x - (b // a) * y, y)

    def _modinv(self, a, m):
        g, x, y = self._egcd(a, m)
        if g != 1:
            raise Exception("_modinv() failed!")
        else:
            return x % m

    def selfcheck(self, message, blinding):

        self.m    = message             # message (padded)
        self.n    = int(self._n,    16) # modulus
        self.d    = int(self._d,    16) # private key
        self.p    = int(self._p,    16) # part of modulus
        self.q    = int(self._q,    16) # part of modulus
        self.dp   = int(self._dp,   16) # smaller exponent
        self.dq   = int(self._dq,   16) # smaller exponent
        self.qinv = int(self._qinv, 16) # helper coefficient

        self.x    = blinding
        self.y    = self._calc_blind_y(self.x, self.n)

        # check modulus
        if self.n == 0:
            print("ERROR: n == 0")
            return False

        if self.n != self.p * self.q:
            print("ERROR: n != (p * q)")
            return False

        # check smaller exponents
        if self.dp != (self.d % (self.p-1)):
            print("ERROR: dp != (d % (p-1))")
            return False

        if self.dq != (self.d % (self.q-1)):
            print("ERROR: dq != (d % (q-1))")
            return False

        # sign to obtain known good value
        s_reference = pow(message, self.d, self.n)

        # blind message
        message_blinded = (message * self.y) % self.n

        # sign blinded message
        s_blinded = pow(message_blinded, self.d, self.n)

        # unblind signature
        s_unblinded = (s_blinded * self.x) % self.n

        # check, that x and y actually work
        if s_unblinded != s_reference:
            print("ERROR: s_unblinded != s_reference!")
            return False

        # try to do crt with the blinded message
        sp_blinded = pow(message_blinded, self.dp, self.p)
        sq_blinded = pow(message_blinded, self.dq, self.q)

        # recover full blinded signature
        sr_blinded = sp_blinded - sq_blinded
        if sr_blinded < 0: sr_blinded += self.p

        sr_qinv_blinded = (sr_blinded * self.qinv) % self.p

        s_crt_blinded = sq_blinded + self.q * sr_qinv_blinded

        # unblind crt signature
        s_crt_unblinded = (s_crt_blinded * self.x) % self.n

        if s_crt_unblinded != s_reference:
            print("ERROR: s_crt_unblinded != s_reference!")
            return False

        self.n_factor = self._calc_mont_factor(self._bits      + 16, self.n)
        self.p_factor = self._calc_mont_factor(self._bits // 2 + 16, self.p)
        self.q_factor = self._calc_mont_factor(self._bits // 2 + 16, self.q)

        self.n_coeff = self._calc_mod_coeff(self._bits      + 16, self.n)
        self.p_coeff = self._calc_mod_coeff(self._bits // 2 + 16, self.p)
        self.q_coeff = self._calc_mod_coeff(self._bits // 2 + 16, self.q)

        print("Test vector checked.")

        return True


def openssl_binary(usage):

    # nothing so far
    openssl = ""

    # user didn't overide anything
    if len(sys.argv) == 1:
        openssl = "openssl"
        print("Using system OpenSSL library.")

    # user requested some specific binary
    elif len(sys.argv) == 2:
        openssl = sys.argv[1]
        print("Using OpenSSL binary '" + openssl + "'...")

    # didn't understand command line
    else:
        print(usage)

    # return path to selected binary (if any)
    return openssl


def openssl_genrsa(binary, length):

    filename = str(length) + "_randomized.key"
    subprocess.call([binary, "genrsa", "-out", filename, str(length)])


def random_message(seed, length):

    message = 0
    num_bytes = length // 8 - 1

    random.seed(seed)

    for i in range(num_bytes):
        message <<= 8
        message += random.getrandbits(8)

    return message


def random_blinding(seed, length):

    blinding = 0
    num_bytes = length // 8 - 1

    random.seed(seed)

    for i in range(num_bytes):
        blinding <<= 8
        blinding += random.getrandbits(8)

    return blinding


def load_vector(binary, length):

    vector = Vector(length)
    piece_type = VectorPiece.VectorPieceOther

    filename = str(length) + "_randomized.key"
    openssl_command = [binary, "rsa", "-in", filename, "-noout", "-text"]
    openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8").splitlines()

    for line in openssl_stdout:
        if   line.startswith("RSA Private-Key:"): piece_type = VectorPiece.VectorPieceOther
        elif line.startswith("modulus:"):         piece_type = VectorPiece.VectorPieceN
        elif line.startswith("publicExponent:"):  piece_type = VectorPiece.VectorPieceOther
        elif line.startswith("privateExponent:"): piece_type = VectorPiece.VectorPieceD
        elif line.startswith("prime1:"):          piece_type = VectorPiece.VectorPieceP
        elif line.startswith("prime2:"):          piece_type = VectorPiece.VectorPieceQ
        elif line.startswith("exponent1:"):       piece_type = VectorPiece.VectorPieceDP
        elif line.startswith("exponent2:"):       piece_type = VectorPiece.VectorPieceDQ
        elif line.startswith("coefficient:"):     piece_type = VectorPiece.VectorPieceQINV
        else: vector._add_piece(piece_type, line)

    return vector


def save_vector(vector):

    filename = "vector_" + str(vector._bits) + "_randomized.py"
    print("Writing to '%s'..." % filename)

    f = open(filename, 'w')

    f.write("# Generated automatically, do not edit.\n\n")

    f.write("class Vector:\n")
    f.write("    m        = 0x%x\n" % vector.m)
    f.write("    n        = 0x%x\n" % vector.n)
    f.write("    d        = 0x%x\n" % vector.d)
    f.write("    p        = 0x%x\n" % vector.p)
    f.write("    q        = 0x%x\n" % vector.q)
    f.write("    dp       = 0x%x\n" % vector.dp)
    f.write("    dq       = 0x%x\n" % vector.dq)
    f.write("    qinv     = 0x%x\n" % vector.qinv)
    f.write("    n_factor = 0x%x\n" % vector.n_factor)
    f.write("    p_factor = 0x%x\n" % vector.p_factor)
    f.write("    q_factor = 0x%x\n" % vector.q_factor)
    f.write("    n_coeff  = 0x%x\n" % vector.n_coeff)
    f.write("    p_coeff  = 0x%x\n" % vector.p_coeff)
    f.write("    q_coeff  = 0x%x\n" % vector.q_coeff)
    f.write("    x        = 0x%x\n" % vector.x)
    f.write("    y        = 0x%x\n" % vector.y)

    f.close()


#
# End of file
#