diff options
Diffstat (limited to 'rsa.c')
-rw-r--r-- | rsa.c | 102 |
1 files changed, 71 insertions, 31 deletions
@@ -70,7 +70,6 @@ #include <stdlib.h> #include <stddef.h> #include <string.h> -#include <assert.h> #include "hal.h" #include "hal_internal.h" @@ -94,6 +93,15 @@ #endif /* + * How big to make the buffers for the modulus coefficient and + * Montgomery factor. This will almost certainly want tuning. + */ + +#ifndef HAL_RSA_MAX_OPERAND_LENGTH +#define HAL_RSA_MAX_OPERAND_LENGTH (4096 / 8) +#endif + +/* * Whether we want debug output. */ @@ -123,7 +131,7 @@ void hal_rsa_set_blinding(const int onoff) */ struct hal_rsa_key { - hal_key_type_t type; /* What kind of key this is */ + hal_key_type_t type; /* What kind of key this is */ fp_int n[1]; /* The modulus */ fp_int e[1]; /* Public exponent */ fp_int d[1]; /* Private exponent */ @@ -132,8 +140,17 @@ struct hal_rsa_key { fp_int u[1]; /* 1/q mod p */ fp_int dP[1]; /* d mod (p - 1) */ fp_int dQ[1]; /* d mod (q - 1) */ + unsigned flags; /* Internal key flags */ + uint8_t /* ModExpA7 speedup factors */ + nC[HAL_RSA_MAX_OPERAND_LENGTH], nF[HAL_RSA_MAX_OPERAND_LENGTH], + pC[HAL_RSA_MAX_OPERAND_LENGTH/2], pF[HAL_RSA_MAX_OPERAND_LENGTH/2], + qC[HAL_RSA_MAX_OPERAND_LENGTH/2], qF[HAL_RSA_MAX_OPERAND_LENGTH/2]; }; +#define RSA_FLAG_PRECALC_N_DONE (1 << 0) +#define RSA_FLAG_PRECALC_P_DONE (1 << 1) +#define RSA_FLAG_PRECALC_Q_DONE (1 << 2) + const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t); /* @@ -158,7 +175,7 @@ const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t); case FP_OKAY: break; \ case FP_VAL: lose(HAL_ERROR_BAD_ARGUMENTS); \ case FP_MEM: lose(HAL_ERROR_ALLOCATION_FAILURE); \ - default: lose(HAL_ERROR_IMPOSSIBLE); \ + default: lose(HAL_ERROR_IMPOSSIBLE); \ } \ } while (0) @@ -171,7 +188,8 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz { hal_error_t err = HAL_OK; - assert(bn != NULL && buffer != NULL); + if (bn == NULL || buffer == NULL) + return HAL_ERROR_IMPOSSIBLE; const size_t bytes = fp_unsigned_bin_size(unconst_fp_int(bn)); @@ -193,22 +211,18 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz */ static hal_error_t modexp(hal_core_t *core, - const fp_int * msg, + const int precalc_done, + const fp_int * const msg, const fp_int * const exp, const fp_int * const mod, - fp_int *res) + fp_int *res, + uint8_t *coeff, const size_t coeff_len, + uint8_t *mont, const size_t mont_len) { hal_error_t err = HAL_OK; - assert(msg != NULL && exp != NULL && mod != NULL && res != NULL); - - fp_int reduced_msg[1] = INIT_FP_INT; - - if (fp_cmp_mag(unconst_fp_int(msg), unconst_fp_int(mod)) != FP_LT) { - fp_init(reduced_msg); - fp_mod(unconst_fp_int(msg), unconst_fp_int(mod), reduced_msg); - msg = reduced_msg; - } + if (msg == NULL || exp == NULL || mod == NULL || res == NULL || coeff == NULL || mont == NULL) + return HAL_ERROR_IMPOSSIBLE; const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3; const size_t exp_len = (fp_unsigned_bin_size(unconst_fp_int(exp)) + 3) & ~3; @@ -222,11 +236,13 @@ static hal_error_t modexp(hal_core_t *core, if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK || (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK || (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK || - (err = hal_modexp(core, + (err = hal_modexp(core, precalc_done, msgbuf, sizeof(msgbuf), expbuf, sizeof(expbuf), modbuf, sizeof(modbuf), - resbuf, sizeof(resbuf))) != HAL_OK) + resbuf, sizeof(resbuf), + coeff, coeff_len, + mont, mont_len)) != HAL_OK) goto fail; fp_read_unsigned_bin(res, resbuf, sizeof(resbuf)); @@ -249,10 +265,14 @@ static hal_error_t modexp(hal_core_t *core, */ static hal_error_t modexp(const hal_core_t *core, /* ignored */ + const int precalc_done, /* ignored */ const fp_int * const msg, const fp_int * const exp, const fp_int * const mod, - fp_int *res) + fp_int *res, + uint8_t *coeff, const size_t coeff_len, /* ignored */ + uint8_t *mont, const size_t mont_len) /* ignored */ + { hal_error_t err = HAL_OK; FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp), unconst_fp_int(mod), res)); @@ -281,7 +301,12 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */ int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d) { - return modexp(NULL, a, b, c, d) == HAL_OK ? FP_OKAY : FP_VAL; + const size_t len = (fp_unsigned_bin_size(unconst_fp_int(b)) + 3) & ~3; + uint8_t C[len], F[len]; + const hal_error_t err = modexp(NULL, 0, a, b, c, d, C, sizeof(C), F, sizeof(F)); + memset(C, 0, sizeof(C)); + memset(F, 0, sizeof(F)); + return err == HAL_OK ? FP_OKAY : FP_VAL; } #endif /* HAL_RSA_SIGN_USE_MODEXP && HAL_RSA_KEYGEN_USE_MODEXP */ @@ -294,7 +319,8 @@ int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d) static hal_error_t create_blinding_factors(hal_core_t *core, const hal_rsa_key_t * const key, fp_int *bf, fp_int *ubf) { - assert(key != NULL && bf != NULL && ubf != NULL); + if (key == NULL || bf == NULL || ubf == NULL) + return HAL_ERROR_IMPOSSIBLE; uint8_t rnd[fp_unsigned_bin_size(unconst_fp_int(key->n))]; hal_error_t err = HAL_OK; @@ -306,9 +332,12 @@ static hal_error_t create_blinding_factors(hal_core_t *core, const hal_rsa_key_t fp_read_unsigned_bin(bf, rnd, sizeof(rnd)); fp_copy(bf, ubf); - if ((err = modexp(core, bf, key->e, key->n, bf)) != HAL_OK) + if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), bf, key->e, key->n, bf, + key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) != HAL_OK) goto fail; + key->flags |= RSA_FLAG_PRECALC_N_DONE; + FP_CHECK(fp_invmod(ubf, unconst_fp_int(key->n), ubf)); fail: @@ -322,7 +351,8 @@ static hal_error_t create_blinding_factors(hal_core_t *core, const hal_rsa_key_t static hal_error_t rsa_crt(hal_core_t *core, const hal_rsa_key_t * const key, fp_int *msg, fp_int *sig) { - assert(key != NULL && msg != NULL && sig != NULL); + if (key == NULL || msg == NULL || sig == NULL) + return HAL_ERROR_IMPOSSIBLE; hal_error_t err = HAL_OK; fp_int t[1] = INIT_FP_INT; @@ -343,11 +373,18 @@ static hal_error_t rsa_crt(hal_core_t *core, const hal_rsa_key_t * const key, fp /* * m1 = msg ** dP mod p * m2 = msg ** dQ mod q + * + * This is just crying out to be done with parallel cores, but get + * the boring version working before jumping off that cliff. */ - if ((err = modexp(core, msg, key->dP, key->p, m1)) != HAL_OK || - (err = modexp(core, msg, key->dQ, key->q, m2)) != HAL_OK) + if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_P_DONE), + msg, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK || + (err = modexp(core, (key->flags & RSA_FLAG_PRECALC_Q_DONE), + msg, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK) goto fail; + key->flags |= RSA_FLAG_PRECALC_P_DONE | RSA_FLAG_PRECALC_Q_DONE; + /* * t = m1 - m2. */ @@ -406,11 +443,12 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core, fp_read_unsigned_bin(i, unconst_uint8_t(input), input_len); - if ((err = modexp(core, i, key->e, key->n, o)) != HAL_OK || - (err = unpack_fp(o, output, output_len)) != HAL_OK) - goto fail; + if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), i, key->e, key->n, o, + key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) == HAL_OK) { + key->flags |= RSA_FLAG_PRECALC_N_DONE; + err = unpack_fp(o, output, output_len); + } - fail: fp_zero(i); fp_zero(o); return err; @@ -436,11 +474,13 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core, * just do brute force ModExp. */ - if (fp_iszero(key->p) || fp_iszero(key->q) || fp_iszero(key->u) || fp_iszero(key->dP) || fp_iszero(key->dQ)) - err = modexp(core, i, key->d, key->n, o); - else + if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) && !fp_iszero(key->dP) && !fp_iszero(key->dQ)) err = rsa_crt(core, key, i, o); + else if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), i, key->d, key->n, o, + key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) == HAL_OK) + key->flags |= RSA_FLAG_PRECALC_N_DONE; + if (err != HAL_OK || (err = unpack_fp(o, output, output_len)) != HAL_OK) goto fail; |