aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2017-04-09 13:45:20 -0400
committerRob Austein <sra@hactrn.net>2017-04-09 13:45:20 -0400
commit3c7c7e20da805b3abea3fa4debdf255208164366 (patch)
tree82bbc6129aaaa23b13db2b7ef26d81e4d2133576
parentd0df322ae659b4a1f80ac57e9b20fa7464f0fb84 (diff)
Unit tests for key backup operations.
-rw-r--r--unit-tests.py202
1 files changed, 186 insertions, 16 deletions
diff --git a/unit-tests.py b/unit-tests.py
index 6fe5ccf..d5b6c77 100644
--- a/unit-tests.py
+++ b/unit-tests.py
@@ -42,28 +42,34 @@ import datetime
import logging
import sys
+from struct import pack, unpack
+
from libhal import *
try:
- from Crypto.Util.number import inverse
- from Crypto.PublicKey import RSA
- from Crypto.Signature import PKCS1_v1_5
- from Crypto.Hash.SHA256 import SHA256Hash as SHA256
- from Crypto.Hash.SHA384 import SHA384Hash as SHA384
- from Crypto.Hash.SHA512 import SHA512Hash as SHA512
+ from Crypto.Util.number import inverse
+ from Crypto.PublicKey import RSA
+ from Crypto.Cipher import AES
+ from Crypto.Cipher.PKCS1_v1_5 import PKCS115_Cipher
+ from Crypto.Signature.PKCS1_v1_5 import PKCS115_SigScheme
+ from Crypto.Hash.SHA256 import SHA256Hash as SHA256
+ from Crypto.Hash.SHA384 import SHA384Hash as SHA384
+ from Crypto.Hash.SHA512 import SHA512Hash as SHA512
pycrypto_loaded = True
except ImportError:
pycrypto_loaded = False
try:
- from ecdsa import der as ECDSA_DER
- from ecdsa.keys import SigningKey as ECDSA_SigningKey, VerifyingKey as ECDSA_VerifyingKey
- from ecdsa.ellipticcurve import Point
- from ecdsa.curves import NIST256p, NIST384p, NIST521p, find_curve as ECDSA_find_curve
- from ecdsa.util import oid_ecPublicKey
+ from ecdsa import der as ECDSA_DER
+ from ecdsa.keys import SigningKey as ECDSA_SigningKey
+ from ecdsa.keys import VerifyingKey as ECDSA_VerifyingKey
+ from ecdsa.ellipticcurve import Point
+ from ecdsa.curves import NIST256p, NIST384p, NIST521p
+ from ecdsa.curves import find_curve as ECDSA_find_curve
+ from ecdsa.util import oid_ecPublicKey
if not pycrypto_loaded:
- from hashlib import sha256 as SHA256, sha384 as SHA384, sha512 as SHA512
+ from hashlib import sha256 as SHA256, sha384 as SHA384, sha512 as SHA512
ecdsa_loaded = True
except ImportError:
ecdsa_loaded = False
@@ -852,6 +858,168 @@ class TestPkeyECDSAVerificationNIST(TestCaseLoggedIn):
py_hash = SHA384)
+@unittest.skipUnless(pycrypto_loaded, "Requires Python Crypto package")
+class TestPKeyBackup(TestCaseLoggedIn):
+
+ oid_rsaEncryption = "\x2A\x86\x48\x86\xF7\x0D\x01\x01\x01"
+ oid_aesKeyWrap = "\x60\x86\x48\x01\x65\x03\x04\x01\x30"
+
+ @staticmethod
+ def parse_EncryptedPrivateKeyInfo(der, oid):
+ from Crypto.Util.asn1 import DerObject, DerSequence, DerOctetString, DerObjectId
+ encryptedPrivateKeyInfo = DerSequence()
+ encryptedPrivateKeyInfo.decode(der)
+ encryptionAlgorithm = DerSequence()
+ encryptionAlgorithm.decode(encryptedPrivateKeyInfo[0])
+ algorithm = DerObjectId()
+ algorithm.decode(encryptionAlgorithm[0])
+ encryptedData = DerOctetString()
+ encryptedData.decode(encryptedPrivateKeyInfo[1])
+ if algorithm.payload != oid:
+ raise ValueError
+ return encryptedData.payload
+
+ @staticmethod
+ def encode_EncryptedPrivateKeyInfo(der, oid):
+ from Crypto.Util.asn1 import DerSequence, DerOctetString
+ return DerSequence([
+ DerSequence([chr(0x06) + chr(len(oid)) + oid]).encode(),
+ DerOctetString(der).encode()
+ ]).encode()
+
+ @staticmethod
+ def make_kek():
+ import Crypto.Random
+ return Crypto.Random.new().read(256/8)
+
+ def sig_check(self, pkey, der):
+ from Crypto.Util.asn1 import DerSequence, DerNull, DerOctetString
+ p115 = PKCS115_SigScheme(RSA.importKey(der))
+ hash = SHA256("Your mother was a hamster")
+ data = DerSequence([
+ DerSequence([hash.oid, DerNull().encode()]).encode(),
+ DerOctetString(hash.digest()).encode()
+ ]).encode()
+ sig1 = p115.sign(hash)
+ sig2 = pkey.sign(data = data)
+ self.assertEqual(sig1, sig2)
+ p115.verify(hash, sig1)
+ p115.verify(hash, sig2)
+ pkey.verify(signature = sig1, data = data)
+ pkey.verify(signature = sig2, data = data)
+
+ def test_export(self):
+ kekek = hsm.pkey_load(
+ flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT,
+ der = PreloadedKey.db[HAL_KEY_TYPE_RSA_PUBLIC, 1024].der)
+ self.addCleanup(kekek.delete)
+ pkey = hsm.pkey_generate_rsa(
+ keylen= 1024,
+ flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE | HAL_KEY_FLAG_EXPORTABLE)
+ self.addCleanup(pkey.delete)
+ pkcs8_der, kek_der = kekek.export_pkey(pkey)
+ kek = PKCS115_Cipher(PreloadedKey.db[HAL_KEY_TYPE_RSA_PRIVATE, 1024].obj).decrypt(
+ self.parse_EncryptedPrivateKeyInfo(kek_der, self.oid_rsaEncryption),
+ self.make_kek())
+ der = AESKeyWrapWithPadding(kek).unwrap(
+ self.parse_EncryptedPrivateKeyInfo(pkcs8_der, self.oid_aesKeyWrap))
+ self.sig_check(pkey, der)
+
+ def test_import(self):
+ kekek = hsm.pkey_generate_rsa(
+ keylen= 1024,
+ flags = HAL_KEY_FLAG_USAGE_KEYENCIPHERMENT)
+ self.addCleanup(kekek.delete)
+ kek = self.make_kek()
+ der = PreloadedKey.db[HAL_KEY_TYPE_RSA_PRIVATE, 1024].der
+ pkey = kekek.import_pkey(
+ pkcs8 = self.encode_EncryptedPrivateKeyInfo(
+ AESKeyWrapWithPadding(kek).wrap(der),
+ self.oid_aesKeyWrap),
+ kek = self.encode_EncryptedPrivateKeyInfo(
+ PKCS115_Cipher(RSA.importKey(kekek.public_key)).encrypt(kek),
+ self.oid_rsaEncryption),
+ flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE)
+ self.addCleanup(pkey.delete)
+ self.sig_check(pkey, der)
+
+
+class AESKeyWrapWithPadding(object):
+ """
+ Implementation of AES Key Wrap With Padding from RFC 5649.
+ """
+
+ class UnwrapError(Exception):
+ "Something went wrong during unwrap."
+
+ def __init__(self, key):
+ self.ctx = AES.new(key, AES.MODE_ECB)
+
+ def _encrypt(self, b1, b2):
+ aes_block = self.ctx.encrypt(b1 + b2)
+ return aes_block[:8], aes_block[8:]
+
+ def _decrypt(self, b1, b2):
+ aes_block = self.ctx.decrypt(b1 + b2)
+ return aes_block[:8], aes_block[8:]
+
+ @staticmethod
+ def _start_stop(start, stop): # Syntactic sugar
+ step = -1 if start > stop else 1
+ return xrange(start, stop + step, step)
+
+ def wrap(self, Q):
+ "RFC 5649 section 4.1."
+ m = len(Q) # Plaintext length
+ if m % 8 != 0: # Pad Q if needed
+ Q += "\x00" * (8 - (m % 8))
+ R = [pack(">LL", 0xa65959a6, m)] # Magic MSB(32,A), build LSB(32,A)
+ R.extend(Q[i : i + 8] # Append Q
+ for i in xrange(0, len(Q), 8))
+ n = len(R) - 1
+ if n == 1:
+ R[0], R[1] = self._encrypt(R[0], R[1])
+ else:
+ # RFC 3394 section 2.2.1
+ for j in self._start_stop(0, 5):
+ for i in self._start_stop(1, n):
+ R[0], R[i] = self._encrypt(R[0], R[i])
+ W0, W1 = unpack(">LL", R[0])
+ W1 ^= n * j + i
+ R[0] = pack(">LL", W0, W1)
+ assert len(R) == (n + 1) and all(len(r) == 8 for r in R)
+ return "".join(R)
+
+ def unwrap(self, C):
+ "RFC 5649 section 4.2."
+ if len(C) % 8 != 0:
+ raise self.UnwrapError("Ciphertext length {} is not an integral number of blocks"
+ .format(len(C)))
+ n = (len(C) / 8) - 1
+ R = [C[i : i + 8] for i in xrange(0, len(C), 8)]
+ if n == 1:
+ R[0], R[1] = self._decrypt(R[0], R[1])
+ else:
+ # RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
+ for j in self._start_stop(5, 0):
+ for i in self._start_stop(n, 1):
+ W0, W1 = unpack(">LL", R[0])
+ W1 ^= n * j + i
+ R[0] = pack(">LL", W0, W1)
+ R[0], R[i] = self._decrypt(R[0], R[i])
+ magic, m = unpack(">LL", R[0])
+ if magic != 0xa65959a6:
+ raise self.UnwrapError("Magic value in AIV should have been 0xa65959a6, was 0x{:02x}"
+ .format(magic))
+ if m <= 8 * (n - 1) or m > 8 * n:
+ raise self.UnwrapError("Length encoded in AIV out of range: m {}, n {}".format(m, n))
+ R = "".join(R[1:])
+ assert len(R) == 8 * n
+ if any(r != "\x00" for r in R[m:]):
+ raise self.UnwrapError("Nonzero trailing bytes {}".format(R[m:].encode("hex")))
+ return R[:m]
+
+
class Pinwheel(object):
"""
Activity pinwheel, as needed.
@@ -901,14 +1069,16 @@ class PreloadedRSAKey(PreloadedKey):
if pycrypto_loaded:
k1 = RSA.importKey(pem)
k2 = k1.publickey()
- cls(HAL_KEY_TYPE_RSA_PRIVATE, keylen, k1, k1.exportKey(format = "DER", pkcs = 8), keylen = keylen)
- cls(HAL_KEY_TYPE_RSA_PUBLIC, keylen, k2, k2.exportKey(format = "DER" ), keylen = keylen)
+ cls(HAL_KEY_TYPE_RSA_PRIVATE, keylen,
+ k1, k1.exportKey(format = "DER", pkcs = 8), keylen = keylen)
+ cls(HAL_KEY_TYPE_RSA_PUBLIC, keylen,
+ k2, k2.exportKey(format = "DER" ), keylen = keylen)
def sign(self, text, hash):
- return PKCS1_v1_5.PKCS115_SigScheme(self.obj).sign(hash(text))
+ return PKCS115_SigScheme(self.obj).sign(hash(text))
def verify(self, text, hash, signature):
- return PKCS1_v1_5.PKCS115_SigScheme(self.obj).verify(hash(text), signature)
+ return PKCS115_SigScheme(self.obj).verify(hash(text), signature)
class PreloadedECKey(PreloadedKey):