From 7a89eaa086fa534a6a0aac45fa2f5865ef7839ef Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Tue, 16 Jun 2015 19:17:24 -0400 Subject: Refactor key loading code. --- cryptech.h | 31 ++++++++---- rsa.c | 152 ++++++++++++++++++++++++++++++++++--------------------- tests/test-rsa.c | 42 +++++++++------ 3 files changed, 142 insertions(+), 83 deletions(-) diff --git a/cryptech.h b/cryptech.h index 2f8abc9..48f2a75 100644 --- a/cryptech.h +++ b/cryptech.h @@ -442,9 +442,8 @@ DEFINE_HAL_ERROR(HAL_ERROR_KEYWRAP_BAD_MAGIC, "Bad magic number while unwrapping key") \ DEFINE_HAL_ERROR(HAL_ERROR_KEYWRAP_BAD_LENGTH, "Length out of range while unwrapping key") \ DEFINE_HAL_ERROR(HAL_ERROR_KEYWRAP_BAD_PADDING, "Non-zero padding detected unwrapping key") \ - DEFINE_HAL_ERROR(HAL_ERROR_CRT_FAILED, "CRT calculation failed") \ + DEFINE_HAL_ERROR(HAL_ERROR_IMPOSSIBLE, "\"Impossible\" error") \ DEFINE_HAL_ERROR(HAL_ERROR_ALLOCATION_FAILURE, "Memory allocation failed") \ - DEFINE_HAL_ERROR(HAL_ERROR_UNKNOWN_TFM_FAILURE, "Unknown libtfm failure") \ DEFINE_HAL_ERROR(HAL_ERROR_RESULT_TOO_LONG, "Result too long for buffer") \ END_OF_HAL_ERROR_LIST @@ -608,17 +607,29 @@ extern hal_error_t hal_modexp(const uint8_t * const msg, const size_t msg_len, / extern void hal_rsa_set_debug(const int onoff); -extern hal_error_t hal_rsa_crt(const uint8_t * const m, const size_t m_len, - const uint8_t * const n, const size_t n_len, - const uint8_t * const e, const size_t e_len, - const uint8_t * const d, const size_t d_len, - const uint8_t * const p, const size_t p_len, - const uint8_t * const q, const size_t q_len, - const uint8_t * const u, const size_t u_len, - uint8_t * result, const size_t result_len); +extern const size_t hal_rsa_key_t_size; + +typedef enum { RSA_PRIVATE, RSA_PUBLIC } hal_rsa_key_type_t; +typedef struct { void *key; } hal_rsa_key_t; +extern hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type, + hal_rsa_key_t *key, + void *keybuf, const size_t keybuf_len, + const uint8_t * const n, const size_t n_len, + const uint8_t * const e, const size_t e_len, + const uint8_t * const d, const size_t d_len, + const uint8_t * const p, const size_t p_len, + const uint8_t * const q, const size_t q_len, + const uint8_t * const u, const size_t u_len, + const uint8_t * const dP, const size_t dP_len, + const uint8_t * const dQ, const size_t dQ_len); +extern void hal_rsa_key_clear(hal_rsa_key_t key); + +extern hal_error_t hal_rsa_crt(hal_rsa_key_t key, + const uint8_t * const m, const size_t m_len, + uint8_t * result, const size_t result_len); #endif /* _CRYPTECH_H_ */ diff --git a/rsa.c b/rsa.c index 31c4f61..0d3ae69 100644 --- a/rsa.c +++ b/rsa.c @@ -96,10 +96,8 @@ void hal_rsa_set_debug(const int onoff) * can make memory allocation the caller's problem (well, maybe). */ -typedef enum { RSA_PRIVATE, RSA_PUBLIC } rsa_key_type_t; - -typedef struct { - rsa_key_type_t type; /* What kind of key this is */ +struct rsa_key { + hal_rsa_key_type_t type; /* What kind of key this is */ fp_int n; /* The modulus */ fp_int e; /* Public exponent */ fp_int d; /* Private exponent */ @@ -108,9 +106,9 @@ typedef struct { fp_int u; /* 1/q mod p */ fp_int dP; /* d mod (p - 1) */ fp_int dQ; /* d mod (q - 1) */ -} rsa_key_t; +}; -const size_t hal_rsa_key_t_size = sizeof(rsa_key_t); +const size_t hal_rsa_key_t_size = sizeof(struct rsa_key); /* * In the long run we want a full RSA implementation, or enough of one @@ -134,7 +132,7 @@ const size_t hal_rsa_key_t_size = sizeof(rsa_key_t); case FP_OKAY: break; \ case FP_VAL: lose(HAL_ERROR_BAD_ARGUMENTS); \ case FP_MEM: lose(HAL_ERROR_ALLOCATION_FAILURE); \ - default: lose(HAL_ERROR_UNKNOWN_TFM_FAILURE); \ + default: lose(HAL_ERROR_IMPOSSIBLE); \ } \ } while (0) @@ -162,13 +160,10 @@ static hal_error_t unpack_fp(fp_int *bn, uint8_t *buffer, const size_t length) } /* - * modexp_fp() is a function I haven't written yet which takes - * fp_int values, unwraps them, feeds the numbers into hal_modexp(), - * and wraps the result back up as a fp_int. + * Unwrap bignums into byte arrays, feeds them into hal_modexp(), and + * wrap result back up as a bignum. */ -/* modexp_fp(&tmp.msg, &key.dP, &key.p, &tmp.m1) */ - static hal_error_t modexp_fp(fp_int *msg, fp_int *exp, fp_int *mod, fp_int *res) { hal_error_t err = HAL_OK; @@ -201,81 +196,124 @@ static hal_error_t modexp_fp(fp_int *msg, fp_int *exp, fp_int *mod, fp_int *res) return err; } + /* - * CRT with the components we have. PyCrypto doesn't give us dP or - * dQ, probably because they're so easy to calculate that it's - * (almost) not worth the bother. + * Clear a key. We might want to do something a bit more energetic + * than plain old memset() eventually. */ -hal_error_t hal_rsa_crt(const uint8_t * const m, const size_t m_len, - const uint8_t * const n, const size_t n_len, - const uint8_t * const e, const size_t e_len, - const uint8_t * const d, const size_t d_len, - const uint8_t * const p, const size_t p_len, - const uint8_t * const q, const size_t q_len, - const uint8_t * const u, const size_t u_len, - uint8_t * result, const size_t result_len) +void hal_rsa_key_clear(hal_rsa_key_t key) { - hal_error_t err = HAL_OK; - rsa_key_t key; - struct { fp_int t, msg, m1, m2; } tmp; + if (key.key != NULL) + memset(key.key, 0, sizeof(struct rsa_key)); +} - key.type = RSA_PRIVATE; +/* + * Load a key from raw components. This is a simplistic version: we + * don't attempt to generate missing private key components, we just + * reject the key if it doesn't have everything we expect. + * + * In theory, the only things we'd really need for the private key if + * we were being nicer about this would be e, p, and q, as we could + * calculate everything else from them. + */ + +hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type, + hal_rsa_key_t *key_, + void *keybuf, const size_t keybuf_len, + const uint8_t * const n, const size_t n_len, + const uint8_t * const e, const size_t e_len, + const uint8_t * const d, const size_t d_len, + const uint8_t * const p, const size_t p_len, + const uint8_t * const q, const size_t q_len, + const uint8_t * const u, const size_t u_len, + const uint8_t * const dP, const size_t dP_len, + const uint8_t * const dQ, const size_t dQ_len) +{ + if (key_ == NULL || keybuf == NULL || keybuf_len < sizeof(struct rsa_key)) + return HAL_ERROR_BAD_ARGUMENTS; + + memset(keybuf, 0, keybuf_len); -#define _(x) do { fp_init(&key.x); fp_read_unsigned_bin(&key.x, (uint8_t *) x, x##_len); } while (0) - _(n); _(e); _(d); _(p); _(q); _(u); + struct rsa_key *key = keybuf; + + key->type = type; + +#define _(x) do { fp_init(&key->x); if (x == NULL) goto fail; fp_read_unsigned_bin(&key->x, (uint8_t *) x, x##_len); } while (0) + switch (type) { + case RSA_PRIVATE: + _(d); _(p); _(q); _(u); _(dP); _(dQ); + case RSA_PUBLIC: + _(n); _(e); + key_->key = key; + return HAL_OK; + } #undef _ + fail: + memset(key, 0, sizeof(*key)); + return HAL_ERROR_BAD_ARGUMENTS; +} + +/* + * RSA decyrption/signature using the Chinese Remainder Theorem + * (Garner's formula). + */ + +hal_error_t hal_rsa_crt(hal_rsa_key_t key_, + const uint8_t * const m, const size_t m_len, + uint8_t * result, const size_t result_len) +{ + hal_error_t err = HAL_OK; + struct rsa_key *key = key_.key; + struct { fp_int t, msg, m1, m2; } tmp; + fp_init(&tmp.t); fp_init(&tmp.msg); fp_init(&tmp.m1); fp_init(&tmp.m2); - /* Calculate dP = d % (p-1) and dQ = d % (q-1) */ - - fp_sub_d(&key.p, 1, &tmp.t); - FP_CHECK(fp_mod(&key.d, &tmp.t, &key.dP)); - - fp_sub_d(&key.q, 1, &tmp.t); - FP_CHECK(fp_mod(&key.d, &tmp.t, &key.dQ)); - - /* Read message to be signed/decrypted into a bignum */ - fp_read_unsigned_bin(&tmp.msg, (uint8_t *) m, m_len); - /* OK, try to perform the CRT */ - - /* m1 = msg ** dP mod p, m2 = msg ** dQ mod q */ - if ((err = modexp_fp(&tmp.msg, &key.dP, &key.p, &tmp.m1)) != HAL_OK || - (err = modexp_fp(&tmp.msg, &key.dQ, &key.q, &tmp.m2)) != HAL_OK) + /* + * m1 = msg ** dP mod p + * m2 = msg ** dQ mod q + */ + if ((err = modexp_fp(&tmp.msg, &key->dP, &key->p, &tmp.m1)) != HAL_OK || + (err = modexp_fp(&tmp.msg, &key->dQ, &key->q, &tmp.m2)) != HAL_OK) goto fail; - /* t = m1 - m2 */ + /* + * t = m1 - m2. + * Add zero (mod p) once or twice if necessary to get positive result. + */ fp_sub(&tmp.m1, &tmp.m2, &tmp.t); - - /* Add zero (mod p) if necessary to get positive result */ if (fp_cmp_d(&tmp.t, 0) == FP_LT) - fp_add(&tmp.t, &key.p, &tmp.t); + fp_add(&tmp.t, &key->p, &tmp.t); if (fp_cmp_d(&tmp.t, 0) == FP_LT) - fp_add(&tmp.t, &key.p, &tmp.t); + fp_add(&tmp.t, &key->p, &tmp.t); if (fp_cmp_d(&tmp.t, 0) == FP_LT) - lose(HAL_ERROR_CRT_FAILED); + lose(HAL_ERROR_IMPOSSIBLE); - /* t = (t * u mod p) * q + m2 */ - FP_CHECK(fp_mulmod(&tmp.t, &key.u, &key.p, &tmp.t)); - fp_mul(&tmp.t, &key.q, &tmp.t); + /* + * t = (t * u mod p) * q + m2 + */ + FP_CHECK(fp_mulmod(&tmp.t, &key->u, &key->p, &tmp.t)); + fp_mul(&tmp.t, &key->q, &tmp.t); fp_add(&tmp.t, &tmp.m2, &tmp.t); - /* Have result, write it back to caller */ + /* + * t now holds result, write it back to caller + */ if ((err = unpack_fp(&tmp.t, result, result_len)) != HAL_OK) goto fail; - /* Done, fall through into cleanup code */ + /* + * Done, fall through into cleanup. + */ fail: - memset(&key, 0, sizeof(key)); memset(&tmp, 0, sizeof(tmp)); - return err; } diff --git a/tests/test-rsa.c b/tests/test-rsa.c index 6925261..b415955 100644 --- a/tests/test-rsa.c +++ b/tests/test-rsa.c @@ -83,28 +83,38 @@ static int test_modexp(const char * const kind, static int test_crt(const char * const kind, const rsa_tc_t * const tc) { - uint8_t result[tc->n.len]; - printf("%s test for %lu-bit RSA key\n", kind, (unsigned long) tc->size); - if (hal_rsa_crt(tc->m.val, tc->m.len, - tc->n.val, tc->n.len, - tc->e.val, tc->e.len, - tc->d.val, tc->d.len, - tc->p.val, tc->p.len, - tc->q.val, tc->q.len, - tc->u.val, tc->u.len, - result, sizeof(result)) != HAL_OK) { - printf("RSA CRT failed\n"); + uint8_t keybuf[hal_rsa_key_t_size]; + hal_error_t err = HAL_OK; + hal_rsa_key_t key; + + if ((err = hal_rsa_key_load(RSA_PRIVATE, &key, keybuf, sizeof(keybuf), + tc->n.val, tc->n.len, + tc->e.val, tc->e.len, + tc->d.val, tc->d.len, + tc->p.val, tc->p.len, + tc->q.val, tc->q.len, + tc->u.val, tc->u.len, + tc->dP.val, tc->dP.len, + tc->dQ.val, tc->dQ.len)) != HAL_OK) { + printf("RSA CRT key load failed: %s\n", hal_error_string(err)); return 0; } - if (memcmp(result, tc->s.val, tc->s.len)) { - printf("MISMATCH\n"); - return 0; - } + uint8_t result[tc->n.len]; - return 1; + if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) + printf("RSA CRT failed: %s\n", hal_error_string(err)); + + const int mismatch = (err == HAL_OK && memcmp(result, tc->s.val, tc->s.len) != 0); + + if (mismatch) + printf("MISMATCH\n"); + + hal_rsa_key_clear(key); + + return err == HAL_OK && !mismatch; } /* -- cgit v1.2.3