aboutsummaryrefslogblamecommitdiff
path: root/vector/vector_util.py
blob: 413cb61c617083b486d8af85bb1e2427b13ef744 (plain) (tree)




































































































































































































































































                                                                                          
#!/usr/bin/python3
#
#
# Helper routines for ModExpNG randomized test vector generator.
#
#
# Copyright (c) 2019, NORDUnet A/S
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
#   be used to endorse or promote products derived from this software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#


import sys
import random
import subprocess
from enum import Enum, auto


class VectorPiece(Enum):
    VectorPieceX    = auto()
    VectorPieceN    = auto()
    VectorPieceD    = auto()
    VectorPieceP    = auto()
    VectorPieceQ    = auto()
    VectorPieceDP   = auto()
    VectorPieceDQ   = auto()
    VectorPieceQINV = auto()


class Vector:

    def __init__(self, length):
        self._bits = length
        self._n    = ""
        self._d    = ""
        self._p    = ""
        self._q    = ""
        self._dp   = ""
        self._dq   = ""
        self._qinv = ""

    def _add_piece(self, type, value):
        value = value.replace(":",   "")
        value = value.replace("\r",  "")
        value = value.replace("\n",  "")
        value = value.replace(" ",   "")

        if   type == VectorPiece.VectorPieceN:    self._n    += value
        elif type == VectorPiece.VectorPieceD:    self._d    += value
        elif type == VectorPiece.VectorPieceP:    self._p    += value
        elif type == VectorPiece.VectorPieceQ:    self._q    += value
        elif type == VectorPiece.VectorPieceDP:   self._dp   += value
        elif type == VectorPiece.VectorPieceDQ:   self._dq   += value
        elif type == VectorPiece.VectorPieceQINV: self._qinv += value
        else: raise Exception("Invalid vector piece type!")

    def _calc_mont_factor(self, length, modulus):
        return pow(2, 2*length, modulus)

    def _calc_mod_coeff(self, length, modulus):

        pwr = pow(2, length)
        pwr_mask = pwr - 1

        r = 1
        b = 1

        nn = ((modulus ^ pwr_mask) + 1) % pwr

        for i in range(1, length):

            b = (b << 1) % pwr
            t = (r * nn) % pwr

            if t & (1 << i): r += b

        return r

    def selfcheck(self, message):

        self.m    = message             # message (padded)
        self.n    = int(self._n,    16) # modulus
        self.d    = int(self._d,    16) # private key
        self.p    = int(self._p,    16) # part of modulus
        self.q    = int(self._q,    16) # part of modulus
        self.dp   = int(self._dp,   16) # smaller exponent
        self.dq   = int(self._dq,   16) # smaller exponent
        self.qinv = int(self._qinv, 16) # helper coefficient

        # check modulus
        if self.n == 0:
            print("ERROR: n == 0")
            return False

        if self.n != self.p * self.q:
            print("ERROR: n != (p * q)")
            return False

        # check smaller exponents
        if self.dp != (self.d % (self.p-1)):
            print("ERROR: dp != (d % (p-1))")
            return False

        if self.dq != (self.d % (self.q-1)):
            print("ERROR: dq != (d % (q-1))")
            return False

        # sign
        s = pow(message, self.d, self.n)

        # try to do crt
        sp = pow(message, self.dp, self.p)
        sq = pow(message, self.dq, self.q)

        sr = sp - sq
        if sr < 0: sr += self.p

        srqinv = (sr * self.qinv) % self.p

        s_crt = sq + self.q * srqinv

        if s_crt != s:
            print("ERROR: s_crt != s!")
            return False

        self.n_factor = self._calc_mont_factor(self._bits      + 16, self.n)
        self.p_factor = self._calc_mont_factor(self._bits // 2 + 16, self.p)
        self.q_factor = self._calc_mont_factor(self._bits // 2 + 16, self.q)

        self.n_coeff = self._calc_mod_coeff(self._bits      + 16, self.n)
        self.p_coeff = self._calc_mod_coeff(self._bits // 2 + 16, self.p)
        self.q_coeff = self._calc_mod_coeff(self._bits // 2 + 16, self.q)

        print("Test vector checked.")

        return True


def openssl_binary(usage):

    # nothing so far
    openssl = ""

    # user didn't overide anything
    if len(sys.argv) == 1:
        openssl = "openssl"
        print("Using system OpenSSL library.")

    # user requested some specific binary
    elif len(sys.argv) == 2:
        openssl = sys.argv[1]
        print("Using OpenSSL binary '" + openssl + "'...")

    # didn't understand command line
    else:
        print(usage)

    # return path to selected binary (if any)
    return openssl


def openssl_genrsa(binary, length):

    filename = str(length) + "_randomized.key"
    subprocess.call([binary, "genrsa", "-out", filename, str(length)])


def random_message(seed, length):

    message = 0
    num_bytes = length // 8 - 1

    random.seed(seed)

    for i in range(num_bytes):
        message <<= 8
        message += random.getrandbits(8)

    return message


def load_vector(binary, length):

    vector = Vector(length)
    piece_type = VectorPiece.VectorPieceX

    filename = str(length) + "_randomized.key"
    openssl_command = [binary, "rsa", "-in", filename, "-noout", "-text"]
    openssl_stdout = subprocess.check_output(openssl_command).decode("utf-8").splitlines()

    for line in openssl_stdout:
        if   line.startswith("RSA Private-Key:"): piece_type = VectorPiece.VectorPieceX
        elif line.startswith("modulus:"):         piece_type = VectorPiece.VectorPieceN
        elif line.startswith("publicExponent:"):  piece_type = VectorPiece.VectorPieceX
        elif line.startswith("privateExponent:"): piece_type = VectorPiece.VectorPieceD
        elif line.startswith("prime1:"):          piece_type = VectorPiece.VectorPieceP
        elif line.startswith("prime2:"):          piece_type = VectorPiece.VectorPieceQ
        elif line.startswith("exponent1:"):       piece_type = VectorPiece.VectorPieceDP
        elif line.startswith("exponent2:"):       piece_type = VectorPiece.VectorPieceDQ
        elif line.startswith("coefficient:"):     piece_type = VectorPiece.VectorPieceQINV
        else: vector._add_piece(piece_type, line)

    return vector


def save_vector(vector):

    filename = "vector_" + str(vector._bits) + "_randomized.py"
    print("Writing to '%s'..." % filename)

    f = open(filename, 'w')

    f.write("# Generated automatically, do not edit.\n\n")

    f.write("class Vector:\n")
    f.write("    m        = 0x%x\n" % vector.m)
    f.write("    n        = 0x%x\n" % vector.n)
    f.write("    d        = 0x%x\n" % vector.d)
    f.write("    p        = 0x%x\n" % vector.p)
    f.write("    q        = 0x%x\n" % vector.q)
    f.write("    dp       = 0x%x\n" % vector.dp)
    f.write("    dq       = 0x%x\n" % vector.dq)
    f.write("    qinv     = 0x%x\n" % vector.qinv)
    f.write("    n_factor = 0x%x\n" % vector.n_factor)
    f.write("    p_factor = 0x%x\n" % vector.p_factor)
    f.write("    q_factor = 0x%x\n" % vector.q_factor)
    f.write("    n_coeff  = 0x%x\n" % vector.n_coeff)
    f.write("    p_coeff  = 0x%x\n" % vector.p_coeff)
    f.write("    q_coeff  = 0x%x\n" % vector.q_coeff)

    f.close()


#
# End of file
#