aboutsummaryrefslogtreecommitdiff
path: root/rsa.c
diff options
context:
space:
mode:
Diffstat (limited to 'rsa.c')
-rw-r--r--rsa.c240
1 files changed, 222 insertions, 18 deletions
diff --git a/rsa.c b/rsa.c
index 1b5de7d..b0c34c5 100644
--- a/rsa.c
+++ b/rsa.c
@@ -121,6 +121,13 @@ void hal_rsa_set_debug(const int onoff)
debug = onoff;
}
+static int do_crt = 1;
+
+void hal_rsa_set_crt(const int onoff)
+{
+ do_crt = onoff;
+}
+
/*
* Whether we want RSA blinding.
*/
@@ -165,9 +172,9 @@ struct hal_rsa_key {
fp_int dQ[1]; /* d mod (q - 1) */
unsigned flags; /* Internal key flags */
uint8_t /* ModExpA7 speedup factors */
- nC[HAL_RSA_MAX_OPERAND_LENGTH], nF[HAL_RSA_MAX_OPERAND_LENGTH],
- pC[HAL_RSA_MAX_OPERAND_LENGTH/2], pF[HAL_RSA_MAX_OPERAND_LENGTH/2],
- qC[HAL_RSA_MAX_OPERAND_LENGTH/2], qF[HAL_RSA_MAX_OPERAND_LENGTH/2];
+ nC[HAL_RSA_MAX_OPERAND_LENGTH+4], nF[HAL_RSA_MAX_OPERAND_LENGTH],
+ pC[HAL_RSA_MAX_OPERAND_LENGTH/2+4], pF[HAL_RSA_MAX_OPERAND_LENGTH/2],
+ qC[HAL_RSA_MAX_OPERAND_LENGTH/2+4], qF[HAL_RSA_MAX_OPERAND_LENGTH/2];
};
#define RSA_FLAG_NEEDS_SAVING (1 << 0)
@@ -228,6 +235,26 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz
#if HAL_RSA_SIGN_USE_MODEXP
+static hal_error_t modexp_precalc(const fp_int *modulus, uint8_t *coeff, size_t coeff_len, uint8_t *factor, size_t factor_len)
+{
+ const size_t keylen = ((fp_unsigned_bin_size(unconst_fp_int(modulus)) + 3) & ~3) * 8;
+ hal_error_t err;
+
+ /* factor = (2 ** (2 * (keylen + 16))) % modulus */
+ fp_int fp_result[1];
+ fp_2expt(fp_result, 2 * (keylen + 16));
+ fp_mod(fp_result, unconst_fp_int(modulus), fp_result);
+ if ((err = unpack_fp(fp_result, factor, factor_len)) != HAL_OK)
+ return err;
+
+ /* coeff = (-modulus ** -1) % (2 ** (keylen + 16)) */
+ fp_int pwr[1];
+ fp_2expt(pwr, keylen + 16);
+ fp_neg(unconst_fp_int(modulus), fp_result);
+ fp_invmod(fp_result, pwr, fp_result);
+ return unpack_fp(fp_result, coeff, coeff_len);
+}
+
/*
* Unwrap bignums into byte arrays, feed them into hal_modexp(), and
* wrap result back up as a bignum.
@@ -248,14 +275,34 @@ static hal_error_t modexp(hal_core_t *core,
return HAL_ERROR_IMPOSSIBLE;
const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3;
- const size_t exp_len = (fp_unsigned_bin_size(unconst_fp_int(exp)) + 3) & ~3;
- const size_t mod_len = (fp_unsigned_bin_size(unconst_fp_int(mod)) + 3) & ~3;
+ const size_t exp_len = (fp_unsigned_bin_size(unconst_fp_int(exp)) + 3) & ~3;
+ const size_t mod_len = (fp_unsigned_bin_size(unconst_fp_int(mod)) + 3) & ~3;
uint8_t msgbuf[msg_len];
uint8_t expbuf[exp_len];
uint8_t modbuf[mod_len];
uint8_t resbuf[mod_len];
+ if (hal_modexp_using_modexpng()) {
+ hal_modexpng_arg_t args = {
+ .core = core,
+ .msg = msgbuf, .msg_len = sizeof(msgbuf),
+ .exp = expbuf, .exp_len = sizeof(expbuf),
+ .mod = modbuf, .mod_len = sizeof(modbuf),
+ .result = resbuf, .result_len = sizeof(resbuf),
+ .coeff = coeff, .coeff_len = coeff_len,
+ .mont = mont, .mont_len = mont_len
+ };
+
+ if ((precalc &&
+ (err = modexp_precalc(mod, coeff, coeff_len, mont, mont_len)) != HAL_OK) ||
+ (err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
+ (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK ||
+ (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK ||
+ (err = hal_modexpng(&args)) != HAL_OK)
+ goto fail;
+ }
+ else {
hal_modexp_arg_t args = {
.core = core,
.msg = msgbuf, .msg_len = sizeof(msgbuf),
@@ -266,11 +313,12 @@ static hal_error_t modexp(hal_core_t *core,
.mont = mont, .mont_len = mont_len
};
- if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
- (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK ||
- (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK ||
- (err = hal_modexp(precalc, &args)) != HAL_OK)
+ if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
+ (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK ||
+ (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK ||
+ (err = hal_modexp(precalc, &args)) != HAL_OK)
goto fail;
+ }
fp_read_unsigned_bin(res, resbuf, sizeof(resbuf));
@@ -279,7 +327,6 @@ static hal_error_t modexp(hal_core_t *core,
memset(expbuf, 0, sizeof(expbuf));
memset(modbuf, 0, sizeof(modbuf));
memset(resbuf, 0, sizeof(resbuf));
- memset(&args, 0, sizeof(args));
return err;
}
@@ -306,10 +353,10 @@ static hal_error_t modexp2(const int precalc,
return HAL_ERROR_IMPOSSIBLE;
const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3;
- const size_t exp1_len = (fp_unsigned_bin_size(unconst_fp_int(exp1)) + 3) & ~3;
- const size_t mod1_len = (fp_unsigned_bin_size(unconst_fp_int(mod1)) + 3) & ~3;
- const size_t exp2_len = (fp_unsigned_bin_size(unconst_fp_int(exp2)) + 3) & ~3;
- const size_t mod2_len = (fp_unsigned_bin_size(unconst_fp_int(mod2)) + 3) & ~3;
+ const size_t exp1_len = (fp_unsigned_bin_size(unconst_fp_int(exp1)) + 3) & ~3;
+ const size_t mod1_len = (fp_unsigned_bin_size(unconst_fp_int(mod1)) + 3) & ~3;
+ const size_t exp2_len = (fp_unsigned_bin_size(unconst_fp_int(exp2)) + 3) & ~3;
+ const size_t mod2_len = (fp_unsigned_bin_size(unconst_fp_int(mod2)) + 3) & ~3;
uint8_t msgbuf[msg_len];
uint8_t expbuf1[exp1_len], modbuf1[mod1_len], resbuf1[mod1_len];
@@ -359,6 +406,113 @@ static hal_error_t modexp2(const int precalc,
return err;
}
+static hal_error_t modexpng(hal_core_t *core,
+ const fp_int * const msg,
+ hal_rsa_key_t *key,
+ fp_int *bf,
+ fp_int *ubf,
+ fp_int *res)
+{
+ hal_error_t err = HAL_OK;
+
+ if (msg == NULL || key == NULL || res == NULL)
+ return HAL_ERROR_IMPOSSIBLE;
+
+ if (!(key->flags & RSA_FLAG_PRECALC_N_DONE)) {
+ if ((err = modexp_precalc(key->n, key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) != HAL_OK)
+ return err;
+ key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING;
+ }
+
+ if (key->p && !(key->flags & RSA_FLAG_PRECALC_PQ_DONE)) {
+ if ((err = modexp_precalc(key->p, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK ||
+ (err = modexp_precalc(key->q, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK)
+ return err;
+ key->flags |= RSA_FLAG_PRECALC_PQ_DONE | RSA_FLAG_NEEDS_SAVING;
+ }
+
+/* number of significant bytes in an fp_int, rounded to a multiple of 4 */
+#define fp_len(x) (fp_unsigned_bin_size(unconst_fp_int(x)) + 3) & ~3
+
+ const size_t mod_len = fp_len(key->n);
+
+ uint8_t msgbuf[fp_len(msg)];
+ uint8_t expbuf[fp_len(key->d)];
+ uint8_t modbuf[mod_len];
+ uint8_t resbuf[mod_len];
+ uint8_t p_buf[mod_len/2];
+ uint8_t q_buf[mod_len/2];
+ uint8_t u_buf[mod_len/2];
+ uint8_t dP_buf[mod_len/2];
+ uint8_t dQ_buf[mod_len/2];
+ uint8_t bf_buf[mod_len];
+ uint8_t ubf_buf[mod_len];
+
+ hal_modexpng_arg_t args = {
+ .core = core,
+ .msg = msgbuf, .msg_len = sizeof(msgbuf),
+ .exp = expbuf, .exp_len = sizeof(expbuf),
+ .mod = modbuf, .mod_len = sizeof(modbuf),
+ .result = resbuf, .result_len = sizeof(resbuf),
+ .coeff = key->nC, .coeff_len = sizeof(key->nC),
+ .mont = key->nF, .mont_len = sizeof(key->nF),
+ .p = p_buf, .p_len = sizeof(p_buf),
+ .pC = key->pC, .pC_len = sizeof(key->pC),
+ .pF = key->pF, .pF_len = sizeof(key->pF),
+ .q = q_buf, .q_len = sizeof(q_buf),
+ .qC = key->qC, .qC_len = sizeof(key->qC),
+ .qF = key->qF, .qF_len = sizeof(key->qF),
+ .dP = dP_buf, .dP_len = sizeof(dP_buf),
+ .dQ = dQ_buf, .dQ_len = sizeof(dQ_buf),
+ .qInv = u_buf, .qInv_len = sizeof(u_buf),
+ .bf = bf_buf, .bf_len = sizeof(bf_buf),
+ .ubf = ubf_buf, .ubf_len = sizeof(ubf_buf),
+ };
+
+ if (bf) {
+ if ((err = unpack_fp(bf, bf_buf, sizeof(bf_buf))) != HAL_OK ||
+ (err = unpack_fp(ubf, ubf_buf, sizeof(ubf_buf))) != HAL_OK)
+ goto fail;
+ }
+ else {
+ /* set blinding factors to (1,1) */
+ memset(bf_buf, 0, sizeof(bf_buf)); bf_buf[sizeof(bf_buf) - 1] = 1;
+ memset(ubf_buf, 0, sizeof(ubf_buf)); ubf_buf[sizeof(ubf_buf) - 1] = 1;
+ }
+
+ if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
+ (err = unpack_fp(key->d, expbuf, sizeof(expbuf))) != HAL_OK ||
+ (err = unpack_fp(key->n, modbuf, sizeof(modbuf))) != HAL_OK ||
+ (err = unpack_fp(key->p, p_buf, sizeof(p_buf))) != HAL_OK ||
+ (err = unpack_fp(key->q, q_buf, sizeof(q_buf))) != HAL_OK ||
+ (err = unpack_fp(key->u, u_buf, sizeof(u_buf))) != HAL_OK ||
+ (err = unpack_fp(key->dP, dP_buf, sizeof(dP_buf))) != HAL_OK ||
+ (err = unpack_fp(key->dQ, dQ_buf, sizeof(dQ_buf))) != HAL_OK ||
+ (err = hal_modexpng(&args)) != HAL_OK)
+ goto fail;
+
+ fp_read_unsigned_bin(res, resbuf, sizeof(resbuf));
+ /* we do the blinding factor permutation in create_blinding_factors,
+ * so we don't need to read them back from the core
+ */
+
+ fail:
+ memset(msgbuf, 0, sizeof(msgbuf));
+ memset(expbuf, 0, sizeof(expbuf));
+ memset(modbuf, 0, sizeof(modbuf));
+ memset(resbuf, 0, sizeof(resbuf));
+ memset(p_buf, 0, sizeof(p_buf));
+ memset(q_buf, 0, sizeof(q_buf));
+ memset(u_buf, 0, sizeof(u_buf));
+ memset(dP_buf, 0, sizeof(dP_buf));
+ memset(dQ_buf, 0, sizeof(dQ_buf));
+ memset(bf_buf, 0, sizeof(bf_buf));
+ memset(ubf_buf, 0, sizeof(ubf_buf));
+ memset(&args, 0, sizeof(args));
+ return err;
+}
+
+
#else /* HAL_RSA_SIGN_USE_MODEXP */
/*
@@ -406,6 +560,25 @@ static hal_error_t modexp2(const int precalc, /* ignored */
return err;
}
+int hal_modexp_using_modexpng(void)
+{
+ return 0;
+}
+
+static hal_error_t modexpng(hal_core_t *core,
+ const fp_int * const msg,
+ hal_rsa_key_t *key,
+ fp_int *bf,
+ fp_int *ubf,
+ fp_int *res)
+{
+ return HAL_ERROR_FORBIDDEN;
+}
+
+static hal_error_t modexp_precalc(const fp_int *modulus, uint8_t *coeff, const size_t coeff_len, uint8_t *factor, const size_t factor_len)
+{
+ return HAL_ERROR_FORBIDDEN;
+}
#endif /* HAL_RSA_SIGN_USE_MODEXP */
/*
@@ -482,6 +655,7 @@ static hal_error_t create_blinding_factors(hal_rsa_key_t *key, fp_int *bf, fp_in
fp_read_unsigned_bin(bf, rnd, sizeof(rnd));
fp_copy(bf, ubf);
+ /* bf = ubf ** e mod n */
if ((err = modexp(NULL, precalc, bf, key->e, key->n, bf,
key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) != HAL_OK)
goto fail;
@@ -516,13 +690,25 @@ static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *
if (key == NULL || msg == NULL || sig == NULL)
return HAL_ERROR_IMPOSSIBLE;
- const int precalc = !(key->flags & RSA_FLAG_PRECALC_PQ_DONE);
hal_error_t err = HAL_OK;
+ fp_int bf[1] = INIT_FP_INT;
+ fp_int ubf[1] = INIT_FP_INT;
+
+ if (hal_modexp_using_modexpng()) {
+ if (blinding) {
+ if ((err = create_blinding_factors(key, bf, ubf)) != HAL_OK)
+ return err;
+ return modexpng(core1, msg, key, bf, ubf, sig);
+ }
+ else {
+ return modexpng(core1, msg, key, NULL, NULL, sig);
+ }
+ }
+
+ const int precalc = !(key->flags & RSA_FLAG_PRECALC_PQ_DONE);
fp_int t[1] = INIT_FP_INT;
fp_int m1[1] = INIT_FP_INT;
fp_int m2[1] = INIT_FP_INT;
- fp_int bf[1] = INIT_FP_INT;
- fp_int ubf[1] = INIT_FP_INT;
/*
* Handle blinding if requested.
@@ -530,6 +716,7 @@ static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *
if (blinding) {
if ((err = create_blinding_factors(key, bf, ubf)) != HAL_OK)
goto fail;
+ /* msg = (msg * bf) % modulus */
FP_CHECK(fp_mulmod(msg, bf, unconst_fp_int(key->n), msg));
}
@@ -571,6 +758,7 @@ static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *
/*
* Unblind if necessary.
*/
+ /* sig = (sig * ubf) % modulus */
if (blinding)
FP_CHECK(fp_mulmod(sig, ubf, unconst_fp_int(key->n), sig));
@@ -604,6 +792,7 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core,
fp_read_unsigned_bin(i, unconst_uint8_t(input), input_len);
+ /* o = i ** e % n */
err = modexp(core, precalc, i, key->e, key->n, o,
key->nC, sizeof(key->nC), key->nF, sizeof(key->nF));
@@ -639,12 +828,17 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core1,
* just do brute force ModExp.
*/
- if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) &&
+ /* These should all be set if we generated the key, and we'll reject an
+ * externally generated key if it doesn't have all the components, so I'm
+ * not sure what the point is.
+ */
+ if (do_crt && !fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) &&
!fp_iszero(key->dP) && !fp_iszero(key->dQ))
err = rsa_crt(core1, core2, key, i, o);
else {
const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE);
+ /* o = i ** d % n */
err = modexp(core1, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC),
key->nF, sizeof(key->nF));
if (err == HAL_OK && precalc)
@@ -706,6 +900,7 @@ static hal_error_t load_key(const hal_key_type_t type,
switch (type) {
case HAL_KEY_TYPE_RSA_PRIVATE:
_(d); _(p); _(q); _(u); _(dP); _(dQ);
+ /* fall through */
case HAL_KEY_TYPE_RSA_PUBLIC:
_(n); _(e);
*key_ = key;
@@ -977,6 +1172,15 @@ hal_error_t hal_rsa_key_gen(hal_core_t *core,
key->flags |= RSA_FLAG_NEEDS_SAVING;
+#if 0
+ if (hal_modexp_using_modexpng()) {
+ modexp_precalc(key->n, key->nC, sizeof(key->nC), key->nF, sizeof(key->nF));
+ modexp_precalc(key->p, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF));
+ modexp_precalc(key->q, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF));
+ key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_PRECALC_PQ_DONE;
+ }
+#endif
+
*key_ = key;
/* Fall through to cleanup */