aboutsummaryrefslogtreecommitdiff
path: root/rsa.c
diff options
context:
space:
mode:
Diffstat (limited to 'rsa.c')
-rw-r--r--rsa.c102
1 files changed, 71 insertions, 31 deletions
diff --git a/rsa.c b/rsa.c
index eeb611c..90a878f 100644
--- a/rsa.c
+++ b/rsa.c
@@ -70,7 +70,6 @@
#include <stdlib.h>
#include <stddef.h>
#include <string.h>
-#include <assert.h>
#include "hal.h"
#include "hal_internal.h"
@@ -94,6 +93,15 @@
#endif
/*
+ * How big to make the buffers for the modulus coefficient and
+ * Montgomery factor. This will almost certainly want tuning.
+ */
+
+#ifndef HAL_RSA_MAX_OPERAND_LENGTH
+#define HAL_RSA_MAX_OPERAND_LENGTH (4096 / 8)
+#endif
+
+/*
* Whether we want debug output.
*/
@@ -123,7 +131,7 @@ void hal_rsa_set_blinding(const int onoff)
*/
struct hal_rsa_key {
- hal_key_type_t type; /* What kind of key this is */
+ hal_key_type_t type; /* What kind of key this is */
fp_int n[1]; /* The modulus */
fp_int e[1]; /* Public exponent */
fp_int d[1]; /* Private exponent */
@@ -132,8 +140,17 @@ struct hal_rsa_key {
fp_int u[1]; /* 1/q mod p */
fp_int dP[1]; /* d mod (p - 1) */
fp_int dQ[1]; /* d mod (q - 1) */
+ unsigned flags; /* Internal key flags */
+ uint8_t /* ModExpA7 speedup factors */
+ nC[HAL_RSA_MAX_OPERAND_LENGTH], nF[HAL_RSA_MAX_OPERAND_LENGTH],
+ pC[HAL_RSA_MAX_OPERAND_LENGTH/2], pF[HAL_RSA_MAX_OPERAND_LENGTH/2],
+ qC[HAL_RSA_MAX_OPERAND_LENGTH/2], qF[HAL_RSA_MAX_OPERAND_LENGTH/2];
};
+#define RSA_FLAG_PRECALC_N_DONE (1 << 0)
+#define RSA_FLAG_PRECALC_P_DONE (1 << 1)
+#define RSA_FLAG_PRECALC_Q_DONE (1 << 2)
+
const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t);
/*
@@ -158,7 +175,7 @@ const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t);
case FP_OKAY: break; \
case FP_VAL: lose(HAL_ERROR_BAD_ARGUMENTS); \
case FP_MEM: lose(HAL_ERROR_ALLOCATION_FAILURE); \
- default: lose(HAL_ERROR_IMPOSSIBLE); \
+ default: lose(HAL_ERROR_IMPOSSIBLE); \
} \
} while (0)
@@ -171,7 +188,8 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz
{
hal_error_t err = HAL_OK;
- assert(bn != NULL && buffer != NULL);
+ if (bn == NULL || buffer == NULL)
+ return HAL_ERROR_IMPOSSIBLE;
const size_t bytes = fp_unsigned_bin_size(unconst_fp_int(bn));
@@ -193,22 +211,18 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz
*/
static hal_error_t modexp(hal_core_t *core,
- const fp_int * msg,
+ const int precalc_done,
+ const fp_int * const msg,
const fp_int * const exp,
const fp_int * const mod,
- fp_int *res)
+ fp_int *res,
+ uint8_t *coeff, const size_t coeff_len,
+ uint8_t *mont, const size_t mont_len)
{
hal_error_t err = HAL_OK;
- assert(msg != NULL && exp != NULL && mod != NULL && res != NULL);
-
- fp_int reduced_msg[1] = INIT_FP_INT;
-
- if (fp_cmp_mag(unconst_fp_int(msg), unconst_fp_int(mod)) != FP_LT) {
- fp_init(reduced_msg);
- fp_mod(unconst_fp_int(msg), unconst_fp_int(mod), reduced_msg);
- msg = reduced_msg;
- }
+ if (msg == NULL || exp == NULL || mod == NULL || res == NULL || coeff == NULL || mont == NULL)
+ return HAL_ERROR_IMPOSSIBLE;
const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3;
const size_t exp_len = (fp_unsigned_bin_size(unconst_fp_int(exp)) + 3) & ~3;
@@ -222,11 +236,13 @@ static hal_error_t modexp(hal_core_t *core,
if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
(err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK ||
(err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK ||
- (err = hal_modexp(core,
+ (err = hal_modexp(core, precalc_done,
msgbuf, sizeof(msgbuf),
expbuf, sizeof(expbuf),
modbuf, sizeof(modbuf),
- resbuf, sizeof(resbuf))) != HAL_OK)
+ resbuf, sizeof(resbuf),
+ coeff, coeff_len,
+ mont, mont_len)) != HAL_OK)
goto fail;
fp_read_unsigned_bin(res, resbuf, sizeof(resbuf));
@@ -249,10 +265,14 @@ static hal_error_t modexp(hal_core_t *core,
*/
static hal_error_t modexp(const hal_core_t *core, /* ignored */
+ const int precalc_done, /* ignored */
const fp_int * const msg,
const fp_int * const exp,
const fp_int * const mod,
- fp_int *res)
+ fp_int *res,
+ uint8_t *coeff, const size_t coeff_len, /* ignored */
+ uint8_t *mont, const size_t mont_len) /* ignored */
+
{
hal_error_t err = HAL_OK;
FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp), unconst_fp_int(mod), res));
@@ -281,7 +301,12 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */
int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
{
- return modexp(NULL, a, b, c, d) == HAL_OK ? FP_OKAY : FP_VAL;
+ const size_t len = (fp_unsigned_bin_size(unconst_fp_int(b)) + 3) & ~3;
+ uint8_t C[len], F[len];
+ const hal_error_t err = modexp(NULL, 0, a, b, c, d, C, sizeof(C), F, sizeof(F));
+ memset(C, 0, sizeof(C));
+ memset(F, 0, sizeof(F));
+ return err == HAL_OK ? FP_OKAY : FP_VAL;
}
#endif /* HAL_RSA_SIGN_USE_MODEXP && HAL_RSA_KEYGEN_USE_MODEXP */
@@ -294,7 +319,8 @@ int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
static hal_error_t create_blinding_factors(hal_core_t *core, const hal_rsa_key_t * const key, fp_int *bf, fp_int *ubf)
{
- assert(key != NULL && bf != NULL && ubf != NULL);
+ if (key == NULL || bf == NULL || ubf == NULL)
+ return HAL_ERROR_IMPOSSIBLE;
uint8_t rnd[fp_unsigned_bin_size(unconst_fp_int(key->n))];
hal_error_t err = HAL_OK;
@@ -306,9 +332,12 @@ static hal_error_t create_blinding_factors(hal_core_t *core, const hal_rsa_key_t
fp_read_unsigned_bin(bf, rnd, sizeof(rnd));
fp_copy(bf, ubf);
- if ((err = modexp(core, bf, key->e, key->n, bf)) != HAL_OK)
+ if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), bf, key->e, key->n, bf,
+ key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) != HAL_OK)
goto fail;
+ key->flags |= RSA_FLAG_PRECALC_N_DONE;
+
FP_CHECK(fp_invmod(ubf, unconst_fp_int(key->n), ubf));
fail:
@@ -322,7 +351,8 @@ static hal_error_t create_blinding_factors(hal_core_t *core, const hal_rsa_key_t
static hal_error_t rsa_crt(hal_core_t *core, const hal_rsa_key_t * const key, fp_int *msg, fp_int *sig)
{
- assert(key != NULL && msg != NULL && sig != NULL);
+ if (key == NULL || msg == NULL || sig == NULL)
+ return HAL_ERROR_IMPOSSIBLE;
hal_error_t err = HAL_OK;
fp_int t[1] = INIT_FP_INT;
@@ -343,11 +373,18 @@ static hal_error_t rsa_crt(hal_core_t *core, const hal_rsa_key_t * const key, fp
/*
* m1 = msg ** dP mod p
* m2 = msg ** dQ mod q
+ *
+ * This is just crying out to be done with parallel cores, but get
+ * the boring version working before jumping off that cliff.
*/
- if ((err = modexp(core, msg, key->dP, key->p, m1)) != HAL_OK ||
- (err = modexp(core, msg, key->dQ, key->q, m2)) != HAL_OK)
+ if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_P_DONE),
+ msg, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK ||
+ (err = modexp(core, (key->flags & RSA_FLAG_PRECALC_Q_DONE),
+ msg, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK)
goto fail;
+ key->flags |= RSA_FLAG_PRECALC_P_DONE | RSA_FLAG_PRECALC_Q_DONE;
+
/*
* t = m1 - m2.
*/
@@ -406,11 +443,12 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core,
fp_read_unsigned_bin(i, unconst_uint8_t(input), input_len);
- if ((err = modexp(core, i, key->e, key->n, o)) != HAL_OK ||
- (err = unpack_fp(o, output, output_len)) != HAL_OK)
- goto fail;
+ if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), i, key->e, key->n, o,
+ key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) == HAL_OK) {
+ key->flags |= RSA_FLAG_PRECALC_N_DONE;
+ err = unpack_fp(o, output, output_len);
+ }
- fail:
fp_zero(i);
fp_zero(o);
return err;
@@ -436,11 +474,13 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core,
* just do brute force ModExp.
*/
- if (fp_iszero(key->p) || fp_iszero(key->q) || fp_iszero(key->u) || fp_iszero(key->dP) || fp_iszero(key->dQ))
- err = modexp(core, i, key->d, key->n, o);
- else
+ if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) && !fp_iszero(key->dP) && !fp_iszero(key->dQ))
err = rsa_crt(core, key, i, o);
+ else if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), i, key->d, key->n, o,
+ key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) == HAL_OK)
+ key->flags |= RSA_FLAG_PRECALC_N_DONE;
+
if (err != HAL_OK || (err = unpack_fp(o, output, output_len)) != HAL_OK)
goto fail;