aboutsummaryrefslogtreecommitdiff
path: root/src/model/keywrap.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model/keywrap.py')
-rwxr-xr-xsrc/model/keywrap.py125
1 files changed, 101 insertions, 24 deletions
diff --git a/src/model/keywrap.py b/src/model/keywrap.py
index 08aac9a..747f406 100755
--- a/src/model/keywrap.py
+++ b/src/model/keywrap.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+#!/usr/bin/env python
# -*- coding: utf-8 -*-
#=======================================================================
#
@@ -43,7 +43,9 @@
# Python module imports.
#-------------------------------------------------------------------
import sys
-import aes
+import Crypto.Random
+from Crypto.Cipher import AES
+from struct import pack, unpack
#-------------------------------------------------------------------
@@ -52,37 +54,112 @@ import aes
VERBOSE = True
+
+#-------------------------------------------------------------------
+# AESKeyWrapWithPadding
+#-------------------------------------------------------------------
+class AESKeyWrapWithPadding(object):
+ """
+ Implementation of AES Key Wrap With Padding from RFC 5649.
+ """
+
+ class UnwrapError(Exception):
+ "Something went wrong during unwrap."
+
+ def __init__(self, key):
+ self.ctx = AES.new(key, AES.MODE_ECB)
+
+ def _encrypt(self, b1, b2):
+ aes_block = self.ctx.encrypt(b1 + b2)
+ return aes_block[:8], aes_block[8:]
+
+ def _decrypt(self, b1, b2):
+ aes_block = self.ctx.decrypt(b1 + b2)
+ return aes_block[:8], aes_block[8:]
+
+ @staticmethod
+ def _start_stop(start, stop): # Syntactic sugar
+ step = -1 if start > stop else 1
+ return xrange(start, stop + step, step)
+
+ @staticmethod
+ def _xor(R0, t):
+ return pack(">Q", unpack(">Q", R0)[0] ^ t)
+
+ def wrap(self, Q):
+ "RFC 5649 section 4.1."
+ m = len(Q) # Plaintext length
+ if m % 8 != 0: # Pad Q if needed
+ Q += "\x00" * (8 - (m % 8))
+ R = [pack(">LL", 0xa65959a6, m)] # Magic MSB(32,A), build LSB(32,A)
+ R.extend(Q[i : i + 8] # Append Q
+ for i in xrange(0, len(Q), 8))
+ n = len(R) - 1
+ if n == 1:
+ R[0], R[1] = self._encrypt(R[0], R[1])
+ else:
+ # RFC 3394 section 2.2.1
+ for j in self._start_stop(0, 5):
+ for i in self._start_stop(1, n):
+ R[0], R[i] = self._encrypt(R[0], R[i])
+ R[0] = self._xor(R[0], n * j + i)
+ assert len(R) == (n + 1) and all(len(r) == 8 for r in R)
+ return "".join(R)
+
+ def unwrap(self, C):
+ "RFC 5649 section 4.2."
+ if len(C) % 8 != 0:
+ raise self.UnwrapError("Ciphertext length {} is not an integral number of blocks"
+ .format(len(C)))
+ n = (len(C) / 8) - 1
+ R = [C[i : i + 8] for i in xrange(0, len(C), 8)]
+ if n == 1:
+ R[0], R[1] = self._decrypt(R[0], R[1])
+ else:
+ # RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
+ for j in self._start_stop(5, 0):
+ for i in self._start_stop(n, 1):
+ R[0] = self._xor(R[0], n * j + i)
+ R[0], R[i] = self._decrypt(R[0], R[i])
+ magic, m = unpack(">LL", R[0])
+ if magic != 0xa65959a6:
+ raise self.UnwrapError("Magic value in AIV should have been 0xa65959a6, was 0x{:02x}"
+ .format(magic))
+ if m <= 8 * (n - 1) or m > 8 * n:
+ raise self.UnwrapError("Length encoded in AIV out of range: m {}, n {}".format(m, n))
+ R = "".join(R[1:])
+ assert len(R) == 8 * n
+ if any(r != "\x00" for r in R[m:]):
+ raise self.UnwrapError("Nonzero trailing bytes {}".format(R[m:].encode("hex")))
+ return R[:m]
+
+
#-------------------------------------------------------------------
-# aes_test
+# wrap_test1
#
-# Check that the AES functionality is available and works
-# as expected.
+# First, simplest test from NIST test vectors.
#-------------------------------------------------------------------
-def keywrap_aes_test():
- nist_aes128_key = (0x2b7e1516, 0x28aed2a6, 0xabf71588, 0x09cf4f3c)
- nist_plaintext0 = (0x6bc1bee2, 0x2e409f96, 0xe93d7e11, 0x7393172a)
- nist_exp128_0 = (0x3ad77bb4, 0x0d7a3660, 0xa89ecaf3, 0x2466ef97)
- enc_result128_0 = aes.aes_encipher_block(nist_aes128_key, nist_plaintext0)
-
- print("Test 0 for AES-128.")
- print("Key:")
- aes.print_key(nist_aes128_key)
- print("Block in:")
- aes.print_block(nist_plaintext0)
- print("Expected block out:")
- aes.print_block(nist_exp128_0)
- print("Got block out:")
- aes.print_block(enc_result128_0)
- print("")
+def wrap_test1():
+ my_key = Crypto.Random.new().read(256/8)
+ my_keywrap = AESKeyWrapWithPadding(my_key)
+
+ my_plaintext = "\x31\x32\x33"
+ my_wrap = my_keywrap.wrap(my_plaintext)
+ print(type(my_wrap))
+ my_unwrap = my_keywrap.wrap(my_wrap)
+ print(type(my_unwrap))
+ print("plaintext: %s wrapped: %s unwrapped: %s" %
+ (my_plaintext, my_wrap, my_unwrap))
#-------------------------------------------------------------------
#-------------------------------------------------------------------
def main():
- print("Testing the KEY WRAP model")
- print("===========================")
+ print("Testing the Key Wrap Python model")
+ print("=================================")
print
- keywrap_aes_test()
+ # keywrap_aes_test()
+ wrap_test1()
#-------------------------------------------------------------------