diff options
-rw-r--r-- | aes_keywrap.py | 594 |
1 files changed, 146 insertions, 448 deletions
diff --git a/aes_keywrap.py b/aes_keywrap.py index 1d0be29..bb6968a 100644 --- a/aes_keywrap.py +++ b/aes_keywrap.py @@ -1,8 +1,8 @@ #!/usr/bin/env python """ -Python prototype of an AES Key Wrap implementation, RFC 5649 flavor -per Russ, using PyCrypto to supply the AES code. +Python implementation of RFC 5649 AES Key Wrap With Padding, +using PyCrypto to supply the AES code. """ # Terminology mostly follows the RFC, including variable names. @@ -14,505 +14,203 @@ per Russ, using PyCrypto to supply the AES code. # operation, then immediately split the result back into a pair of # 64-bit blocks. - -from struct import pack, unpack -from Crypto.Cipher import AES -from array import array - -verbose = False - - -def bin2hex(bytes, sep = ":"): - return sep.join("%02x" % ord(b) for b in bytes) - -def hex2bin(text): - return text.translate(None, ": \t\n\r").decode("hex") - - -def start_stop(start, stop): # syntactic sugar - step = -1 if start > stop else 1 - return xrange(start, stop + step, step) - - -class Block(long): - """ - One 64-bit block, a Python long with some extra methods. - """ - - def __new__(cls, v): - assert v >= 0 and v.bit_length() <= 64 - return super(Block, cls).__new__(cls, v) - - @classmethod - def from_bytes(cls, v): - assert isinstance(v, str) and len(v) == 8 - return cls(unpack(">Q", v)[0]) - - def to_bytes(self): - assert self >= 0 and self.bit_length() <= 64 - return pack(">Q", self) - - @classmethod - def from_words(cls, hi, lo): - assert hi >= 0 and hi.bit_length() <= 32 - assert lo >= 0 and lo.bit_length() <= 32 - return cls((hi << 32L) + lo) - - def to_words(self): - assert self >= 0 and self.bit_length() <= 64 - return ((self >> 32) & 0xFFFFFFFF), (self & 0xFFFFFFFF) - - def to_hex(self): - assert self >= 0 and self.bit_length() <= 64 - return "%016x" % self - - -class Buffer(array): - """ - Python type B array with a few extra methods. - """ - - def __new__(cls, *initializer): - return super(Buffer, cls).__new__(cls, "B", *initializer) - - def get_block(self, i): - return self.__class__(self[8*i:8*(i+1)]) - - def set_block(self, i, v): - assert len(v) == 8 - self[8*i:8*(i+1)] = v - - def get_hex(self, i = None): - return bin2hex(self if i is None else self.get_block(i)) - - -class KEK(object): - """ - Key encryption key, based on a PyCrypto encryption context. - - This can work with either Block objects or Python arrays. - - Since this is a test tool used with known static keys in an attempt - to produce known results, we use a totally unsafe keying method. - Don't try this at home, kids. - """ - - def __init__(self, key): - self.ctx = AES.new(key, AES.MODE_ECB) - - def encrypt_block(self, i1, i2): - """ - Concatenate two 64-bit blocks into a 128-bit block, encrypt it - with AES-ECB, return the result split back into 64-bit blocks. - """ - - aes_block = pack(">QQ", i1, i2) - aes_block = self.ctx.encrypt(aes_block) - o1, o2 = tuple(Block(b) for b in unpack(">QQ", aes_block)) - if verbose: - print " Encrypt: %s | %s => %s | %s" % tuple(b.to_hex() for b in (i1, i2, o1, o2)) - return o1, o2 - - def encrypt_array(self, b1, b2): - """ - Concatenate two 64-bit blocks into a 128-bit block, encrypt it - with AES-ECB, return the result split back into 64-bit blocks. - """ - - aes_block = b1 + b2 - aes_block = self.ctx.encrypt(aes_block.tostring()) - return Buffer(aes_block[:8]), Buffer(aes_block[8:]) - - def decrypt_block(self, i1, i2): - """ - Concatenate two 64-bit blocks into a 128-bit block, decrypt it - with AES-ECB, return the result split back into 64-bit blocks. +class AESKeyWrapWithPadding(object): """ - - aes_block = pack(">QQ", i1, i2) - aes_block = self.ctx.decrypt(aes_block) - o1, o2 = tuple(Block(b) for b in unpack(">QQ", aes_block)) - if verbose: - print " Decrypt: %s | %s => %s | %s" % tuple(b.to_hex() for b in (i1, i2, o1, o2)) - return o1, o2 - - def decrypt_array(self, b1, b2): - """ - Concatenate two 64-bit blocks into a 128-bit block, decrypt it - with AES-ECB, return the result split back into 64-bit blocks. + Implementation of AES Key Wrap With Padding from RFC 5649. """ - aes_block = b1 + b2 - aes_block = self.ctx.decrypt(aes_block.tostring()) - return Buffer(aes_block[:8]), Buffer(aes_block[8:]) - - -def block_wrap_key(Q, K): - """ - Wrap a key according to RFC 5649 section 4.1. - - Q is the plaintext to be wrapped, a byte string. - - K is the KEK with which to encrypt. - - Returns C, the wrapped ciphertext. - - This implementation is based on Python long integers and includes - code to log internal state in verbose mode. - """ - - if verbose: - def log_registers(): - print " A: ", A.to_hex() - for r in xrange(1, n+1): - print " R[%3d]" % r, R[r].to_hex() + class UnwrapError(Exception): + "Something went wrong during unwrap." - m = len(Q) - if m % 8 != 0: - Q += "\x00" * (8 - (m % 8)) - assert len(Q) % 8 == 0 + def __init__(self, key): + from Crypto.Cipher import AES + self.ctx = AES.new(key, AES.MODE_ECB) - n = len(Q) / 8 - P = [Block.from_bytes(Q[i:i+8]) for i in xrange(0, len(Q), 8)] - assert len(P) == n + def _encrypt(self, b1, b2): + aes_block = self.ctx.encrypt(b1 + b2) + return aes_block[:8], aes_block[8:] - P.insert(0, None) # Make P one-based - A = Block.from_words(0xA65959A6, m) # RFC 5649 section 3 AIV - R = P # Alias to follow the spec - - if verbose: - print " Starting wrap, n =", n + def _decrypt(self, b1, b2): + aes_block = self.ctx.decrypt(b1 + b2) + return aes_block[:8], aes_block[8:] - if n == 1: - if verbose: - log_registers() - C = K.encrypt_block(A, P[1]) + @staticmethod + def _start_stop(start, stop): # Syntactic sugar + step = -1 if start > stop else 1 + return xrange(start, stop + step, step) - else: - # RFC 3394 section 2.2.1 - for j in start_stop(0, 5): - for i in start_stop(1, n): - t = n * j + i - if verbose: - print " i = %d, j = %d, t = 0x%x" % (i, j, t) - log_registers() - B_hi, B_lo = K.encrypt_block(A, R[i]) - A = Block(B_hi ^ t) - R[i] = B_lo - C = R - C[0] = A - if verbose: - print " Finishing wrap" - for i in xrange(len(C)): - print " C[%3d]" % i, C[i].to_hex() - print - - assert len(C) == n + 1 - return "".join(c.to_bytes() for c in C) - - -def array_wrap_key(Q, K): - """ - Wrap a key according to RFC 5649 section 4.1. - - Q is the plaintext to be wrapped, a byte string. - - K is the KEK with which to encrypt. - - Returns C, the wrapped ciphertext. - - This implementation is based on Python byte arrays. - """ - - m = len(Q) # Plaintext length - R = Buffer("\xa6\x59\x59\xa6") # Magic MSB(32,A) - for i in xrange(24, -8, -8): - R.append((m >> i) & 0xFF) # Build LSB(32,A) - R.fromstring(Q) # Append Q - if m % 8 != 0: # Pad Q if needed - R.fromstring("\x00" * (8 - (m % 8))) - - assert len(R) % 8 == 0 - n = (len(R) / 8) - 1 + def wrap_key(self, Q): + """ + Wrap a key according to RFC 5649 section 4.1. - if n == 1: - B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(1)) - R.set_block(0, B1) - R.set_block(1, B2) + Q is the plaintext to be wrapped, a byte string. - else: - # RFC 3394 section 2.2.1 - for j in start_stop(0, 5): - for i in start_stop(1, n): - B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(i)) - t = n * j + i - R.set_block(0, B1) - R.set_block(i, B2) - R[7] ^= t & 0xFF; t >>= 8 - R[6] ^= t & 0xFF; t >>= 8 - R[5] ^= t & 0xFF; t >>= 8 - R[4] ^= t & 0xFF + Returns C, the wrapped ciphertext. + """ - assert len(R) == (n + 1) * 8 - return R.tostring() + from struct import pack, unpack + m = len(Q) # Plaintext length + if m % 8 != 0: # Pad Q if needed + Q += "\x00" * (8 - (m % 8)) + R = [pack(">LL", 0xa65959a6, m)] # Magic MSB(32,A), build LSB(32,A) + R.extend(Q[i : i + 8] # Append Q + for i in xrange(0, len(Q), 8)) -class UnwrapError(Exception): - "Something went wrong during unwrap." - - -def block_unwrap_key(C, K): - """ - Unwrap a key according to RFC 5649 section 4.2. - - C is the ciphertext to be unwrapped, a byte string - - K is the KEK with which to decrypt. - - Returns Q, the unwrapped plaintext. - - This implementation is based on Python long integers and includes - code to log internal state in verbose mode. - """ - - if verbose: - def log_registers(): - print " A: ", A.to_hex() - for r in xrange(1, n+1): - print " R[%3d]" % r, R[r].to_hex() + n = len(R) - 1 - if len(C) % 8 != 0: - raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C)) + if n == 1: + R[0], R[1] = self._encrypt(R[0], R[1]) - n = (len(C) / 8) - 1 - C = [Block.from_bytes(C[i:i+8]) for i in xrange(0, len(C), 8)] - assert len(C) == n + 1 + else: + # RFC 3394 section 2.2.1 + for j in self._start_stop(0, 5): + for i in self._start_stop(1, n): + R[0], R[i] = self._encrypt(R[0], R[i]) + W0, W1 = unpack(">LL", R[0]) + W1 ^= n * j + i + R[0] = pack(">LL", W0, W1) - P = R = C # Lots of names for the same list of blocks - A = C[0] - - if verbose: - print " Starting unwrap, n =", n + assert len(R) == (n + 1) and all(len(r) == 8 for r in R) + return "".join(R) - if n == 1: - if verbose: - log_registers() - A, R[1] = K.decrypt_block(A, R[1]) - else: - # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) - for j in start_stop(5, 0): - for i in start_stop(n, 1): - t = n * j + i - if verbose: - print " i = %d, j = %d, t = 0x%x" % (i, j, t) - log_registers() - B_hi, B_lo = K.decrypt_block(Block(A ^ t), R[i]) - A = B_hi - R[i] = B_lo + def unwrap_key(self, C): + """ + Unwrap a key according to RFC 5649 section 4.2. - if verbose: - print " Finishing unwrap" - print " A: ", A.to_hex() - for i in xrange(1, len(P)): - print " P[%3d]" % i, P[i].to_hex() - print + C is the ciphertext to be unwrapped, a byte string - magic, m = A.to_words() + Returns Q, the unwrapped plaintext. + """ - if magic != 0xA65959A6: - raise UnwrapError("Magic value in AIV should hae been a65959a6, was %08x" % magic) + from struct import pack, unpack - if m <= 8 * (n - 1) or m > 8 * n: - raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n)) + if len(C) % 8 != 0: + raise self.UnwrapError("Ciphertext length {} is not an integral number of blocks" + .format(len(C))) - Q = "".join(p.to_bytes() for p in P[1:]) - assert len(Q) == 8 * n + n = (len(C) / 8) - 1 + R = [C[i : i + 8] for i in xrange(0, len(C), 8)] - if any(q != "\x00" for q in Q[m:]): - raise UnwrapError("Nonzero trailing bytes %s" % bin2hex(Q[m:])) + if n == 1: + R[0], R[1] = self._decrypt(R[0], R[1]) - return Q[:m] + else: + # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) + for j in self._start_stop(5, 0): + for i in self._start_stop(n, 1): + W0, W1 = unpack(">LL", R[0]) + W1 ^= n * j + i + R[0] = pack(">LL", W0, W1) + R[0], R[i] = self._decrypt(R[0], R[i]) + magic, m = unpack(">LL", R[0]) -def array_unwrap_key(C, K): - """ - Unwrap a key according to RFC 5649 section 4.2. + if magic != 0xa65959a6: + raise self.UnwrapError("Magic value in AIV should have been 0xa65959a6, was 0x{:02x}" + .format(magic)) - C is the ciphertext to be unwrapped, a byte string + if m <= 8 * (n - 1) or m > 8 * n: + raise self.UnwrapError("Length encoded in AIV out of range: m {}, n {}".format(m, n)) - K is the KEK with which to decrypt. + R = "".join(R[1:]) + assert len(R) == 8 * n - Returns Q, the unwrapped plaintext. + if any(r != "\x00" for r in R[m:]): + raise self.UnwrapError("Nonzero trailing bytes {}".format(R[m:].encode("hex"))) - This implementation is based on Python byte arrays. - """ + return R[:m] - if len(C) % 8 != 0: - raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C)) - n = (len(C) / 8) - 1 - R = Buffer(C) - - if n == 1: - B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(1)) - R.set_block(0, B1) - R.set_block(1, B2) +if __name__ == "__main__": - else: - # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) - for j in start_stop(5, 0): - for i in start_stop(n, 1): - t = n * j + i - R[7] ^= t & 0xFF; t >>= 8 - R[6] ^= t & 0xFF; t >>= 8 - R[5] ^= t & 0xFF; t >>= 8 - R[4] ^= t & 0xFF - B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(i)) - R.set_block(0, B1) - R.set_block(i, B2) + # Test code from here down - if R[:4].tostring() != "\xa6\x59\x59\xa6": - raise UnwrapError("Magic value in AIV should hae been a65959a6, was %02x%02x%02x%02x" % (R[0], R[1], R[2], R[3])) + import unittest - m = (((((R[4] << 8) + R[5]) << 8) + R[6]) << 8) + R[7] + class TestAESKeyWrapWithPadding(unittest.TestCase): - if m <= 8 * (n - 1) or m > 8 * n: - raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n)) + @staticmethod + def bin2hex(bytes, sep = ":"): + return sep.join("{:02x}".format(ord(b)) for b in bytes) - del R[:8] - assert len(R) == 8 * n + @staticmethod + def hex2bin(text): + return text.translate(None, ": \t\n\r").decode("hex") - if any(r != 0 for r in R[m:]): - raise UnwrapError("Nonzero trailing bytes %s" % ":".join("%02x" % r for r in R[m:])) + def loopback_test(self, I): + K = AESKeyWrapWithPadding(self.hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f")) + C = K.wrap_key(I) + O = K.unwrap_key(C) + if I != O: + raise RuntimeError("Input and output plaintext did not match: {!r} <> {!r}".format(I, O)) - del R[m:] - assert len(R) == m - return R.tostring() + def rfc5649_test(self, K, Q, C): + K = AESKeyWrapWithPadding(key = self.hex2bin(K)) + Q = self.hex2bin(Q) + C = self.hex2bin(C) + c = K.wrap_key(Q) + q = K.unwrap_key(C) + if q != Q: + raise RuntimeError("Input and output plaintext did not match: {} <> {}".format(self.bin2hex(Q), self.bin2hex(q))) + if c != C: + raise RuntimeError("Input and output ciphertext did not match: {} <> {}".format(self.bin2hex(C), self.bin2hex(c))) + def test_rfc5649_1(self): + self.rfc5649_test(K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8", + Q = "c37b7e6492584340 bed1220780894115 5068f738", + C = "138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a") -if __name__ == "__main__": + def test_rfc5649_2(self): + self.rfc5649_test(K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8", + Q = "466f7250617369", + C = "afbeb0f07dfbf541 9200f2ccb50bb24f") - # Test code from here down + def test_mangled_1(self): + self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test, + K = "5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8", + Q = "466f7250617368", + C = "afbeb0f07dfbf541 9200f2ccb50bb24f") - def loopback_test(K, I): - """ - Loopback test, just encrypt followed by decrypt to see if we can - get matching results without throwing any errors. - """ + def test_mangled_2(self): + self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test, + K = "5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8", + Q = "466f7250617368", + C = "afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef") - print "Testing:", repr(I) - C = wrap_key(I, K) - print "Wrapped: [%d]" % len(C), bin2hex(C) - O = unwrap_key(C, K) - if I != O: - raise RuntimeError("Input and output plaintext did not match: %r <> %r" % (I, O)) - print + def test_mangled_3(self): + self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test, + K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8", + Q = "c37b7e6492584340 bed1220780894115 5068f738", + C = "138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a") + def test_loopback_1(self): + self.loopback_test("!") - def rfc5649_test(K, Q, C): - """ - Test vectors as in RFC 5649 or similar. - """ + def test_loopback_2(self): + self.loopback_test("Yo!") - print "Testing: [%d]" % len(Q), bin2hex(Q) - c = wrap_key(Q, K) + def test_loopback_3(self): + self.loopback_test("Hi, Mom") - print "Wrapped: [%d]" % len(C), bin2hex(C) - q = unwrap_key(C, K) + def test_loopback_4(self): + self.loopback_test("1" * (64 / 8)) - if q != Q: - raise RuntimeError("Input and output plaintext did not match: %s <> %s" % (bin2hex(Q), bin2hex(q))) + def test_loopback_5(self): + self.loopback_test("2" * (128 / 8)) - if c != C: - raise RuntimeError("Input and output ciphertext did not match: %s <> %s" % (bin2hex(C), bin2hex(c))) + def test_loopback_6(self): + self.loopback_test("3" * (256 / 8)) - print + def test_loopback_7(self): + self.loopback_test("3.14159265358979323846264338327950288419716939937510") + def test_loopback_8(self): + self.loopback_test("3.14159265358979323846264338327950288419716939937510") - def run_tests(): - """ - Run all tests for a particular implementation. - """ + def test_loopback_9(self): + self.loopback_test("Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.") - if args.rfc5649_test_vectors: - print "Test vectors from RFC 5649" - print - - rfc5649_test(K = KEK(hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), - C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a")) - - rfc5649_test(K = KEK(hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("466f7250617369"), - C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")) - - if args.mangled_tests: - print "Deliberately mangled test vectors to see whether we notice" - print "These *should* detect errors" - for d in (dict(K = KEK(hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("466f7250617368"), - C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")), - dict(K = KEK(key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("466f7250617368"), - C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")), - dict(K = KEK(key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), - C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a"))): - print - try: - rfc5649_test(**d) - except UnwrapError as e: - print "Detected an error during unwrap: %s" % e - except RuntimeError as e: - print "Detected an error in test function: %s" % e - print - - if args.loopback_tests: - print "Loopback tests of various lengths" - print - K = KEK(hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f")) - loopback_test(K, "!") - loopback_test(K, "!") - loopback_test(K, "Yo!") - loopback_test(K, "Hi, Mom") - loopback_test(K, "1" * (64 / 8)) - loopback_test(K, "2" * (128 / 8)) - loopback_test(K, "3" * (256 / 8)) - loopback_test(K, "3.14159265358979323846264338327950288419716939937510") - loopback_test(K, "3.14159265358979323846264338327950288419716939937510") - loopback_test(K, "Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.") - - - # Main (test) program - - from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter - - parser = ArgumentParser(description = __doc__, formatter_class = ArgumentDefaultsHelpFormatter) - parser.add_argument("-v", "--verbose", action = "store_true", - help = "bark more") - parser.add_argument("-r", "--rfc5649-test-vectors", action = "store_false", - help = "RFC 5649 test vectors") - parser.add_argument("-m", "--mangled-tests", action = "store_true", - help = "test against deliberately mangled test vectors") - parser.add_argument("-l", "--loopback-tests", action = "store_true", - help = "ad hoc collection of loopback tests") - parser.add_argument("under_test", nargs = "?", choices = ("array", "long", "both"), default = "long", - help = "implementation to test") - args = parser.parse_args() - verbose = args.verbose - - if args.under_test in ("long", "both"): - print "Testing with Block (Python long) implementation" - print - wrap_key = block_wrap_key - unwrap_key = block_unwrap_key - run_tests() - - if args.under_test in ("array", "both"): - print "Testing with Python array implementation" - print - wrap_key = array_wrap_key - unwrap_key = array_unwrap_key - run_tests() + unittest.main() |