aboutsummaryrefslogtreecommitdiff
path: root/aes_keywrap.py
diff options
context:
space:
mode:
Diffstat (limited to 'aes_keywrap.py')
-rw-r--r--aes_keywrap.py451
1 files changed, 451 insertions, 0 deletions
diff --git a/aes_keywrap.py b/aes_keywrap.py
new file mode 100644
index 0000000..a191e3e
--- /dev/null
+++ b/aes_keywrap.py
@@ -0,0 +1,451 @@
+# minas-ithil.hactrn.net:/Users/sra/cryptech/aes-keywrap.py, 30-Apr-2015 09:10:55, sra
+#
+# Python prototype of an AES Key Wrap implementation, RFC 5649 flavor
+# per Russ, using Cryptlib to supply the AES code.
+#
+# Terminology mostly follows the RFC, including variable names.
+#
+# Block sizes get confusing: AES Key Wrap uses 64-bit blocks, not to
+# be confused with AES, which uses 128-bit blocks. In practice, this
+# is less confusing than when reading the description, because we
+# concatenate two 64-bit blocks just prior to performing an AES ECB
+# operation, then immediately split the result back into a pair of
+# 64-bit blocks.
+#
+# The spec uses both zero based and one based arrays, probably because
+# that's the easiest way of coping with the extra block of ciphertext.
+
+
+from cryptlib_py import *
+from struct import pack, unpack
+import atexit
+
+
+def bin2hex(bytes):
+ return ":".join("%02x" % ord(b) for b in bytes)
+
+def hex2bin(text):
+ return "".join(text.split()).translate(None, ":").decode("hex")
+
+
+def start_stop(start, stop): # syntactic sugar
+ step = -1 if start > stop else 1
+ return xrange(start, stop + step, step)
+
+
+class Block(long):
+ """
+ One 64-bit block, a Python long with some extra methods.
+ """
+
+ def __new__(cls, v):
+ # Python voodoo, nothing to see here, move along.
+ assert v >= 0 and v.bit_length() <= 64
+ return super(Block, cls).__new__(cls, v)
+
+ @classmethod
+ def from_bytes(cls, v):
+ assert isinstance(v, str) and len(v) == 8
+ return cls(unpack(">Q", v)[0])
+
+ def to_bytes(self):
+ assert self >= 0 and self.bit_length() <= 64
+ return pack(">Q", self)
+
+ @classmethod
+ def from_words(cls, hi, lo):
+ assert hi >= 0 and hi.bit_length() <= 32
+ assert lo >= 0 and lo.bit_length() <= 32
+ return cls((hi << 32L) + lo)
+
+ def to_words(self):
+ assert self >= 0 and self.bit_length() <= 64
+ return ((self >> 32) & 0xFFFFFFFF), (self & 0xFFFFFFFF)
+
+
+class Buffer(array):
+ """
+ Python type B array with a few extra methods.
+ """
+
+ def __new__(cls, initializer = None):
+ if initializer is None:
+ return super(Buffer, cls).__new__(cls, "B")
+ else:
+ return super(Buffer, cls).__new__(cls, "B", initializer)
+
+ def get_block(self, i):
+ return self[8*i:8*(i+1)]
+
+ def set_block(self, i, v):
+ assert len(v) == 8
+ self[8*i:8*(i+1)] = v
+
+
+class KEK(object):
+ """
+ Key encryption key, based on a Cryptlib encryption context.
+
+ This can work with either Block objects or Python array.
+ """
+
+ def __init__(self, salt = None, passphrase = None, size = None, key = None, generate = False):
+ self.ctx = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_AES)
+ atexit.register(cryptDestroyContext, self.ctx)
+ self.ctx.CTXINFO_MODE = CRYPT_MODE_ECB
+ if size is not None:
+ assert size % 8 == 0
+ self.ctx.CTXINFO_KEYSIZE = size / 8
+ if salt is None and passphrase is not None:
+ salt = "\x00" * 8 # Totally unsafe salt value, don't use this at home kids
+ if salt is not None:
+ self.ctx.CTXINFO_KEYING_SALT = salt
+ if passphrase is not None:
+ self.ctx.CTXINFO_KEYING_VALUE = passphrase
+ if key is not None:
+ self.ctx.CTXINFO_KEY = key
+ if generate:
+ cryptGenerateKey(self.ctx)
+
+ def encrypt_block(self, b1, b2):
+ """
+ Concatenate two 64-bit blocks into a 128-bit block, encrypt it
+ with AES-ECB, return the result split back into 64-bit blocks.
+ """
+
+ aes_block = array("c", pack(">QQ", b1, b2))
+ cryptEncrypt(self.ctx, aes_block)
+ return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring()))
+
+ def encrypt_array(self, b1, b2):
+ """
+ Concatenate two 64-bit blocks into a 128-bit block, encrypt it
+ with AES-ECB, return the result split back into 64-bit blocks.
+ """
+
+ aes_block = b1 + b2
+ cryptEncrypt(self.ctx, aes_block)
+ return aes_block[:8], aes_block[8:]
+
+ def decrypt_block(self, b1, b2):
+ """
+ Concatenate two 64-bit blocks into a 128-bit block, decrypt it
+ with AES-ECB, return the result split back into 64-bit blocks.
+
+ Blocks can be represented either as Block objects or as 8-byte
+ Python arrays.
+ """
+
+ aes_block = array("c", pack(">QQ", b1, b2))
+ cryptDecrypt(self.ctx, aes_block)
+ return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring()))
+
+ def decrypt_array(self, b1, b2):
+ """
+ Concatenate two 64-bit blocks into a 128-bit block, decrypt it
+ with AES-ECB, return the result split back into 64-bit blocks.
+
+ Blocks can be represented either as Block objects or as 8-byte
+ Python arrays.
+ """
+
+ aes_block = b1 + b2
+ cryptDecrypt(self.ctx, aes_block)
+ return aes_block[:8], aes_block[8:]
+
+
+def block_wrap_key(Q, K):
+ """
+ Wrap a key according to RFC 5649 section 4.1.
+
+ Q is the plaintext to be wrapped, a byte string.
+
+ K is the KEK with which to encrypt.
+
+ Returns C, the wrapped ciphertext.
+ """
+
+ m = len(Q)
+ if m % 8 != 0:
+ Q += "\x00" * (8 - (m % 8))
+ assert len(Q) % 8 == 0
+
+ n = len(Q) / 8
+ P = [Block.from_bytes(Q[i:i+8]) for i in xrange(0, len(Q), 8)]
+ assert len(P) == n
+
+ P.insert(0, None) # Make P one-based
+ A = Block.from_words(0xA65959A6, m) # RFC 5649 section 3 AIV
+
+ if n == 1:
+ C = K.encrypt_block(A, P[1])
+
+ else:
+ # RFC 3394 section 2.2.1
+ R = [p for p in P]
+ for j in start_stop(0, 5):
+ for i in start_stop(1, n):
+ B_hi, B_lo = K.encrypt_block(A, R[i])
+ A = Block(B_hi ^ (n * j + i))
+ R[i] = B_lo
+ C = R
+ C[0] = A
+
+ assert len(C) == n + 1
+ return "".join(c.to_bytes() for c in C)
+
+
+def array_wrap_key(Q, K):
+ """
+ Wrap a key according to RFC 5649 section 4.1.
+
+ Q is the plaintext to be wrapped, a byte string.
+
+ K is the KEK with which to encrypt.
+
+ Returns C, the wrapped ciphertext.
+ """
+
+ m = len(Q) # Plaintext length
+ R = Buffer("\xa6\x59\x59\xa6") # Magic MSB(32,A)
+ for i in xrange(24, -8, -8):
+ R.append((m >> i) & 0xFF) # Build LSB(32,A)
+ R.fromstring(Q) # Append Q
+ if m % 8 != 0: # Pad Q if needed
+ R.fromstring("\x00" * (8 - (m % 8)))
+
+ assert len(R) % 8 == 0
+ n = (len(R) / 8) - 1
+
+ if n == 1:
+ B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(1))
+ R.set_block(0, B1)
+ R.set_block(1, B2)
+
+ else:
+ # RFC 3394 section 2.2.1
+ for j in start_stop(0, 5):
+ for i in start_stop(1, n):
+ B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(i))
+ t = n * j + i
+ R.set_block(0, B1)
+ R.set_block(i, B2)
+ R[7] ^= t & 0xFF; t >>= 8
+ R[6] ^= t & 0xFF; t >>= 8
+ R[5] ^= t & 0xFF; t >>= 8
+ R[4] ^= t & 0xFF
+
+ assert len(R) == (n + 1) * 8
+ return R.tostring()
+
+
+class UnwrapError(Exception):
+ "Something went wrong during unwrap."
+
+
+def block_unwrap_key(C, K):
+ """
+ Unwrap a key according to RFC 5649 section 4.2.
+
+ C is the ciphertext to be unwrapped, a byte string
+
+ K is the KEK with which to decrypt.
+
+ Returns Q, the unwrapped plaintext.
+ """
+
+ if len(C) % 8 != 0:
+ raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C))
+
+ n = (len(C) / 8) - 1
+ C = [Block.from_bytes(C[i:i+8]) for i in xrange(0, len(C), 8)]
+ assert len(C) == n + 1
+
+ P = [None for i in xrange(n+1)]
+
+ if n == 1:
+ A, P[1] = K.decrypt_block(C[0], C[1])
+
+ else:
+ # RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
+ A = C[0]
+ R = C
+ for j in start_stop(5, 0):
+ for i in start_stop(n, 1):
+ B_hi, B_lo = K.decrypt_block(Block(A ^ (n * j + i)), R[i])
+ A = B_hi
+ R[i] = B_lo
+ P = R
+
+ magic, m = A.to_words()
+
+ if magic != 0xA65959A6:
+ raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%08x" % magic)
+
+ if m <= 8 * (n - 1) or m > 8 * n:
+ raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n))
+
+ Q = "".join(p.to_bytes() for p in P[1:])
+ assert len(Q) == 8 * n
+
+ if any(q != "\x00" for q in Q[m:]):
+ raise UnwrapError("Nonzero trailing bytes %s" % bin2hex(Q[m:]))
+
+ return Q[:m]
+
+
+def array_unwrap_key(C, K):
+ """
+ Unwrap a key according to RFC 5649 section 4.2.
+
+ C is the ciphertext to be unwrapped, a byte string
+
+ K is the KEK with which to decrypt.
+
+ Returns Q, the unwrapped plaintext.
+ """
+
+ if len(C) % 8 != 0:
+ raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C))
+
+ n = (len(C) / 8) - 1
+ R = Buffer(C)
+
+ if n == 1:
+ B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(1))
+ R.set_block(0, B1)
+ R.set_block(1, B2)
+
+ else:
+ # RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
+ for j in start_stop(5, 0):
+ for i in start_stop(n, 1):
+ t = n * j + i
+ R[7] ^= t & 0xFF; t >>= 8
+ R[6] ^= t & 0xFF; t >>= 8
+ R[5] ^= t & 0xFF; t >>= 8
+ R[4] ^= t & 0xFF
+ B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(i))
+ R.set_block(0, B1)
+ R.set_block(i, B2)
+
+ if R[:4].tostring() != "\xa6\x59\x59\xa6":
+ raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%02x%02x%02x%02x" % (R[0], R[1], R[2], R[3]))
+
+ m = (((((R[4] << 8) + R[5]) << 8) + R[6]) << 8) + R[7]
+
+ if m <= 8 * (n - 1) or m > 8 * n:
+ raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n))
+
+ del R[:8]
+ assert len(R) == 8 * n
+
+ if any(r != 0 for r in R[m:]):
+ raise UnwrapError("Nonzero trailing bytes %s" % ":".join("%02x" % r for r in R[m:]))
+
+ del R[m:]
+ assert len(R) == m
+ return R.tostring()
+
+
+def loopback_test(K, I):
+ """
+ Run one test. Inputs are KEK and a chunk of plaintext.
+
+ Test is just encrypt followed by decrypt to see if we can get
+ matching results without throwing any errors.
+ """
+
+ print "Testing:", repr(I)
+ C = wrap_key(I, K)
+ print "Wrapped: [%d]" % len(C), bin2hex(C)
+ O = unwrap_key(C, K)
+ if I != O:
+ raise RuntimeError("Input and output plaintext did not match: %r <> %r" % (I, O))
+ print
+
+
+def rfc5649_test(K, Q, C):
+ print "Testing: [%d]" % len(Q), bin2hex(Q)
+ print "Wrapped: [%d]" % len(C), bin2hex(C)
+ c = wrap_key(Q, K)
+ q = unwrap_key(C, K)
+ if q != Q:
+ raise RuntimeError("Input and output plaintext did not match: %s <> %s" % (bin2hex(Q), bin2hex(q)))
+ if c != C:
+ raise RuntimeError("Input and output ciphertext did not match: %s <> %s" % (bin2hex(C), bin2hex(c)))
+ print
+
+
+def run_tests():
+
+ print "Test vectors from RFC 5649"
+ print
+
+ rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")),
+ Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"),
+ C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a"))
+
+ rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")),
+ Q = hex2bin("466f7250617369"),
+ C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f"))
+
+ print "Deliberately mangled test vectors to see whether we notice"
+ print "These *should* detect errors"
+
+ for d in (dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")),
+ Q = hex2bin("466f7250617368"),
+ C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")),
+ dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")),
+ Q = hex2bin("466f7250617368"),
+ C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")),
+ dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")),
+ Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"),
+ C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a"))):
+ print
+ try:
+ rfc5649_test(**d)
+ except UnwrapError as e:
+ print "Detected an error during unwrap: %s" % e
+ except RuntimeError as e:
+ print "Detected an error in test function: %s" % e
+
+ print
+ print "Loopback tests of various lengths"
+ print
+
+ K = KEK(size = 128, key = hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f"))
+ loopback_test(K, "!")
+ loopback_test(K, "!")
+ loopback_test(K, "Yo!")
+ loopback_test(K, "Hi, Mom")
+ loopback_test(K, "1" * (64 / 8))
+ loopback_test(K, "2" * (128 / 8))
+ loopback_test(K, "3" * (256 / 8))
+ loopback_test(K, "3.14159265358979323846264338327950288419716939937510")
+ loopback_test(K, "3.14159265358979323846264338327950288419716939937510")
+ loopback_test(K, "Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.")
+
+
+def main():
+ cryptInit()
+ atexit.register(cryptEnd)
+ global wrap_key, unwrap_key
+
+ if False:
+ print "Testing with Block (Python long) implementation"
+ print
+ wrap_key = block_wrap_key
+ unwrap_key = block_unwrap_key
+ run_tests()
+
+ if True:
+ print "Testing with Python array implementation"
+ print
+ wrap_key = array_wrap_key
+ unwrap_key = array_unwrap_key
+ run_tests()
+
+
+if __name__ == "__main__":
+ main()