# minas-ithil.hactrn.net:/Users/sra/cryptech/aes-keywrap.py, 30-Apr-2015 09:10:55, sra # # Python prototype of an AES Key Wrap implementation, RFC 5649 flavor # per Russ, using Cryptlib to supply the AES code. # # Terminology mostly follows the RFC, including variable names. # # Block sizes get confusing: AES Key Wrap uses 64-bit blocks, not to # be confused with AES, which uses 128-bit blocks. In practice, this # is less confusing than when reading the description, because we # concatenate two 64-bit blocks just prior to performing an AES ECB # operation, then immediately split the result back into a pair of # 64-bit blocks. # # The spec uses both zero based and one based arrays, probably because # that's the easiest way of coping with the extra block of ciphertext. from cryptlib_py import * from struct import pack, unpack import atexit def bin2hex(bytes): return ":".join("%02x" % ord(b) for b in bytes) def hex2bin(text): return "".join(text.split()).translate(None, ":").decode("hex") def start_stop(start, stop): # syntactic sugar step = -1 if start > stop else 1 return xrange(start, stop + step, step) class Block(long): """ One 64-bit block, a Python long with some extra methods. """ def __new__(cls, v): # Python voodoo, nothing to see here, move along. assert v >= 0 and v.bit_length() <= 64 return super(Block, cls).__new__(cls, v) @classmethod def from_bytes(cls, v): assert isinstance(v, str) and len(v) == 8 return cls(unpack(">Q", v)[0]) def to_bytes(self): assert self >= 0 and self.bit_length() <= 64 return pack(">Q", self) @classmethod def from_words(cls, hi, lo): assert hi >= 0 and hi.bit_length() <= 32 assert lo >= 0 and lo.bit_length() <= 32 return cls((hi << 32L) + lo) def to_words(self): assert self >= 0 and self.bit_length() <= 64 return ((self >> 32) & 0xFFFFFFFF), (self & 0xFFFFFFFF) class Buffer(array): """ Python type B array with a few extra methods. """ def __new__(cls, initializer = None): if initializer is None: return super(Buffer, cls).__new__(cls, "B") else: return super(Buffer, cls).__new__(cls, "B", initializer) def get_block(self, i): return self[8*i:8*(i+1)] def set_block(self, i, v): assert len(v) == 8 self[8*i:8*(i+1)] = v class KEK(object): """ Key encryption key, based on a Cryptlib encryption context. This can work with either Block objects or Python array. """ def __init__(self, salt = None, passphrase = None, size = None, key = None, generate = False): self.ctx = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_AES) atexit.register(cryptDestroyContext, self.ctx) self.ctx.CTXINFO_MODE = CRYPT_MODE_ECB if size is not None: assert size % 8 == 0 self.ctx.CTXINFO_KEYSIZE = size / 8 if salt is None and passphrase is not None: salt = "\x00" * 8 # Totally unsafe salt value, don't use this at home kids if salt is not None: self.ctx.CTXINFO_KEYING_SALT = salt if passphrase is not None: self.ctx.CTXINFO_KEYING_VALUE = passphrase if key is not None: self.ctx.CTXINFO_KEY = key if generate: cryptGenerateKey(self.ctx) def encrypt_block(self, b1, b2): """ Concatenate two 64-bit blocks into a 128-bit block, encrypt it with AES-ECB, return the result split back into 64-bit blocks. """ aes_block = array("c", pack(">QQ", b1, b2)) cryptEncrypt(self.ctx, aes_block) return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring())) def encrypt_array(self, b1, b2): """ Concatenate two 64-bit blocks into a 128-bit block, encrypt it with AES-ECB, return the result split back into 64-bit blocks. """ aes_block = b1 + b2 cryptEncrypt(self.ctx, aes_block) return aes_block[:8], aes_block[8:] def decrypt_block(self, b1, b2): """ Concatenate two 64-bit blocks into a 128-bit block, decrypt it with AES-ECB, return the result split back into 64-bit blocks. Blocks can be represented either as Block objects or as 8-byte Python arrays. """ aes_block = array("c", pack(">QQ", b1, b2)) cryptDecrypt(self.ctx, aes_block) return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring())) def decrypt_array(self, b1, b2): """ Concatenate two 64-bit blocks into a 128-bit block, decrypt it with AES-ECB, return the result split back into 64-bit blocks. Blocks can be represented either as Block objects or as 8-byte Python arrays. """ aes_block = b1 + b2 cryptDecrypt(self.ctx, aes_block) return aes_block[:8], aes_block[8:] def block_wrap_key(Q, K): """ Wrap a key according to RFC 5649 section 4.1. Q is the plaintext to be wrapped, a byte string. K is the KEK with which to encrypt. Returns C, the wrapped ciphertext. """ m = len(Q) if m % 8 != 0: Q += "\x00" * (8 - (m % 8)) assert len(Q) % 8 == 0 n = len(Q) / 8 P = [Block.from_bytes(Q[i:i+8]) for i in xrange(0, len(Q), 8)] assert len(P) == n P.insert(0, None) # Make P one-based A = Block.from_words(0xA65959A6, m) # RFC 5649 section 3 AIV if n == 1: C = K.encrypt_block(A, P[1]) else: # RFC 3394 section 2.2.1 R = [p for p in P] for j in start_stop(0, 5): for i in start_stop(1, n): B_hi, B_lo = K.encrypt_block(A, R[i]) A = Block(B_hi ^ (n * j + i)) R[i] = B_lo C = R C[0] = A assert len(C) == n + 1 return "".join(c.to_bytes() for c in C) def array_wrap_key(Q, K): """ Wrap a key according to RFC 5649 section 4.1. Q is the plaintext to be wrapped, a byte string. K is the KEK with which to encrypt. Returns C, the wrapped ciphertext. """ m = len(Q) # Plaintext length R = Buffer("\xa6\x59\x59\xa6") # Magic MSB(32,A) for i in xrange(24, -8, -8): R.append((m >> i) & 0xFF) # Build LSB(32,A) R.fromstring(Q) # Append Q if m % 8 != 0: # Pad Q if needed R.fromstring("\x00" * (8 - (m % 8))) assert len(R) % 8 == 0 n = (len(R) / 8) - 1 if n == 1: B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(1)) R.set_block(0, B1) R.set_block(1, B2) else: # RFC 3394 section 2.2.1 for j in start_stop(0, 5): for i in start_stop(1, n): B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(i)) t = n * j + i R.set_block(0, B1) R.set_block(i, B2) R[7] ^= t & 0xFF; t >>= 8 R[6] ^= t & 0xFF; t >>= 8 R[5] ^= t & 0xFF; t >>= 8 R[4] ^= t & 0xFF assert len(R) == (n + 1) * 8 return R.tostring() class UnwrapError(Exception): "Something went wrong during unwrap." def block_unwrap_key(C, K): """ Unwrap a key according to RFC 5649 section 4.2. C is the ciphertext to be unwrapped, a byte string K is the KEK with which to decrypt. Returns Q, the unwrapped plaintext. """ if len(C) % 8 != 0: raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C)) n = (len(C) / 8) - 1 C = [Block.from_bytes(C[i:i+8]) for i in xrange(0, len(C), 8)] assert len(C) == n + 1 P = [None for i in xrange(n+1)] if n == 1: A, P[1] = K.decrypt_block(C[0], C[1]) else: # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) A = C[0] R = C for j in start_stop(5, 0): for i in start_stop(n, 1): B_hi, B_lo = K.decrypt_block(Block(A ^ (n * j + i)), R[i]) A = B_hi R[i] = B_lo P = R magic, m = A.to_words() if magic != 0xA65959A6: raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%08x" % magic) if m <= 8 * (n - 1) or m > 8 * n: raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n)) Q = "".join(p.to_bytes() for p in P[1:]) assert len(Q) == 8 * n if any(q != "\x00" for q in Q[m:]): raise UnwrapError("Nonzero trailing bytes %s" % bin2hex(Q[m:])) return Q[:m] def array_unwrap_key(C, K): """ Unwrap a key according to RFC 5649 section 4.2. C is the ciphertext to be unwrapped, a byte string K is the KEK with which to decrypt. Returns Q, the unwrapped plaintext. """ if len(C) % 8 != 0: raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C)) n = (len(C) / 8) - 1 R = Buffer(C) if n == 1: B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(1)) R.set_block(0, B1) R.set_block(1, B2) else: # RFC 3394 section 2.2.2 steps (1), (2), and part of (3) for j in start_stop(5, 0): for i in start_stop(n, 1): t = n * j + i R[7] ^= t & 0xFF; t >>= 8 R[6] ^= t & 0xFF; t >>= 8 R[5] ^= t & 0xFF; t >>= 8 R[4] ^= t & 0xFF B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(i)) R.set_block(0, B1) R.set_block(i, B2) if R[:4].tostring() != "\xa6\x59\x59\xa6": raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%02x%02x%02x%02x" % (R[0], R[1], R[2], R[3])) m = (((((R[4] << 8) + R[5]) << 8) + R[6]) << 8) + R[7] if m <= 8 * (n - 1) or m > 8 * n: raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n)) del R[:8] assert len(R) == 8 * n if any(r != 0 for r in R[m:]): raise UnwrapError("Nonzero trailing bytes %s" % ":".join("%02x" % r for r in R[m:])) del R[m:] assert len(R) == m return R.tostring() def loopback_test(K, I): """ Run one test. Inputs are KEK and a chunk of plaintext. Test is just encrypt followed by decrypt to see if we can get matching results without throwing any errors. """ print "Testing:", repr(I) C = wrap_key(I, K) print "Wrapped: [%d]" % len(C), bin2hex(C) O = unwrap_key(C, K) if I != O: raise RuntimeError("Input and output plaintext did not match: %r <> %r" % (I, O)) print def rfc5649_test(K, Q, C): print "Testing: [%d]" % len(Q), bin2hex(Q) print "Wrapped: [%d]" % len(C), bin2hex(C) c = wrap_key(Q, K) q = unwrap_key(C, K) if q != Q: raise RuntimeError("Input and output plaintext did not match: %s <> %s" % (bin2hex(Q), bin2hex(q))) if c != C: raise RuntimeError("Input and output ciphertext did not match: %s <> %s" % (bin2hex(C), bin2hex(c))) print def run_tests(): print "Test vectors from RFC 5649" print rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a")) rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), Q = hex2bin("466f7250617369"), C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")) print "Deliberately mangled test vectors to see whether we notice" print "These *should* detect errors" for d in (dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), Q = hex2bin("466f7250617368"), C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")), dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")), Q = hex2bin("466f7250617368"), C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")), dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")), Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"), C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a"))): print try: rfc5649_test(**d) except UnwrapError as e: print "Detected an error during unwrap: %s" % e except RuntimeError as e: print "Detected an error in test function: %s" % e print print "Loopback tests of various lengths" print K = KEK(size = 128, key = hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f")) loopback_test(K, "!") loopback_test(K, "!") loopback_test(K, "Yo!") loopback_test(K, "Hi, Mom") loopback_test(K, "1" * (64 / 8)) loopback_test(K, "2" * (128 / 8)) loopback_test(K, "3" * (256 / 8)) loopback_test(K, "3.14159265358979323846264338327950288419716939937510") loopback_test(K, "3.14159265358979323846264338327950288419716939937510") loopback_test(K, "Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.") def main(): cryptInit() atexit.register(cryptEnd) global wrap_key, unwrap_key if False: print "Testing with Block (Python long) implementation" print wrap_key = block_wrap_key unwrap_key = block_unwrap_key run_tests() if True: print "Testing with Python array implementation" print wrap_key = array_wrap_key unwrap_key = array_unwrap_key run_tests() if __name__ == "__main__": main()