diff options
-rw-r--r-- | unit-tests.py | 202 |
1 files changed, 186 insertions, 16 deletions
diff --git a/unit-tests.py b/unit-tests.py index 6fe5ccf..d5b6c77 100644 --- a/unit-tests.py +++ b/unit-tests.py @@ -42,28 +42,34 @@ import datetime import logging import sys +from struct import pack, unpack + from libhal import * try: - from Crypto.Util.number import inverse - from Crypto.PublicKey import RSA - from Crypto.Signature import PKCS1_v1_5 - from Crypto.Hash.SHA256 import SHA256Hash as SHA256 - from Crypto.Hash.SHA384 import SHA384Hash as SHA384 - from Crypto.Hash.SHA512 import SHA512Hash as SHA512 + from Crypto.Util.number import inverse + from Crypto.PublicKey import RSA + from Crypto.Cipher import AES + from Crypto.Cipher.PKCS1_v1_5 import PKCS115_Cipher + from Crypto.Signature.PKCS1_v1_5 import PKCS115_SigScheme + from Crypto.Hash.SHA256 import SHA256Hash as SHA256 + from Crypto.Hash.SHA384 import SHA384Hash as SHA384 + from Crypto.Hash.SHA512 import SHA512Hash as SHA512 pycrypto_loaded = True except ImportError: pycrypto_loaded = False try: - from ecdsa import der as ECDSA_DER - from ecdsa.keys import SigningKey as ECDSA_SigningKey, VerifyingKey as ECDSA_VerifyingKey - from ecdsa.ellipticcurve import Point - from ecdsa.curves import NIST256p, NIST384p, NIST521p, find_curve as ECDSA_find_curve - from ecdsa.util import oid_ecPublicKey + from ecdsa import der as ECDSA_DER + from ecdsa.keys import SigningKey as ECDSA_SigningKey + from ecdsa.keys import VerifyingKey as ECDSA_VerifyingKey + from ecdsa.ellipticcurve import Point + from ecdsa.curves import NIST256p, NIST384p, NIST521p + from ecdsa.curves import find_curve as ECDSA_find_curve + from ecdsa.util import oid_ecPublicKey if not pycrypto_loaded: - from hashlib import sha256 as SHA256, sha384 as SHA384, sha512 as SHA512 + from hashlib import sha256 as SHA256, sha384 as SHA384, sha512 as SHA512 ecdsa_loaded = True except ImportError: ecdsa_loaded = False @@ -852,6 +858,168 @@ class TestPkeyECDSAVerificationNIST(TestCaseLoggedIn): py_hash = SHA384) +@unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package") +class TestPKeyBackup(TestCaseLoggedIn): + + oid_rsaEncryption = "\x2A\x86\x48\x86\xF7\x0D\x01\x01\x01" + oid_aesKeyWrap = "\x60\x86\x48\x01\x65\x03\x04\x01\x30" + + @staticmethod + def parse_EncryptedPrivateKeyInfo(der, oid): + from Crypto.Util.asn1 import DerObject, DerSequence, DerOctetString, DerObjectId + encryptedPrivateKeyInfo = DerSequence() + encryptedPrivateKeyInfo.decode(der) + encryptionAlgorithm = DerSequence() + encryptionAlgorithm.decode(encryptedPrivateKeyInfo[0]) + algorithm = DerObjectId() + algorithm.decode(encryptionAlgorithm[0]) + encryptedData = DerOctetString() + encryptedData.decode(encryptedPrivateKeyInfo[1]) + if algorithm.payload != oid: + raise ValueError + return encryptedData.payload + + @staticmethod + def encode_EncryptedPrivateKeyInfo(der, oid): + from Crypto.Util.asn1 import DerSequence, DerOctetString + return DerSequence([ + DerSequence([chr(0x06) + chr(len(oid)) + oid]).encode(), + DerOctetString(der).encode() + ]).encode() + + @staticmethod + def make_kek(): + import Crypto.Random + return Crypto.Random.new().read(256/8) + + def sig_check(self, pkey, der): + from Crypto.Util.asn1 import DerSequence, DerNull, DerOctetString + p115 = PKCS115_SigScheme(RSA.importKey(der)) + hash = SHA256("Your mother was a hamster") + data = DerSequence([ + DerSequence([hash.oid, DerNull().encode()]).encode(), + DerOctetString(hash.digest()).encode() + ]).encode() + sig1 = p115.sign(hash) + sig2 = pkey.sign(data = data) + self.assertEqual(sig1, sig2) + p115.verify(hash, sig1) + p115.verify(hash, sig2) + pkey.verify(signature = sig1, data = data) + pkey.verify(signature = sig2, data = data) + + def test_export(self): + kekek = hsm.pkey_load( + flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT, + der = PreloadedKey.db[HAL_KEY_TYPE_RSA_PUBLIC, 1024].der) + self.addCleanup(kekek.delete) + pkey = hsm.pkey_generate_rsa( + keylen= 1024, + flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE | HAL_KEY_FLAG_EXPORTABLE) + self.addCleanup(pkey.delete) + pkcs8_der, kek_der = kekek.export_pkey(pkey) + kek = PKCS115_Cipher(PreloadedKey.db[HAL_KEY_TYPE_RSA_PRIVATE, 1024].obj).decrypt( + self.parse_EncryptedPrivateKeyInfo(kek_der, self.oid_rsaEncryption), + self.make_kek()) + der = AESKeyWrapWithPadding(kek).unwrap( + self.parse_EncryptedPrivateKeyInfo(pkcs8_der, self.oid_aesKeyWrap)) + self.sig_check(pkey, der) + + def test_import(self): + kekek = hsm.pkey_generate_rsa( + keylen= 1024, + flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT) + self.addCleanup(kekek.delete) + kek = self.make_kek() + der = PreloadedKey.db[HAL_KEY_TYPE_RSA_PRIVATE, 1024].der + pkey = kekek.import_pkey( + pkcs8 = self.encode_EncryptedPrivateKeyInfo( + AESKeyWrapWithPadding(kek).wrap(der), + self.oid_aesKeyWrap), + kek = self.encode_EncryptedPrivateKeyInfo( + PKCS115_Cipher(RSA.importKey(kekek.public_key)).encrypt(kek), + self.oid_rsaEncryption), + flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE) + self.addCleanup(pkey.delete) + self.sig_check(pkey, der) + + +class AESKeyWrapWithPadding(object): + """ + Implementation of AES Key Wrap With Padding from RFC 5649. + """ + + class UnwrapError(Exception): + "Something went wrong during unwrap." + + def __init__(self, key): + self.ctx = AES.new(key, AES.MODE_ECB) + + def _encrypt(self, b1, b2): + aes_block = self.ctx.encrypt(b1 + b2) + return aes_block[:8], aes_block[8:] + + def _decrypt(self, b1, b2): + aes_block = self.ctx.decrypt(b1 + b2) + return aes_block[:8], aes_block[8:] + + @staticmethod + def _start_stop(start, stop): # Syntactic sugar + step = -1 if start > stop else 1 + return xrange(start, stop + step, step) + + def wrap(self, Q): + "RFC 5649 section 4.1." + m = len(Q) # Plaintext length + if m % 8 != 0: # Pad Q if needed + Q += "\x00" * (8 - (m % 8)) + R = [pack(">LL", 0xa65959a6, m)] # Magic MSB(32,A), build LSB(32,A) + R.extend(Q[i : i + 8] # Append Q + for i in xrange(0, len(Q), 8)) + n = len(R) - 1 + if n == 1: + R[0], R[1] = self._encrypt(R[0], R[1]) + else: + # RFC 3394 section 2.2.1 + for j in self._start_stop(0, 5): + for i in self._start_stop(1, n): + R[0], R[i] = self._encrypt(R[0], R[i]) + W0, W1 = unpack(">LL", R[0]) + W1 ^= n * j + i + R[0] = pack(">LL", W0, W1) + assert len(R) == (n + 1) and all(len(r) == 8 for r in R) + return "".join(R) + + def unwrap(self, C): + "RFC 5649 section 4.2." + if len(C) % 8 != 0: + raise self.UnwrapError("Ciphertext length {} is not an integral number of blocks" + .format(len(C))) + n = (len(C) / 8) - 1 + R = [C[i : i + 8] for i in xrange(0, len(C), 8)] + if n == 1: + R[0], R[1] = self._decrypt(R[0], R[1]) + else: + # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) + for j in self._start_stop(5, 0): + for i in self._start_stop(n, 1): + W0, W1 = unpack(">LL", R[0]) + W1 ^= n * j + i + R[0] = pack(">LL", W0, W1) + R[0], R[i] = self._decrypt(R[0], R[i]) + magic, m = unpack(">LL", R[0]) + if magic != 0xa65959a6: + raise self.UnwrapError("Magic value in AIV should have been 0xa65959a6, was 0x{:02x}" + .format(magic)) + if m <= 8 * (n - 1) or m > 8 * n: + raise self.UnwrapError("Length encoded in AIV out of range: m {}, n {}".format(m, n)) + R = "".join(R[1:]) + assert len(R) == 8 * n + if any(r != "\x00" for r in R[m:]): + raise self.UnwrapError("Nonzero trailing bytes {}".format(R[m:].encode("hex"))) + return R[:m] + + class Pinwheel(object): """ Activity pinwheel, as needed. @@ -901,14 +1069,16 @@ class PreloadedRSAKey(PreloadedKey): if pycrypto_loaded: k1 = RSA.importKey(pem) k2 = k1.publickey() - cls(HAL_KEY_TYPE_RSA_PRIVATE, keylen, k1, k1.exportKey(format = "DER", pkcs = 8), keylen = keylen) - cls(HAL_KEY_TYPE_RSA_PUBLIC, keylen, k2, k2.exportKey(format = "DER" ), keylen = keylen) + cls(HAL_KEY_TYPE_RSA_PRIVATE, keylen, + k1, k1.exportKey(format = "DER", pkcs = 8), keylen = keylen) + cls(HAL_KEY_TYPE_RSA_PUBLIC, keylen, + k2, k2.exportKey(format = "DER" ), keylen = keylen) def sign(self, text, hash): - return PKCS1_v1_5.PKCS115_SigScheme(self.obj).sign(hash(text)) + return PKCS115_SigScheme(self.obj).sign(hash(text)) def verify(self, text, hash, signature): - return PKCS1_v1_5.PKCS115_SigScheme(self.obj).verify(hash(text), signature) + return PKCS115_SigScheme(self.obj).verify(hash(text), signature) class PreloadedECKey(PreloadedKey): |