# minas-ithil.hactrn.net:/Users/sra/cryptech/aes-keywrap.py, 30-Apr-2015 09:10:55, sra
#
# Python prototype of an AES Key Wrap implementation, RFC 5649 flavor
# per Russ, using Cryptlib to supply the AES code.
#
# Terminology mostly follows the RFC, including variable names.
#
# Block sizes get confusing: AES Key Wrap uses 64-bit blocks, not to
# be confused with AES, which uses 128-bit blocks. In practice, this
# is less confusing than when reading the description, because we
# concatenate two 64-bit blocks just prior to performing an AES ECB
# operation, then immediately split the result back into a pair of
# 64-bit blocks.
#
# The spec uses both zero based and one based arrays, probably because
# that's the easiest way of coping with the extra block of ciphertext.
from cryptlib_py import *
from struct import pack, unpack
import atexit
def bin2hex(bytes):
return ":".join("%02x" % ord(b) for b in bytes)
def hex2bin(text):
return "".join(text.split()).translate(None, ":").decode("hex")
def start_stop(start, stop): # syntactic sugar
step = -1 if start > stop else 1
return xrange(start, stop + step, step)
class Block(long):
"""
One 64-bit block, a Python long with some extra methods.
"""
def __new__(cls, v):
# Python voodoo, nothing to see here, move along.
assert v >= 0 and v.bit_length() <= 64
return super(Block, cls).__new__(cls, v)
@classmethod
def from_bytes(cls, v):
assert isinstance(v, str) and len(v) == 8
return cls(unpack(">Q", v)[0])
def to_bytes(self):
assert self >= 0 and self.bit_length() <= 64
return pack(">Q", self)
@classmethod
def from_words(cls, hi, lo):
assert hi >= 0 and hi.bit_length() <= 32
assert lo >= 0 and lo.bit_length() <= 32
return cls((hi << 32L) + lo)
def to_words(self):
assert self >= 0 and self.bit_length() <= 64
return ((self >> 32) & 0xFFFFFFFF), (self & 0xFFFFFFFF)
class Buffer(array):
"""
Python type B array with a few extra methods.
"""
def __new__(cls, initializer = None):
if initializer is None:
return super(Buffer, cls).__new__(cls, "B")
else:
return super(Buffer, cls).__new__(cls, "B", initializer)
def get_block(self, i):
return self[8*i:8*(i+1)]
def set_block(self, i, v):
assert len(v) == 8
self[8*i:8*(i+1)] = v
class KEK(object):
"""
Key encryption key, based on a Cryptlib encryption context.
This can work with either Block objects or Python array.
"""
def __init__(self, salt = None, passphrase = None, size = None, key = None, generate = False):
self.ctx = cryptCreateContext(CRYPT_UNUSED, CRYPT_ALGO_AES)
atexit.register(cryptDestroyContext, self.ctx)
self.ctx.CTXINFO_MODE = CRYPT_MODE_ECB
if size is not None:
assert size % 8 == 0
self.ctx.CTXINFO_KEYSIZE = size / 8
if salt is None and passphrase is not None:
salt = "\x00" * 8 # Totally unsafe salt value, don't use this at home kids
if salt is not None:
self.ctx.CTXINFO_KEYING_SALT = salt
if passphrase is not None:
self.ctx.CTXINFO_KEYING_VALUE = passphrase
if key is not None:
self.ctx.CTXINFO_KEY = key
if generate:
cryptGenerateKey(self.ctx)
def encrypt_block(self, b1, b2):
"""
Concatenate two 64-bit blocks into a 128-bit block, encrypt it
with AES-ECB, return the result split back into 64-bit blocks.
"""
aes_block = array("c", pack(">QQ", b1, b2))
cryptEncrypt(self.ctx, aes_block)
return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring()))
def encrypt_array(self, b1, b2):
"""
Concatenate two 64-bit blocks into a 128-bit block, encrypt it
with AES-ECB, return the result split back into 64-bit blocks.
"""
aes_block = b1 + b2
cryptEncrypt(self.ctx, aes_block)
return aes_block[:8], aes_block[8:]
def decrypt_block(self, b1, b2):
"""
Concatenate two 64-bit blocks into a 128-bit block, decrypt it
with AES-ECB, return the result split back into 64-bit blocks.
Blocks can be represented either as Block objects or as 8-byte
Python arrays.
"""
aes_block = array("c", pack(">QQ", b1, b2))
cryptDecrypt(self.ctx, aes_block)
return tuple(Block(b) for b in unpack(">QQ", aes_block.tostring()))
def decrypt_array(self, b1, b2):
"""
Concatenate two 64-bit blocks into a 128-bit block, decrypt it
with AES-ECB, return the result split back into 64-bit blocks.
Blocks can be represented either as Block objects or as 8-byte
Python arrays.
"""
aes_block = b1 + b2
cryptDecrypt(self.ctx, aes_block)
return aes_block[:8], aes_block[8:]
def block_wrap_key(Q, K):
"""
Wrap a key according to RFC 5649 section 4.1.
Q is the plaintext to be wrapped, a byte string.
K is the KEK with which to encrypt.
Returns C, the wrapped ciphertext.
"""
m = len(Q)
if m % 8 != 0:
Q += "\x00" * (8 - (m % 8))
assert len(Q) % 8 == 0
n = len(Q) / 8
P = [Block.from_bytes(Q[i:i+8]) for i in xrange(0, len(Q), 8)]
assert len(P) == n
P.insert(0, None) # Make P one-based
A = Block.from_words(0xA65959A6, m) # RFC 5649 section 3 AIV
if n == 1:
C = K.encrypt_block(A, P[1])
else:
# RFC 3394 section 2.2.1
R = [p for p in P]
for j in start_stop(0, 5):
for i in start_stop(1, n):
B_hi, B_lo = K.encrypt_block(A, R[i])
A = Block(B_hi ^ (n * j + i))
R[i] = B_lo
C = R
C[0] = A
assert len(C) == n + 1
return "".join(c.to_bytes() for c in C)
def array_wrap_key(Q, K):
"""
Wrap a key according to RFC 5649 section 4.1.
Q is the plaintext to be wrapped, a byte string.
K is the KEK with which to encrypt.
Returns C, the wrapped ciphertext.
"""
m = len(Q) # Plaintext length
R = Buffer("\xa6\x59\x59\xa6") # Magic MSB(32,A)
for i in xrange(24, -8, -8):
R.append((m >> i) & 0xFF) # Build LSB(32,A)
R.fromstring(Q) # Append Q
if m % 8 != 0: # Pad Q if needed
R.fromstring("\x00" * (8 - (m % 8)))
assert len(R) % 8 == 0
n = (len(R) / 8) - 1
if n == 1:
B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(1))
R.set_block(0, B1)
R.set_block(1, B2)
else:
# RFC 3394 section 2.2.1
for j in start_stop(0, 5):
for i in start_stop(1, n):
B1, B2 = K.encrypt_array(R.get_block(0), R.get_block(i))
t = n * j + i
R.set_block(0, B1)
R.set_block(i, B2)
R[7] ^= t & 0xFF; t >>= 8
R[6] ^= t & 0xFF; t >>= 8
R[5] ^= t & 0xFF; t >>= 8
R[4] ^= t & 0xFF
assert len(R) == (n + 1) * 8
return R.tostring()
class UnwrapError(Exception):
"Something went wrong during unwrap."
def block_unwrap_key(C, K):
"""
Unwrap a key according to RFC 5649 section 4.2.
C is the ciphertext to be unwrapped, a byte string
K is the KEK with which to decrypt.
Returns Q, the unwrapped plaintext.
"""
if len(C) % 8 != 0:
raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C))
n = (len(C) / 8) - 1
C = [Block.from_bytes(C[i:i+8]) for i in xrange(0, len(C), 8)]
assert len(C) == n + 1
P = [None for i in xrange(n+1)]
if n == 1:
A, P[1] = K.decrypt_block(C[0], C[1])
else:
# RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
A = C[0]
R = C
for j in start_stop(5, 0):
for i in start_stop(n, 1):
B_hi, B_lo = K.decrypt_block(Block(A ^ (n * j + i)), R[i])
A = B_hi
R[i] = B_lo
P = R
magic, m = A.to_words()
if magic != 0xA65959A6:
raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%08x" % magic)
if m <= 8 * (n - 1) or m > 8 * n:
raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n))
Q = "".join(p.to_bytes() for p in P[1:])
assert len(Q) == 8 * n
if any(q != "\x00" for q in Q[m:]):
raise UnwrapError("Nonzero trailing bytes %s" % bin2hex(Q[m:]))
return Q[:m]
def array_unwrap_key(C, K):
"""
Unwrap a key according to RFC 5649 section 4.2.
C is the ciphertext to be unwrapped, a byte string
K is the KEK with which to decrypt.
Returns Q, the unwrapped plaintext.
"""
if len(C) % 8 != 0:
raise UnwrapError("Ciphertext length %d is not an integral number of blocks" % len(C))
n = (len(C) / 8) - 1
R = Buffer(C)
if n == 1:
B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(1))
R.set_block(0, B1)
R.set_block(1, B2)
else:
# RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
for j in start_stop(5, 0):
for i in start_stop(n, 1):
t = n * j + i
R[7] ^= t & 0xFF; t >>= 8
R[6] ^= t & 0xFF; t >>= 8
R[5] ^= t & 0xFF; t >>= 8
R[4] ^= t & 0xFF
B1, B2 = K.decrypt_array(R.get_block(0), R.get_block(i))
R.set_block(0, B1)
R.set_block(i, B2)
if R[:4].tostring() != "\xa6\x59\x59\xa6":
raise UnwrapError("Magic value in AIV should hae been 0xA65959A6, was 0x%02x%02x%02x%02x" % (R[0], R[1], R[2], R[3]))
m = (((((R[4] << 8) + R[5]) << 8) + R[6]) << 8) + R[7]
if m <= 8 * (n - 1) or m > 8 * n:
raise UnwrapError("Length encoded in AIV out of range: m %d, n %d" % (m, n))
del R[:8]
assert len(R) == 8 * n
if any(r != 0 for r in R[m:]):
raise UnwrapError("Nonzero trailing bytes %s" % ":".join("%02x" % r for r in R[m:]))
del R[m:]
assert len(R) == m
return R.tostring()
def loopback_test(K, I):
"""
Run one test. Inputs are KEK and a chunk of plaintext.
Test is just encrypt followed by decrypt to see if we can get
matching results without throwing any errors.
"""
print "Testing:", repr(I)
C = wrap_key(I, K)
print "Wrapped: [%d]" % len(C), bin2hex(C)
O = unwrap_key(C, K)
if I != O:
raise RuntimeError("Input and output plaintext did not match: %r <> %r" % (I, O))
print
def rfc5649_test(K, Q, C):
print "Testing: [%d]" % len(Q), bin2hex(Q)
print "Wrapped: [%d]" % len(C), bin2hex(C)
c = wrap_key(Q, K)
q = unwrap_key(C, K)
if q != Q:
raise RuntimeError("Input and output plaintext did not match: %s <> %s" % (bin2hex(Q), bin2hex(q)))
if c != C:
raise RuntimeError("Input and output ciphertext did not match: %s <> %s" % (bin2hex(C), bin2hex(c)))
print
def run_tests():
print "Test vectors from RFC 5649"
print
rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")),
Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"),
C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a"))
rfc5649_test(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")),
Q = hex2bin("466f7250617369"),
C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f"))
print "Deliberately mangled test vectors to see whether we notice"
print "These *should* detect errors"
for d in (dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")),
Q = hex2bin("466f7250617368"),
C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f")),
dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8")),
Q = hex2bin("466f7250617368"),
C = hex2bin("afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")),
dict(K = KEK(size = 192, key = hex2bin("5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8")),
Q = hex2bin("c37b7e6492584340 bed1220780894115 5068f738"),
C = hex2bin("138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a"))):
print
try:
rfc5649_test(**d)
except UnwrapError as e:
print "Detected an error during unwrap: %s" % e
except RuntimeError as e:
print "Detected an error in test function: %s" % e
print
print "Loopback tests of various lengths"
print
K = KEK(size = 128, key = hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f"))
loopback_test(K, "!")
loopback_test(K, "!")
loopback_test(K, "Yo!")
loopback_test(K, "Hi, Mom")
loopback_test(K, "1" * (64 / 8))
loopback_test(K, "2" * (128 / 8))
loopback_test(K, "3" * (256 / 8))
loopback_test(K, "3.14159265358979323846264338327950288419716939937510")
loopback_test(K, "3.14159265358979323846264338327950288419716939937510")
loopback_test(K, "Hello! My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.")
def main():
cryptInit()
atexit.register(cryptEnd)
global wrap_key, unwrap_key
if False:
print "Testing with Block (Python long) implementation"
print
wrap_key = block_wrap_key
unwrap_key = block_unwrap_key
run_tests()
if True:
print "Testing with Python array implementation"
print
wrap_key = array_wrap_key
unwrap_key = array_unwrap_key
run_tests()
if __name__ == "__main__":
main()