aboutsummaryrefslogtreecommitdiff
path: root/src/model/aes_keywrap.py
blob: 8b9d2b076eb58c560c548cfedbb8bc644e06e921 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#!/usr/bin/env python
#======================================================================
#
# aes_keywrap.py
# --------------
# Python funnctional model of AES Key Wrap including test cases.
# Used to generate test vectors for internal states to drive
# verification of the hardware implementation.
#
#
# Terminology mostly follows the RFC, including variable names.
#
# Block sizes get confusing: AES Key Wrap uses 64-bit blocks, not to
# be confused with AES, which uses 128-bit blocks.  In practice, this
# is less confusing than when reading the description, because we
# concatenate two 64-bit blocks just prior to performing an AES ECB
# operation, then immediately split the result back into a pair of
# 64-bit blocks.
#
#
# Copyright (c) 2018, NORDUnet A/S
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# - Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# - Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in the
#   documentation and/or other materials provided with the distribution.
#
# - Neither the name of the NORDUnet nor the names of its contributors may
#   be used to endorse or promote products derived from this software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#======================================================================

from struct import pack, unpack
from os import urandom
from Crypto.Cipher import AES
import unittest


verbose = True


class AESKeyWrapWithPadding(object):
    """
    Implementation of AES Key Wrap With Padding from RFC 5649.
    using PyCrypto to supply the AES code.
    """

    class UnwrapError(Exception):
        "Something went wrong during unwrap."

    def __init__(self, key):
        self.key = key
        self.ctx = AES.new(key, AES.MODE_ECB)

    def _encrypt(self, b1, b2):
        aes_block = self.ctx.encrypt(b1 + b2)
        return aes_block[:8], aes_block[8:]

    def _decrypt(self, b1, b2):
        aes_block = self.ctx.decrypt(b1 + b2)
        return aes_block[:8], aes_block[8:]

    @staticmethod
    def _start_stop(start, stop):                    # Syntactic sugar
        step = -1 if start > stop else 1
        return xrange(start, stop + step, step)

    @staticmethod
    def bin2hex(bytes, sep = ":"):
        return sep.join("{:02x}".format(ord(b)) for b in bytes)

    def wrap_key(self, Q):
        """
        Wrap a key according to RFC 5649 section 4.1.

        Q is the plaintext to be wrapped, a byte string.

        Returns C, the wrapped ciphertext.
        """

        if verbose:
            print("")
            print("Performing key wrap.")
            print("key:       %s" % (self.bin2hex(self.key)))
            print("plaintext: %s" % (self.bin2hex(Q)))
            print("")

        m = len(Q)                              # Plaintext length
        if m % 8 != 0:                          # Pad Q if needed
            Q += "\x00" * (8 - (m % 8))
        R = [pack(">LL", 0xa65959a6, m)]        # Magic MSB(32,A), build LSB(32,A)
        R.extend(Q[i : i + 8]                   # Append Q
                 for i in xrange(0, len(Q), 8))

        n = len(R) - 1

        if n == 1:
            R[0], R[1] = self._encrypt(R[0], R[1])

        else:
            # RFC 3394 section 2.2.1
            if verbose:
                print("")
                print("Number of blocks to wrap: %d" % (n - 1))
                print("Blocks before wrap:")
                for i in self._start_stop(1, n):
                    print("R[%d] = %s" % (i, self.bin2hex(R[i])))
                print("A before wrap: %s" % (self.bin2hex(R[0])))
                print("")


            for j in self._start_stop(0, 5):
                for i in self._start_stop(1, n):
                    if verbose:
                        print("")
                        print("Iteration %d, %d" % (j, i))

                    if verbose:
                        print("Before encrypt: R[0] = %s  R[%d] = %s" % (self.bin2hex(R[0]), i, self.bin2hex(R[i])))

                    R[0], R[i] = self._encrypt(R[0], R[i])

                    if verbose:
                        print("After encrypt:  R[0] = %s  R[%d] = %s" % (self.bin2hex(R[0]), i, self.bin2hex(R[i])))

                    W0, W1 = unpack(">LL", R[0])
                    xorval = n * j + i
                    W1 ^= xorval
                    R[0] = pack(">LL", W0, W1)
                    if verbose:
                        print("xorval = 0x%016x" % (xorval))

            if verbose:
                print("")
                print("Blocks after wrap:")
                for i in self._start_stop(1, n):
                    print("R[%d] = %s" % (i, self.bin2hex(R[i])))
                print("A after wrap: %s" % (self.bin2hex(R[0])))
                print("")


        assert len(R) == (n + 1) and all(len(r) == 8 for r in R)
        return "".join(R)


    def unwrap_key(self, C):
        """
        Unwrap a key according to RFC 5649 section 4.2.

        C is the ciphertext to be unwrapped, a byte string

        Returns Q, the unwrapped plaintext.
        """

        if len(C) % 8 != 0:
            raise self.UnwrapError("Ciphertext length {} is not an integral number of blocks"
                                   .format(len(C)))

        n = (len(C) / 8) - 1
        R = [C[i : i + 8] for i in xrange(0, len(C), 8)]

        if n == 1:
            R[0], R[1] = self._decrypt(R[0], R[1])

        else:
            # RFC 3394 section 2.2.2 steps (1), (2), and part of (3)
            for j in self._start_stop(5, 0):
                for i in self._start_stop(n, 1):
                    W0, W1 = unpack(">LL", R[0])
                    W1 ^= n * j + i
                    R[0] = pack(">LL", W0, W1)
                    R[0], R[i] = self._decrypt(R[0], R[i])

        magic, m = unpack(">LL", R[0])

        if magic != 0xa65959a6:
            raise self.UnwrapError("Magic value in AIV should have been 0xa65959a6, was 0x{:02x}"
                              .format(magic))

        if m <= 8 * (n - 1) or m > 8 * n:
            raise self.UnwrapError("Length encoded in AIV out of range: m {}, n {}".format(m, n))

        R = "".join(R[1:])
        assert len(R) ==  8 * n

        if any(r != "\x00" for r in R[m:]):
            raise self.UnwrapError("Nonzero trailing bytes {}".format(R[m:].encode("hex")))

        return R[:m]


if __name__ == "__main__":

    # Test code from here down

    class TestAESKeyWrapWithPadding(unittest.TestCase):

        @staticmethod
        def bin2hex(bytes, sep = ":"):
            return sep.join("{:02x}".format(ord(b)) for b in bytes)

        @staticmethod
        def hex2bin(text):
            return text.translate(None, ": \t\n\r").decode("hex")

        def loopback_test(self, I):
            K = AESKeyWrapWithPadding(self.hex2bin("00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f"))
            C = K.wrap_key(I)
            O = K.unwrap_key(C)
            self.assertEqual(I, O, "Input and output plaintext did not match: {!r} <> {!r}".format(I, O))

        def rfc5649_test(self, K, Q, C):
            K = AESKeyWrapWithPadding(key = self.hex2bin(K))
            Q = self.hex2bin(Q)
            C = self.hex2bin(C)
            c = K.wrap_key(Q)
            if verbose:
                print("Wrapped result: %s" % (self.bin2hex(c)))

            q = K.unwrap_key(C)
            self.assertEqual(q, Q, "Input and output plaintext did not match: {} <> {}".format(self.bin2hex(Q), self.bin2hex(q)))
            self.assertEqual(c, C, "Input and output ciphertext did not match: {} <> {}".format(self.bin2hex(C), self.bin2hex(c)))


#        def test_rfc5649_1(self):
#            self.rfc5649_test(K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8",
#                              Q = "c37b7e6492584340 bed1220780894115 5068f738",
#                              C = "138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a 5f54f373fa543b6a")
#
#        def test_rfc5649_2(self):
#            self.rfc5649_test(K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8",
#                              Q = "466f7250617369",
#                              C = "afbeb0f07dfbf541 9200f2ccb50bb24f")
#
#
#        def test_mangled_1(self):
#            self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test,
#                              K = "5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8",
#                              Q = "466f7250617368",
#                              C = "afbeb0f07dfbf541 9200f2ccb50bb24f")
#
#        def test_mangled_2(self):
#            self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test,
#                              K = "5840df6e29b02af0 ab493b705bf16ea1 ae8338f4dcc176a8",
#                              Q = "466f7250617368",
#                              C = "afbeb0f07dfbf541 9200f2ccb50bb24f 0123456789abcdef")
#
#        def test_mangled_3(self):
#            self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test,
#                              K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8",
#                              Q = "c37b7e6492584340 bed1220780894115 5068f738",
#                              C = "138bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a")
#
#
#        # This one should fail. But it doesn't. Que pasa?!?
#        def test_mangled_4(self):
#            self.assertRaises(AESKeyWrapWithPadding.UnwrapError, self.rfc5649_test,
#                              K = "5840df6e29b02af1 ab493b705bf16ea1 ae8338f4dcc176a8",
#                              Q = "c37b7e6492584340 bed1220780894115 5068f738",
#                              C = "238bdeaa9b8fa7fc 61f97742e72248ee 5ae6ae5360d1ae6a")


        # Test vectors from NISTs set of test vectors for SP800-38F KWP algorithm.
        # 128 bit key.
#        def test_kwp_ae_128_1(self):
#            self.rfc5649_test(K = "7efb9b3964de316e 7245c86186d98b5f",
#                              Q = "3e",
#                              C = "116a4054c13b7fea de9c22aa57b3caed")
#
#        def test_kwp_ae_128_2(self):
#            self.rfc5649_test(K = "45c770fc26717507 2d70a38269c54685",
#                              Q = "cc5fb15a17795c34",
#                              C = "78ffa3f03b65c55b 812f355730af71ac")
#
#        def test_kwp_ae_128_3(self):
#            self.rfc5649_test(K = "853e2bac0f1e6298 67acea0d2b3c087e",
#                              Q = "49575527bc59530f be",
#                              C = "b43781062eb0317e b2dec6329f2d64de 1c33d85570d57db6")

        def test_kwp_ae_128_4(self):
            self.rfc5649_test(K = "c03db3cc1416dcd1 c069a195a8d77e3d",
                              Q = "46f87f58cdda4200 f53d99ce2e49bdb7 6212511fe0cd4d0 b5f37a27d45a288",
                              C = "57e3b6699c6e8177 59a69492bb7e2cd0 0160d2ebef9bf4d 4eb16fbf798f134 0f6df6558a4fb84cd0")


#        def test_kwp_ae_256_1(self):
#            self.rfc5649_test(K = "2800f18237cf8d2b a1dfe361784fd751 9b0fdb0ec73e2ab1 c0b966b9173fc5b5",
#                              Q = "ad",
#                              C = "c1eccf2d077a385e 67aaeb35552c893c")
#
#        def test_kwp_ae_256_2(self):
#            self.rfc5649_test(K = "1c997c2bb5a15a45 93e337b3249675d55 7467417917f6bc51 65c9af6a3e29504",
#                              Q = "3e3eafc50cd4e939",
#                              C = "163eb9e7dbc8ed00 86dffbc6ab00e329")
#
#        def test_kwp_ae_256_3(self):
#            self.rfc5649_test(K = "8df1533f99be6fe6 0f951057fed1daccd 14bd4e34118f24af 677bbf46bf11fe7",
#                              Q = "fb36b1f3907fb5ed ce",
#                              C = "6974d7bae0221b4e d91336c26af77e327 61f6024d8bbf292")
#
#        def test_kwp_ae_256_4(self):
#            self.rfc5649_test(K = "dea4667d911b5c9e c996cdb35da0e29bc 996cbfb0e0a56bac 12fccc334d732eb",
#                              Q = "25d58d437a56a733 2a18541333201f992 9fccde11b06844c1 9ba1ca224cfd6",
#                              C = "86d4e258391f15d7 d4f0ab3e15d6f45e6 5dd2f8caf4c67209 63bb8970fc2f3a4 a58dc74674347ec9")


#        def test_loopback_1(self):
#            self.loopback_test("!")
#
#        def test_loopback_2(self):
#            self.loopback_test("Yo!")
#
#        def test_loopback_3(self):
#            self.loopback_test("Hi, Mom")
#
#        def test_loopback_4(self):
#            self.loopback_test("1" * (64 / 8))
#
#        def test_loopback_5(self):
#            self.loopback_test("2" * (128 / 8))
#
#        def test_loopback_6(self):
#            self.loopback_test("3" * (256 / 8))
#
#        def test_loopback_7(self):
#            self.loopback_test("3.14159265358979323846264338327950288419716939937510")
#
#        def test_loopback_8(self):
#            self.loopback_test("3.14159265358979323846264338327950288419716939937510")
#
#        def test_loopback_9(self):
#            self.loopback_test("Hello!  My name is Inigo Montoya. You killed my AES key wrapper. Prepare to die.")
#
#        def test_joachim_loopback(self):
#            I = "31:32:33"
#            K = AESKeyWrapWithPadding(urandom(256/8))
#            C = K.wrap_key(I)
#            O = K.unwrap_key(C)
#            self.assertEqual(I, O, "Input and output plaintext did not match: {!r} <> {!r}".format(I, O))
#

    unittest.main(verbosity = 9)

#======================================================================
# OEF aes_keywrap.py
#======================================================================