aboutsummaryrefslogtreecommitdiff
path: root/rsa.c
diff options
context:
space:
mode:
Diffstat (limited to 'rsa.c')
-rw-r--r--rsa.c138
1 files changed, 119 insertions, 19 deletions
diff --git a/rsa.c b/rsa.c
index dace19b..44ad84e 100644
--- a/rsa.c
+++ b/rsa.c
@@ -233,16 +233,20 @@ static hal_error_t modexp(hal_core_t *core,
uint8_t modbuf[mod_len];
uint8_t resbuf[mod_len];
+ hal_modexp_arg_t args = {
+ .core = core,
+ .msg = msgbuf, .msg_len = sizeof(msgbuf),
+ .exp = expbuf, .exp_len = sizeof(expbuf),
+ .mod = modbuf, .mod_len = sizeof(modbuf),
+ .result = resbuf, .result_len = sizeof(resbuf),
+ .coeff = coeff, .coeff_len = coeff_len,
+ .mont = mont, .mont_len = mont_len
+ };
+
if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
(err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK ||
(err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK ||
- (err = hal_modexp(core, precalc,
- msgbuf, sizeof(msgbuf),
- expbuf, sizeof(expbuf),
- modbuf, sizeof(modbuf),
- resbuf, sizeof(resbuf),
- coeff, coeff_len,
- mont, mont_len)) != HAL_OK)
+ (err = hal_modexp(precalc, &args)) != HAL_OK)
goto fail;
fp_read_unsigned_bin(res, resbuf, sizeof(resbuf));
@@ -252,6 +256,83 @@ static hal_error_t modexp(hal_core_t *core,
memset(expbuf, 0, sizeof(expbuf));
memset(modbuf, 0, sizeof(modbuf));
memset(resbuf, 0, sizeof(resbuf));
+ memset(&args, 0, sizeof(args));
+ return err;
+}
+
+static hal_error_t modexp2(const int precalc,
+ const fp_int * const msg,
+ hal_core_t *core1,
+ const fp_int * const exp1,
+ const fp_int * const mod1,
+ fp_int * res1,
+ uint8_t *coeff1, const size_t coeff1_len,
+ uint8_t *mont1, const size_t mont1_len,
+ hal_core_t *core2,
+ const fp_int * const exp2,
+ const fp_int * const mod2,
+ fp_int * res2,
+ uint8_t *coeff2, const size_t coeff2_len,
+ uint8_t *mont2, const size_t mont2_len)
+{
+ hal_error_t err = HAL_OK;
+
+ if (msg == NULL ||
+ exp1 == NULL || mod1 == NULL || res1 == NULL || coeff1 == NULL || mont1 == NULL ||
+ exp2 == NULL || mod2 == NULL || res2 == NULL || coeff2 == NULL || mont2 == NULL)
+ return HAL_ERROR_IMPOSSIBLE;
+
+ const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3;
+ const size_t exp1_len = (fp_unsigned_bin_size(unconst_fp_int(exp1)) + 3) & ~3;
+ const size_t mod1_len = (fp_unsigned_bin_size(unconst_fp_int(mod1)) + 3) & ~3;
+ const size_t exp2_len = (fp_unsigned_bin_size(unconst_fp_int(exp2)) + 3) & ~3;
+ const size_t mod2_len = (fp_unsigned_bin_size(unconst_fp_int(mod2)) + 3) & ~3;
+
+ uint8_t msgbuf[msg_len];
+ uint8_t expbuf1[exp1_len], modbuf1[mod1_len], resbuf1[mod1_len];
+ uint8_t expbuf2[exp2_len], modbuf2[mod2_len], resbuf2[mod2_len];
+
+ hal_modexp_arg_t args1 = {
+ .core = core1,
+ .msg = msgbuf, .msg_len = sizeof(msgbuf),
+ .exp = expbuf1, .exp_len = sizeof(expbuf1),
+ .mod = modbuf1, .mod_len = sizeof(modbuf1),
+ .result = resbuf1, .result_len = sizeof(resbuf1),
+ .coeff = coeff1, .coeff_len = coeff1_len,
+ .mont = mont1, .mont_len = mont1_len
+ };
+
+ hal_modexp_arg_t args2 = {
+ .core = core2,
+ .msg = msgbuf, .msg_len = sizeof(msgbuf),
+ .exp = expbuf2, .exp_len = sizeof(expbuf2),
+ .mod = modbuf2, .mod_len = sizeof(modbuf2),
+ .result = resbuf2, .result_len = sizeof(resbuf2),
+ .coeff = coeff2, .coeff_len = coeff2_len,
+ .mont = mont2, .mont_len = mont2_len
+ };
+
+ if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
+ (err = unpack_fp(exp1, expbuf1, sizeof(expbuf1))) != HAL_OK ||
+ (err = unpack_fp(mod1, modbuf1, sizeof(modbuf1))) != HAL_OK ||
+ (err = unpack_fp(exp2, expbuf2, sizeof(expbuf2))) != HAL_OK ||
+ (err = unpack_fp(mod2, modbuf2, sizeof(modbuf2))) != HAL_OK ||
+ (err = hal_modexp2(precalc, &args1, &args2)) != HAL_OK)
+ goto fail;
+
+ fp_read_unsigned_bin(res1, resbuf1, sizeof(resbuf1));
+ fp_read_unsigned_bin(res2, resbuf2, sizeof(resbuf2));
+
+ fail:
+ memset(msgbuf, 0, sizeof(msgbuf));
+ memset(expbuf1, 0, sizeof(expbuf1));
+ memset(modbuf1, 0, sizeof(modbuf1));
+ memset(resbuf1, 0, sizeof(resbuf1));
+ memset(&args1, 0, sizeof(args1));
+ memset(expbuf2, 0, sizeof(expbuf2));
+ memset(modbuf2, 0, sizeof(modbuf2));
+ memset(resbuf2, 0, sizeof(resbuf2));
+ memset(&args2, 0, sizeof(args2));
return err;
}
@@ -280,6 +361,28 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */
return err;
}
+static hal_error_t modexp2(const int precalc, /* ignored */
+ const fp_int * const msg,
+ hal_core_t *core1, /* ignored */
+ const fp_int * const exp1,
+ const fp_int * const mod1,
+ fp_int * res1,
+ uint8_t *coeff1, const size_t coeff1_len, /* ignored */
+ uint8_t *mont1, const size_t mont1_len, /* ignored */
+ hal_core_t *core2, /* ignored */
+ const fp_int * const exp2,
+ const fp_int * const mod2,
+ fp_int * res2,
+ uint8_t *coeff2, const size_t coeff2_len, /* ignored */
+ uint8_t *mont2, const size_t mont2_len) /* ignored */
+{
+ hal_error_t err = HAL_OK;
+ FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp1), unconst_fp_int(mod1), res1));
+ FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp2), unconst_fp_int(mod2), res2));
+ fail:
+ return err;
+}
+
#endif /* HAL_RSA_SIGN_USE_MODEXP */
/*
@@ -351,7 +454,7 @@ static hal_error_t create_blinding_factors(hal_core_t *core, hal_rsa_key_t *key,
* RSA decryption via Chinese Remainder Theorem (Garner's formula).
*/
-static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp_int *sig)
+static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *key, fp_int *msg, fp_int *sig)
{
if (key == NULL || msg == NULL || sig == NULL)
return HAL_ERROR_IMPOSSIBLE;
@@ -368,7 +471,7 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp
* Handle blinding if requested.
*/
if (blinding) {
- if ((err = create_blinding_factors(core, key, bf, ubf)) != HAL_OK)
+ if ((err = create_blinding_factors(core1, key, bf, ubf)) != HAL_OK)
goto fail;
FP_CHECK(fp_mulmod(msg, bf, unconst_fp_int(key->n), msg));
}
@@ -376,14 +479,10 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp
/*
* m1 = msg ** dP mod p
* m2 = msg ** dQ mod q
- *
- * This is just crying out to be done with parallel cores, but get
- * the boring version working before jumping off that cliff.
*/
- if ((err = modexp(core, precalc, msg, key->dP, key->p, m1,
- key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK ||
- (err = modexp(core, precalc, msg, key->dQ, key->q, m2,
- key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK)
+ if ((err = modexp2(precalc, msg,
+ core1, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF),
+ core2, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK)
goto fail;
if (precalc)
@@ -462,7 +561,8 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core,
return err;
}
-hal_error_t hal_rsa_decrypt(hal_core_t *core,
+hal_error_t hal_rsa_decrypt(hal_core_t *core1,
+ hal_core_t *core2,
hal_rsa_key_t *key,
const uint8_t * const input, const size_t input_len,
uint8_t * output, const size_t output_len)
@@ -484,11 +584,11 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core,
if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) &&
!fp_iszero(key->dP) && !fp_iszero(key->dQ))
- err = rsa_crt(core, key, i, o);
+ err = rsa_crt(core1, core2, key, i, o);
else {
const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE);
- err = modexp(core, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC),
+ err = modexp(core1, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC),
key->nF, sizeof(key->nF));
if (err == HAL_OK && precalc)
key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING;