diff options
author | Rob Austein <sra@hactrn.net> | 2015-05-04 18:07:02 -0400 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2015-05-04 18:07:02 -0400 |
commit | bb9f12696e626db1d17298a0eeefda00d44eb94f (patch) | |
tree | b320772c4915f428efb4cac86f62ede39d09fedd | |
parent | 865fffeafdc6622285a2dd31e17999965569312a (diff) |
Add code to log internal state, and argparse goo to control verbosity.
-rw-r--r-- | aes_keywrap.py | 359 |
1 files changed, 216 insertions, 143 deletions
diff --git a/aes_keywrap.py b/aes_keywrap.py index a191e3e..75aaa88 100644 --- a/aes_keywrap.py +++ b/aes_keywrap.py @@ -1,8 +1,10 @@ -# minas-ithil.hactrn.net:/Users/sra/cryptech/aes-keywrap.py, 30-Apr-2015 09:10:55, sra -# -# Python prototype of an AES Key Wrap implementation, RFC 5649 flavor -# per Russ, using Cryptlib to supply the AES code. -# +#!/usr/bin/env python + +""" +Python prototype of an AES Key Wrap implementation, RFC 5649 flavor +per Russ, using Cryptlib to supply the AES code. +""" + # Terminology mostly follows the RFC, including variable names. # # Block sizes get confusing: AES Key Wrap uses 64-bit blocks, not to @@ -11,21 +13,20 @@ # concatenate two 64-bit blocks just prior to performing an AES ECB # operation, then immediately split the result back into a pair of # 64-bit blocks. -# -# The spec uses both zero based and one based arrays, probably because -# that's the easiest way of coping with the extra block of ciphertext. from cryptlib_py import * from struct import pack, unpack import atexit +verbose = False + -def bin2hex(bytes): - return ":".join("%02x" % ord(b) for b in bytes) +def bin2hex(bytes, sep = ":"): + return sep.join("%02x" % ord(b) for b in bytes) def hex2bin(text): - return "".join(text.split()).translate(None, ":").decode("hex") + return text.translate(None, ": \t\n\r").decode("hex") def start_stop(start, stop): # syntactic sugar @@ -39,7 +40,6 @@ class Block(long): """ def __new__(cls, v): - # Python voodoo, nothing to see here, move along. assert v >= 0 and v.bit_length() <= 64 return super(Block, cls).__new__(cls, v) @@ -62,60 +62,59 @@ class Block(long): assert self >= 0 and self.bit_length() <= 64 return ((self >> 32) & 0xFFFFFFFF), (self & 0xFFFFFFFF) + def to_hex(self): + assert self >= 0 and self.bit_length() <= 64 + return "%016x" % self + class Buffer(array): """ Python type B array with a few extra methods. """ - def __new__(cls, initializer = None): - if initializer is None: - return super(Buffer, cls).__new__(cls, "B") - else: - return super(Buffer, cls).__new__(cls, "B", initializer) + def __new__(cls, *initializer): + return super(Buffer, cls).__new__(cls, "B", *initializer) def get_block(self, i): - return self[8*i:8*(i+1)] + return self.__class__(self[8*i:8*(i+1)]) def set_block(self, i, v): assert len(v) == 8 self[8*i:8*(i+1)] = v + def get_hex(self, i = None): + return bin2hex(self if i is None else self.get_block(i)) + class KEK(object): """ Key encryption key, based on a Cryptlib encryption context. - This can work with either Block objects or Python array. + This can work with either Block objects or Python arrays. + + Since this is a test tool used with known static keys in an attempt + to produce known results, we use a totally unsafe keying method. + Don't try this at home, kids. """ - def __init__(self, salt = None, passphrase = None, size = None, key = None, generate = False): + def __init__(self, key): self.ctx = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_AES) atexit.register(cryptDestroyContext, self.ctx) self.ctx.CTXINFO_MODE = CRYPT_MODE_ECB - if size is not None: - assert size % 8 == 0 - self.ctx.CTXINFO_KEYSIZE = size / 8 - if salt is None and passphrase is not None: - salt = "\x00" * 8 # Totally unsafe salt value, don't use this at home kids - if salt is not None: - self.ctx.CTXINFO_KEYING_SALT = salt - if passphrase is not None: - self.ctx.CTXINFO_KEYING_VALUE = passphrase - if key is not None: - self.ctx.CTXINFO_KEY = key - if generate: - cryptGenerateKey(self.ctx) - - def encrypt_block(self, b1, b2): + self.ctx.CTXINFO_KEY = key + + def encrypt_block(self, i1, i2): """ Concatenate two 64-bit blocks into a 128-bit block, encrypt it with AES-ECB, return the result split back into 64-bit blocks. """ - aes_block = array("c", pack(">QQ", b1, b2)) + aes_block = array("B", pack(">QQ", i1, i2)) cryptEncrypt(self.ctx, aes_block) - return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring())) + o1, o2 = tuple(Block(b) for b in unpack(">QQ", aes_block.tostring())) + if verbose: + print " Encrypt: %s | %s => %s | %s" % tuple(b.to_hex() for b in (i1, i2, o1, o2)) + return o1, o2 def encrypt_array(self, b1, b2): """ @@ -125,33 +124,30 @@ class KEK(object): aes_block = b1 + b2 cryptEncrypt(self.ctx, aes_block) - return aes_block[:8], aes_block[8:] + return Buffer(aes_block[:8]), Buffer(aes_block[8:]) - def decrypt_block(self, b1, b2): + def decrypt_block(self, i1, i2): """ Concatenate two 64-bit blocks into a 128-bit block, decrypt it with AES-ECB, return the result split back into 64-bit blocks. - - Blocks can be represented either as Block objects or as 8-byte - Python arrays. """ - aes_block = array("c", pack(">QQ", b1, b2)) + aes_block = array("B", pack(">QQ", i1, i2)) cryptDecrypt(self.ctx, aes_block) - return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring())) + o1, o2 = tuple(Block(b) for b in unpack(">QQ", aes_block.tostring())) + if verbose: + print " Decrypt: %s | %s => %s | %s" % tuple(b.to_hex() for b in (i1, i2, o1, o2)) + return o1, o2 def decrypt_array(self, b1, b2): """ Concatenate two 64-bit blocks into a 128-bit block, decrypt it with AES-ECB, return the result split back into 64-bit blocks. - - Blocks can be represented either as Block objects or as 8-byte - Python arrays. """ aes_block = b1 + b2 cryptDecrypt(self.ctx, aes_block) - return aes_block[:8], aes_block[8:] + return Buffer(aes_block[:8]), Buffer(aes_block[8:]) def block_wrap_key(Q, K): @@ -163,8 +159,17 @@ def block_wrap_key(Q, K): K is the KEK with which to encrypt. Returns C, the wrapped ciphertext. + + This implementation is based on Python long integers and includes + code to log internal state in verbose mode. """ + if verbose: + def log_registers(): + print " A: ", A.to_hex() + for r in xrange(1, n+1): + print " R[%3d]" % r, R[r].to_hex() + m = len(Q) if m % 8 != 0: Q += "\x00" * (8 - (m % 8)) @@ -175,22 +180,37 @@ def block_wrap_key(Q, K): assert len(P) == n P.insert(0, None) # Make P one-based - A = Block.from_words(0xA65959A6, m) # RFC 5649 section 3 AIV + A = Block.from_words(0xA65959A6, m) # RFC 5649 section 3 AIV + R = P # Alias to follow the spec + if verbose: + print " Starting wrap, n =", n + if n == 1: + if verbose: + log_registers() C = K.encrypt_block(A, P[1]) else: # RFC 3394 section 2.2.1 - R = [p for p in P] for j in start_stop(0, 5): for i in start_stop(1, n): + t = n * j + i + if verbose: + print " i = %d, j = %d, t = 0x%x" % (i, j, t) + log_registers() B_hi, B_lo = K.encrypt_block(A, R[i]) - A = Block(B_hi ^ (n * j + i)) + A = Block(B_hi ^ t) R[i] = B_lo C = R C[0] = A + if verbose: + print " Finishing wrap" + for i in xrange(len(C)): + print " C[%3d]" % i, C[i].to_hex() + print + assert len(C) == n + 1 return "".join(c.to_bytes() for c in C) @@ -204,6 +224,8 @@ def array_wrap_key(Q, K): K is the KEK with which to encrypt. Returns C, the wrapped ciphertext. + + This implementation is based on Python byte arrays. """ m = len(Q) # Plaintext length @@ -252,8 +274,17 @@ def block_unwrap_key(C, K): K is the KEK with which to decrypt. Returns Q, the unwrapped plaintext. + + This implementation is based on Python long integers and includes + code to log internal state in verbose mode. """ + if verbose: + def log_registers(): + print " A: ", A.to_hex() + for r in xrange(1, n+1): + print " R[%3d]" % r, R[r].to_hex() + if len(C) % 8 != 0: raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C)) @@ -261,26 +292,40 @@ def block_unwrap_key(C, K): C = [Block.from_bytes(C[i:i+8]) for i in xrange(0, len(C), 8)] assert len(C) == n + 1 - P = [None for i in xrange(n+1)] + P = R = C # Lots of names for the same list of blocks + A = C[0] + + if verbose: + print " Starting unwrap, n =", n if n == 1: - A, P[1] = K.decrypt_block(C[0], C[1]) + if verbose: + log_registers() + A, R[1] = K.decrypt_block(A, R[1]) else: # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) - A = C[0] - R = C for j in start_stop(5, 0): for i in start_stop(n, 1): - B_hi, B_lo = K.decrypt_block(Block(A ^ (n * j + i)), R[i]) + t = n * j + i + if verbose: + print " i = %d, j = %d, t = 0x%x" % (i, j, t) + log_registers() + B_hi, B_lo = K.decrypt_block(Block(A ^ t), R[i]) A = B_hi R[i] = B_lo - P = R + + if verbose: + print " Finishing unwrap" + print " A: ", A.to_hex() + for i in xrange(1, len(P)): + print " P[%3d]" % i, P[i].to_hex() + print magic, m = A.to_words() if magic != 0xA65959A6: - raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%08x" % magic) + raise UnwrapError("Magic value in AIV should hae been a65959a6, was %08x" % magic) if m <= 8 * (n - 1) or m > 8 * n: raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n)) @@ -303,6 +348,8 @@ def array_unwrap_key(C, K): K is the KEK with which to decrypt. Returns Q, the unwrapped plaintext. + + This implementation is based on Python byte arrays. """ if len(C) % 8 != 0: @@ -330,7 +377,7 @@ def array_unwrap_key(C, K): R.set_block(i, B2) if R[:4].tostring() != "\xa6\x59\x59\xa6": - raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%02x%02x%02x%02x" % (R[0], R[1], R[2], R[3])) + raise UnwrapError("Magic value in AIV should hae been a65959a6, was %02x%02x%02x%02x" % (R[0], R[1], R[2], R[3])) m = (((((R[4] << 8) + R[5]) << 8) + R[6]) << 8) + R[7] @@ -348,104 +395,130 @@ def array_unwrap_key(C, K): return R.tostring() -def loopback_test(K, I): - """ - Run one test. Inputs are KEK and a chunk of plaintext. +if __name__ == "__main__": - Test is just encrypt followed by decrypt to see if we can get - matching results without throwing any errors. - """ + # Test code from here down + + def loopback_test(K, I): + """ + Loopback test, just encrypt followed by decrypt to see if we can + get matching results without throwing any errors. + """ + + print "Testing:", repr(I) + C = wrap_key(I, K) + print "Wrapped: [%d]" % len(C), bin2hex(C) + O = unwrap_key(C, K) + if I != O: + raise RuntimeError("Input and output plaintext did not match: %r <> %r" % (I, O)) + print + + + def rfc5649_test(K, Q, C): + """ + Test vectors as in RFC 5649 or similar. + """ + + print "Testing: [%d]" % len(Q), bin2hex(Q) + c = wrap_key(Q, K) + + print "Wrapped: [%d]" % len(C), bin2hex(C) + q = unwrap_key(C, K) + + if q != Q: + raise RuntimeError("Input and output plaintext did not match: %s <> %s" % (bin2hex(Q), bin2hex(q))) + + if c != C: + raise RuntimeError("Input and output ciphertext did not match: %s <> %s" % (bin2hex(C), bin2hex(c))) - print "Testing:", repr(I) - C = wrap_key(I, K) - print "Wrapped: [%d]" % len(C), bin2hex(C) - O = unwrap_key(C, K) - if I != O: - raise RuntimeError("Input and output plaintext did not match: %r <> %r" % (I, O)) - print - - -def rfc5649_test(K, Q, C): - print "Testing: [%d]" % len(Q), bin2hex(Q) - print "Wrapped: [%d]" % len(C), bin2hex(C) - c = wrap_key(Q, K) - q = unwrap_key(C, K) - if q != Q: - raise RuntimeError("Input and output plaintext did not match: %s <> %s" % (bin2hex(Q), bin2hex(q))) - if c != C: - raise RuntimeError("Input and output ciphertext did not match: %s <> %s" % (bin2hex(C), bin2hex(c))) - print - - -def run_tests(): - - print "Test vectors from RFC 5649" - print - - rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), - C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a")) - - rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("466f7250617369"), - C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")) - - print "Deliberately mangled test vectors to see whether we notice" - print "These *should* detect errors" - - for d in (dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("466f7250617368"), - C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")), - dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("466f7250617368"), - C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")), - dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), - Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), - C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a"))): print - try: - rfc5649_test(**d) - except UnwrapError as e: - print "Detected an error during unwrap: %s" % e - except RuntimeError as e: - print "Detected an error in test function: %s" % e - - print - print "Loopback tests of various lengths" - print - - K = KEK(size = 128, key = hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f")) - loopback_test(K, "!") - loopback_test(K, "!") - loopback_test(K, "Yo!") - loopback_test(K, "Hi, Mom") - loopback_test(K, "1" * (64 / 8)) - loopback_test(K, "2" * (128 / 8)) - loopback_test(K, "3" * (256 / 8)) - loopback_test(K, "3.14159265358979323846264338327950288419716939937510") - loopback_test(K, "3.14159265358979323846264338327950288419716939937510") - loopback_test(K, "Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.") - - -def main(): + + + def run_tests(): + """ + Run all tests for a particular implementation. + """ + + if args.rfc5649_test_vectors: + print "Test vectors from RFC 5649" + print + + rfc5649_test(K = KEK(hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), + Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), + C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a")) + + rfc5649_test(K = KEK(hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), + Q = hex2bin("466f7250617369"), + C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")) + + if args.mangled_tests: + print "Deliberately mangled test vectors to see whether we notice" + print "These *should* detect errors" + for d in (dict(K = KEK(hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), + Q = hex2bin("466f7250617368"), + C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")), + dict(K = KEK(key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), + Q = hex2bin("466f7250617368"), + C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")), + dict(K = KEK(key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), + Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), + C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a"))): + print + try: + rfc5649_test(**d) + except UnwrapError as e: + print "Detected an error during unwrap: %s" % e + except RuntimeError as e: + print "Detected an error in test function: %s" % e + print + + if args.loopback_tests: + print "Loopback tests of various lengths" + print + K = KEK(hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f")) + loopback_test(K, "!") + loopback_test(K, "!") + loopback_test(K, "Yo!") + loopback_test(K, "Hi, Mom") + loopback_test(K, "1" * (64 / 8)) + loopback_test(K, "2" * (128 / 8)) + loopback_test(K, "3" * (256 / 8)) + loopback_test(K, "3.14159265358979323846264338327950288419716939937510") + loopback_test(K, "3.14159265358979323846264338327950288419716939937510") + loopback_test(K, "Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.") + + + # Main (test) program + + from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + + parser = ArgumentParser(description = __doc__, formatter_class = ArgumentDefaultsHelpFormatter) + parser.add_argument("-v", "--verbose", action = "store_true", + help = "bark more") + parser.add_argument("-r", "--rfc5649-test-vectors", action = "store_false", + help = "RFC 5649 test vectors") + parser.add_argument("-m", "--mangled-tests", action = "store_true", + help = "test against deliberately mangled test vectors") + parser.add_argument("-l", "--loopback-tests", action = "store_true", + help = "ad hoc collection of loopback tests") + parser.add_argument("under_test", nargs = "?", choices = ("array", "long", "both"), default = "long", + help = "implementation to test") + args = parser.parse_args() + verbose = args.verbose + cryptInit() atexit.register(cryptEnd) - global wrap_key, unwrap_key - if False: + if args.under_test in ("long", "both"): print "Testing with Block (Python long) implementation" print wrap_key = block_wrap_key unwrap_key = block_unwrap_key run_tests() - if True: + if args.under_test in ("array", "both"): print "Testing with Python array implementation" print wrap_key = array_wrap_key unwrap_key = array_unwrap_key run_tests() - - -if __name__ == "__main__": - main() |