From 410e0cf1d22c67585f0a5346e62f60aa4e90fe05 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Wed, 13 Sep 2017 20:20:55 -0400 Subject: Preliminary support for parallel core RSA CRT. --- Makefile | 7 +- core.c | 39 +++++----- hal.h | 32 +++++--- hal_internal.h | 12 +++ hal_io_eim.c | 29 ------- hal_io_fmc.c | 31 -------- hal_io_i2c.c | 26 ------- modexp.c | 227 +++++++++++++++++++++++++++++++++---------------------- rpc_pkey.c | 6 +- rsa.c | 138 ++++++++++++++++++++++++++++----- tests/test-rsa.c | 17 ++++- 11 files changed, 329 insertions(+), 235 deletions(-) diff --git a/Makefile b/Makefile index ae6888d..59236af 100644 --- a/Makefile +++ b/Makefile @@ -109,12 +109,13 @@ CORE_OBJ = core.o csprng.o pbkdf2.o aes_keywrap.o modexp.o mkmif.o ${IO_OBJ} # i2c: Older I2C bus from Novena # fmc: FMC bus from dev-bridge and alpha boards +IO_OBJ = hal_io.o ifeq "${IO_BUS}" "eim" - IO_OBJ = hal_io_eim.o novena-eim.o + IO_OBJ += hal_io_eim.o novena-eim.o else ifeq "${IO_BUS}" "i2c" - IO_OBJ = hal_io_i2c.o + IO_OBJ += hal_io_i2c.o else ifeq "${IO_BUS}" "fmc" - IO_OBJ = hal_io_fmc.o + IO_OBJ += hal_io_fmc.o endif # If we're building for STM32, position-independent code leads to some diff --git a/core.c b/core.c index 8e9f2b2..32823a6 100644 --- a/core.c +++ b/core.c @@ -97,7 +97,7 @@ static int name_matches(const hal_core_t *const core, const char * const name) static const struct { const char *name; hal_addr_t extra; } gaps[] = { { "csprng", 11 * CORE_SIZE }, /* empty slots after csprng */ { "modexps6", 3 * CORE_SIZE }, /* ModexpS6 uses four slots */ - { "modexpa7", 3 * CORE_SIZE }, /* ModexpA7 uses four slots */ + { "modexpa7", 7 * CORE_SIZE }, /* ModexpA7 uses eight slots */ }; static hal_core_t *head = NULL; @@ -203,15 +203,17 @@ hal_core_t *hal_core_find(const char *name, hal_core_t *core) hal_error_t hal_core_alloc(const char *name, hal_core_t **pcore) { - hal_core_t *core; - hal_error_t err = HAL_ERROR_CORE_NOT_FOUND; + /* + * This used to allow name == NULL iff *core != NULL, but the + * semantics were fragile and in practice we always pass a name + * anyway, so simplify by requiring name != NULL, always. + */ - if (name == NULL && (pcore == NULL || *pcore == NULL)) + if (name == NULL || pcore == NULL) return HAL_ERROR_BAD_ARGUMENTS; - core = *pcore; - if (name == NULL) - name = core->info.name; + hal_error_t err = HAL_ERROR_CORE_NOT_FOUND; + hal_core_t *core = *pcore; if (core != NULL) { /* if we can reallocate the same core, do it now */ @@ -221,24 +223,23 @@ hal_error_t hal_core_alloc(const char *name, hal_core_t **pcore) hal_critical_section_end(); return HAL_OK; } - /* else fall through to search */ + /* else forget that core and fall through to search */ + *pcore = NULL; } while (1) { hal_critical_section_start(); for (core = hal_core_iterate(NULL); core != NULL; core = core->next) { - if (name_matches(core, name)) { - if (core->busy) { - err = HAL_ERROR_CORE_BUSY; - continue; - } - else { - err = HAL_OK; - *pcore = core; - core->busy = 1; - break; - } + if (!name_matches(core, name)) + continue; + if (core->busy) { + err = HAL_ERROR_CORE_BUSY; + continue; } + err = HAL_OK; + *pcore = core; + core->busy = 1; + break; } hal_critical_section_end(); if (err == HAL_ERROR_CORE_BUSY) diff --git a/hal.h b/hal.h index f7a7522..c017b2d 100644 --- a/hal.h +++ b/hal.h @@ -201,7 +201,8 @@ typedef struct hal_core hal_core_t; extern void hal_io_set_debug(int onoff); extern hal_error_t hal_io_write(const hal_core_t *core, hal_addr_t offset, const uint8_t *buf, size_t len); extern hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf, size_t len); -extern hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count); +extern hal_error_t hal_io_wait(const hal_core_t *core, const uint8_t status, int *count); +extern hal_error_t hal_io_wait2(const hal_core_t *core1, const hal_core_t *core2, const uint8_t status, int *count); /* * Core management functions. @@ -368,19 +369,25 @@ extern hal_error_t hal_pbkdf2(hal_core_t *core, unsigned iterations_desired); /* - * Modular exponentiation. + * Modular exponentiation. This takes a ridiculous number of + * arguments of very similar types, making it easy to confuse them, + * particularly when performing two modexp operations in parallel, so + * we encapsulate the arguments in a structure. */ -extern void hal_modexp_set_debug(const int onoff); +typedef struct { + hal_core_t *core; + const uint8_t *msg; size_t msg_len; /* Message */ + const uint8_t *exp; size_t exp_len; /* Exponent */ + const uint8_t *mod; size_t mod_len; /* Modulus */ + uint8_t *result; size_t result_len; /* Result of exponentiation */ + uint8_t *coeff; size_t coeff_len; /* Modulus coefficient (r/w) */ + uint8_t *mont; size_t mont_len; /* Montgomery factor (r/w)*/ +} hal_modexp_arg_t; -extern hal_error_t hal_modexp(hal_core_t *core, - const int precalc, - const uint8_t * const msg, const size_t msg_len, /* Message */ - const uint8_t * const exp, const size_t exp_len, /* Exponent */ - const uint8_t * const mod, const size_t mod_len, /* Modulus */ - uint8_t * result, const size_t result_len, /* Result of exponentiation */ - uint8_t * coeff, const size_t coeff_len, /* Modulus coefficient (r/w) */ - uint8_t * mont, const size_t mont_len); /* Montgomery factor (r/w)*/ +extern void hal_modexp_set_debug(const int onoff); +extern hal_error_t hal_modexp( const int precalc, hal_modexp_arg_t *args); +extern hal_error_t hal_modexp2(const int precalc, hal_modexp_arg_t *args1, hal_modexp_arg_t *args2); /* * Master Key Memory Interface @@ -462,7 +469,8 @@ extern hal_error_t hal_rsa_encrypt(hal_core_t *core, const uint8_t * const input, const size_t input_len, uint8_t * output, const size_t output_len); -extern hal_error_t hal_rsa_decrypt(hal_core_t *core, +extern hal_error_t hal_rsa_decrypt(hal_core_t *core1, + hal_core_t *core2, hal_rsa_key_t *key, const uint8_t * const input, const size_t input_len, uint8_t * output, const size_t output_len); diff --git a/hal_internal.h b/hal_internal.h index a60d0b5..ac51cfb 100644 --- a/hal_internal.h +++ b/hal_internal.h @@ -103,6 +103,18 @@ static inline hal_error_t hal_io_wait_valid(const hal_core_t *core) return hal_io_wait(core, STATUS_VALID, &limit); } +static inline hal_error_t hal_io_wait_ready2(const hal_core_t *core1, const hal_core_t *core2) +{ + int limit = -1; + return hal_io_wait2(core1, core2, STATUS_READY, &limit); +} + +static inline hal_error_t hal_io_wait_valid2(const hal_core_t *core1, const hal_core_t *core2) +{ + int limit = -1; + return hal_io_wait2(core1, core2, STATUS_VALID, &limit); +} + /* * Static memory allocation on start-up. Don't use this except where * really necessary. By design, there's no way to free this, we don't diff --git a/hal_io_eim.c b/hal_io_eim.c index eabc42e..040cb2b 100644 --- a/hal_io_eim.c +++ b/hal_io_eim.c @@ -43,10 +43,6 @@ static int debug = 0; static int inited = 0; -#ifndef EIM_IO_TIMEOUT -#define EIM_IO_TIMEOUT 100000000 -#endif - static inline hal_error_t init(void) { if (inited) @@ -134,31 +130,6 @@ hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf, return HAL_OK; } -hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count) -{ - hal_error_t err; - uint8_t buf[4]; - int i; - - if (count && *count == -1) - *count = EIM_IO_TIMEOUT; - - for (i = 1; ; ++i) { - - if (count && (*count > 0) && (i >= *count)) - return HAL_ERROR_IO_TIMEOUT; - - if ((err = hal_io_read(core, ADDR_STATUS, buf, sizeof(buf))) != HAL_OK) - return err; - - if ((buf[3] & status) != 0) { - if (count) - *count = i; - return HAL_OK; - } - } -} - /* * Local variables: * indent-tabs-mode: nil diff --git a/hal_io_fmc.c b/hal_io_fmc.c index 5ac73c4..0d49f1e 100644 --- a/hal_io_fmc.c +++ b/hal_io_fmc.c @@ -47,10 +47,6 @@ static int debug = 0; static int inited = 0; -#ifndef FMC_IO_TIMEOUT -#define FMC_IO_TIMEOUT 100000000 -#endif - static inline hal_error_t init(void) { if (!inited) { @@ -136,33 +132,6 @@ hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf, return HAL_OK; } -hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count) -{ - hal_error_t err; - uint8_t buf[4]; - int i; - - if (count && *count == -1) - *count = FMC_IO_TIMEOUT; - - for (i = 1; ; ++i) { - - if (count && (*count > 0) && (i >= *count)) - return HAL_ERROR_IO_TIMEOUT; - - hal_task_yield(); - - if ((err = hal_io_read(core, ADDR_STATUS, buf, sizeof(buf))) != HAL_OK) - return err; - - if ((buf[3] & status) != 0) { - if (count) - *count = i; - return HAL_OK; - } - } -} - /* * Local variables: * indent-tabs-mode: nil diff --git a/hal_io_i2c.c b/hal_io_i2c.c index 018e264..8596174 100644 --- a/hal_io_i2c.c +++ b/hal_io_i2c.c @@ -301,32 +301,6 @@ hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf, return HAL_OK; } -hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count) -{ - hal_error_t err; - uint8_t buf[4]; - int i; - - if (count && *count == -1) - *count = 10; - - for (i = 1; ; ++i) { - - if (count && (*count > 0) && (i >= *count)) - return HAL_ERROR_IO_TIMEOUT; - - if ((err = hal_io_read(core, ADDR_STATUS, buf, 4)) != HAL_OK) - return err; - - if (buf[3] & status) { - if (count) - *count = i; - return HAL_OK; - - } - } -} - /* * Local variables: * indent-tabs-mode: nil diff --git a/modexp.c b/modexp.c index 12b5789..7973258 100644 --- a/modexp.c +++ b/modexp.c @@ -157,125 +157,174 @@ static inline hal_error_t set_buffer(const hal_core_t *core, } /* - * Check a result, report on failure if debugging, pass failures up - * the chain. - */ - -#define check(_expr_) \ - do { \ - hal_error_t _err = (_expr_); \ - if (_err != HAL_OK && debug) \ - hal_log(HAL_LOG_WARN, "%s failed: %s\n", #_expr_, hal_error_string(_err)); \ - if (_err != HAL_OK) { \ - hal_core_free(core); \ - return _err; \ - } \ - } while (0) - -/* - * Run one modexp operation. + * Stuff moved out of modexp so we can run two cores in parallel more + * easily. We have to return to the jacket routine every time we kick + * a core into doing something, since only the jacket routines know + * how many cores we're running for any particular calculation. + * + * In theory we could do something clever where we don't wait for both + * cores to finish precalc before starting either of them on the main + * computation, but that way probably lies madness. */ -hal_error_t hal_modexp(hal_core_t *core, - const int precalc, - const uint8_t * const msg, const size_t msg_len, /* Message */ - const uint8_t * const exp, const size_t exp_len, /* Exponent */ - const uint8_t * const mod, const size_t mod_len, /* Modulus */ - uint8_t *result, const size_t result_len, /* Result of exponentiation */ - uint8_t *coeff, const size_t coeff_len, /* Modulus coefficient (r/w) */ - uint8_t *mont, const size_t mont_len) /* Montgomery factor (r/w)*/ +static inline hal_error_t check_args(hal_modexp_arg_t *a) { - hal_error_t err; - /* - * All pointers must be set, exponent may not be longer than + * All data pointers must be set, exponent may not be longer than * modulus, message may not be longer than twice the modulus (CRT * mode), result buffer must not be shorter than modulus, and all * input lengths must be a multiple of four bytes (the core is all * about 32-bit words). */ - if (msg == NULL || msg_len > MODEXPA7_OPERAND_BYTES || msg_len > mod_len * 2 || - exp == NULL || exp_len > MODEXPA7_OPERAND_BYTES || exp_len > mod_len || - mod == NULL || mod_len > MODEXPA7_OPERAND_BYTES || - result == NULL || result_len > MODEXPA7_OPERAND_BYTES || result_len < mod_len || - coeff == NULL || coeff_len > MODEXPA7_OPERAND_BYTES || - mont == NULL || mont_len > MODEXPA7_OPERAND_BYTES || - ((msg_len | exp_len | mod_len) & 3) != 0) + if (a == NULL || + a->msg == NULL || a->msg_len > MODEXPA7_OPERAND_BYTES || a->msg_len > a->mod_len * 2 || + a->exp == NULL || a->exp_len > MODEXPA7_OPERAND_BYTES || a->exp_len > a->mod_len || + a->mod == NULL || a->mod_len > MODEXPA7_OPERAND_BYTES || + a->result == NULL || a->result_len > MODEXPA7_OPERAND_BYTES || a->result_len < a->mod_len || + a->coeff == NULL || a->coeff_len > MODEXPA7_OPERAND_BYTES || + a->mont == NULL || a->mont_len > MODEXPA7_OPERAND_BYTES || + ((a->msg_len | a->exp_len | a->mod_len) & 3) != 0) return HAL_ERROR_BAD_ARGUMENTS; - /* - * Gonna need to think about running two modexpa7 cores in parallel - * in CRT mode for full speed signature. - */ + return HAL_OK; +} - if (((err = hal_core_alloc(MODEXPA7_NAME, &core)) != HAL_OK)) - return err; +static inline hal_error_t setup_precalc(const int precalc, hal_modexp_arg_t *a) +{ + hal_error_t err; /* - * Now that we have the core, check operand length against what it - * says it can handle. + * Check that operand size is compatabible with the core. */ uint32_t operand_max = 0; - check(get_register(core, MODEXPA7_ADDR_BUFFER_BITS, &operand_max)); + + if ((err = get_register(a->core, MODEXPA7_ADDR_BUFFER_BITS, &operand_max)) != HAL_OK) + return err; + operand_max /= 8; - if (msg_len > operand_max || - exp_len > operand_max || - mod_len > operand_max || - coeff_len > operand_max || - mont_len > operand_max) { - hal_core_free(core); + if (a->msg_len > operand_max || + a->exp_len > operand_max || + a->mod_len > operand_max || + a->coeff_len > operand_max || + a->mont_len > operand_max) return HAL_ERROR_BAD_ARGUMENTS; - } - /* Set modulus */ + /* + * Set the modulus, then initiate calculation of modulus-dependent + * speedup factors if necessary, by edge-triggering the "init" bit, + * then return to caller so it can wait for precalc. + */ + + if ((err = set_register(a->core, MODEXPA7_ADDR_MODULUS_BITS, a->mod_len * 8)) != HAL_OK || + (err = set_buffer(a->core, MODEXPA7_ADDR_MODULUS, a->mod, a->mod_len)) != HAL_OK || + (precalc && (err = hal_io_zero(a->core)) != HAL_OK) || + (precalc && (err = hal_io_init(a->core)) != HAL_OK)) + return err; + + return HAL_OK; +} + +static inline hal_error_t setup_calc(const int precalc, hal_modexp_arg_t *a) +{ + hal_error_t err; + + /* + * Select CRT mode if and only if message is longer than exponent. + */ - check(set_register(core, MODEXPA7_ADDR_MODULUS_BITS, mod_len * 8)); - check(set_buffer(core, MODEXPA7_ADDR_MODULUS, mod, mod_len)); + const uint32_t mode = a->msg_len > a->mod_len ? MODEXPA7_MODE_CRT : MODEXPA7_MODE_PLAIN; /* - * Calculate modulus-dependent speedup factors if needed. Buffer - * space is always caller's problem (because caller almost certainly - * wants to stash these values in the keystore anyway). Calculation - * is edge-triggered by "init" bit going from zero to one. + * Copy out precalc results if necessary, then load everything and + * start the calculation by edge-triggering the "next" bit. If + * everything works, return to caller so it can wait for the + * calculation to complete. */ - if (precalc) { - check(hal_io_zero(core)); - check(hal_io_init(core)); - check(hal_io_wait_ready(core)); - check(get_buffer(core, MODEXPA7_ADDR_MODULUS_COEFF_OUT, coeff, coeff_len)); - check(get_buffer(core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_OUT, mont, mont_len)); - } - - /* Load modulus-dependent speedup factors (even if we just calculated them) */ - check(set_buffer(core, MODEXPA7_ADDR_MODULUS_COEFF_IN, coeff, coeff_len)); - check(set_buffer(core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_IN, mont, mont_len)); - - /* Select CRT mode if and only if message is longer than exponent */ - check(set_register(core, MODEXPA7_ADDR_MODE, - (msg_len > mod_len - ? MODEXPA7_MODE_CRT - : MODEXPA7_MODE_PLAIN))); - - /* Set message and exponent */ - check(set_buffer(core, MODEXPA7_ADDR_MESSAGE, msg, msg_len)); - check(set_buffer(core, MODEXPA7_ADDR_EXPONENT, exp, exp_len)); - check(set_register(core, MODEXPA7_ADDR_EXPONENT_BITS, exp_len * 8)); - - /* Edge-trigger the "next" bit to start calculation, then wait for the result */ - check(hal_io_zero(core)); - check(hal_io_next(core)); - check(hal_io_wait_valid(core)); - - /* Extract result, clean up, then done */ - check(get_buffer(core, MODEXPA7_ADDR_RESULT, result, mod_len)); - hal_core_free(core); + if ((precalc && + (err = get_buffer(a->core, MODEXPA7_ADDR_MODULUS_COEFF_OUT, a->coeff, a->coeff_len)) != HAL_OK) || + (precalc && + (err = get_buffer(a->core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_OUT, a->mont, a->mont_len)) != HAL_OK) || + (err = set_buffer(a->core, MODEXPA7_ADDR_MODULUS_COEFF_IN, a->coeff, a->coeff_len)) != HAL_OK || + (err = set_buffer(a->core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_IN, a->mont, a->mont_len)) != HAL_OK || + (err = set_register(a->core, MODEXPA7_ADDR_MODE, mode)) != HAL_OK || + (err = set_buffer(a->core, MODEXPA7_ADDR_MESSAGE, a->msg, a->msg_len)) != HAL_OK || + (err = set_buffer(a->core, MODEXPA7_ADDR_EXPONENT, a->exp, a->exp_len)) != HAL_OK || + (err = set_register(a->core, MODEXPA7_ADDR_EXPONENT_BITS, a->exp_len * 8)) != HAL_OK || + (err = hal_io_zero(a->core)) != HAL_OK || + (err = hal_io_next(a->core)) != HAL_OK) + return err; + return HAL_OK; } +static inline hal_error_t extract_result(hal_modexp_arg_t *a) +{ + /* + * Extract results from the main calculation and we're done. + * Hardly seems worth making this a separate function. + */ + + return get_buffer(a->core, MODEXPA7_ADDR_RESULT, a->result, a->mod_len); +} + +/* + * Run one modexp operation. + */ + +hal_error_t hal_modexp(const int precalc, hal_modexp_arg_t *a) +{ + hal_error_t err; + + if ((err = check_args(a)) != HAL_OK) + return err; + + if ((err = hal_core_alloc(MODEXPA7_NAME, &a->core)) == HAL_OK && + (err = setup_precalc(precalc, a)) == HAL_OK && + (!precalc || + (err = hal_io_wait_ready(a->core)) == HAL_OK) && + (err = setup_calc(precalc, a)) == HAL_OK && + (err = hal_io_wait_valid(a->core)) == HAL_OK && + (err = extract_result(a)) == HAL_OK) + err = HAL_OK; + + hal_core_free(a->core); + return err; +} + +/* + * Run two modexp operations in parallel. + */ + +hal_error_t hal_modexp2(const int precalc, hal_modexp_arg_t *a1, hal_modexp_arg_t *a2) +{ + hal_error_t err; + + if ((err = check_args(a1)) != HAL_OK || + (err = check_args(a2)) != HAL_OK) + return err; + + if ((err = hal_core_alloc(MODEXPA7_NAME, &a1->core)) == HAL_OK && + (err = hal_core_alloc(MODEXPA7_NAME, &a2->core)) == HAL_OK && + (err = setup_precalc(precalc, a1)) == HAL_OK && + (err = setup_precalc(precalc, a2)) == HAL_OK && + (!precalc || + (err = hal_io_wait_ready2(a1->core, a2->core)) == HAL_OK) && + (err = setup_calc(precalc, a1)) == HAL_OK && + (err = setup_calc(precalc, a2)) == HAL_OK && + (err = hal_io_wait_valid2(a1->core, a2->core)) == HAL_OK && + (err = extract_result(a1)) == HAL_OK && + (err = extract_result(a2)) == HAL_OK) + err = HAL_OK; + + hal_core_free(a1->core); + hal_core_free(a2->core); + return err; +} + /* * Local variables: * indent-tabs-mode: nil diff --git a/rpc_pkey.c b/rpc_pkey.c index 53d3214..9d8975f 100644 --- a/rpc_pkey.c +++ b/rpc_pkey.c @@ -760,8 +760,8 @@ static hal_error_t pkey_local_sign_rsa(hal_pkey_slot_t *slot, input = signature; } - if ((err = pkcs1_5_pad(input, input_len, signature, *signature_len, 0x01)) != HAL_OK || - (err = hal_rsa_decrypt(NULL, key, signature, *signature_len, signature, *signature_len)) != HAL_OK) + if ((err = pkcs1_5_pad(input, input_len, signature, *signature_len, 0x01)) != HAL_OK || + (err = hal_rsa_decrypt(NULL, NULL, key, signature, *signature_len, signature, *signature_len)) != HAL_OK) return err; if (hal_rsa_key_needs_saving(key)) { @@ -1276,7 +1276,7 @@ static hal_error_t pkey_local_import(const hal_client_handle_t client, goto fail; } - if ((err = hal_rsa_decrypt(NULL, rsa, data, data_len, der, data_len)) != HAL_OK) + if ((err = hal_rsa_decrypt(NULL, NULL, rsa, data, data_len, der, data_len)) != HAL_OK) goto fail; if ((err = hal_get_random(NULL, kek, sizeof(kek))) != HAL_OK) diff --git a/rsa.c b/rsa.c index dace19b..44ad84e 100644 --- a/rsa.c +++ b/rsa.c @@ -233,16 +233,20 @@ static hal_error_t modexp(hal_core_t *core, uint8_t modbuf[mod_len]; uint8_t resbuf[mod_len]; + hal_modexp_arg_t args = { + .core = core, + .msg = msgbuf, .msg_len = sizeof(msgbuf), + .exp = expbuf, .exp_len = sizeof(expbuf), + .mod = modbuf, .mod_len = sizeof(modbuf), + .result = resbuf, .result_len = sizeof(resbuf), + .coeff = coeff, .coeff_len = coeff_len, + .mont = mont, .mont_len = mont_len + }; + if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK || (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK || (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK || - (err = hal_modexp(core, precalc, - msgbuf, sizeof(msgbuf), - expbuf, sizeof(expbuf), - modbuf, sizeof(modbuf), - resbuf, sizeof(resbuf), - coeff, coeff_len, - mont, mont_len)) != HAL_OK) + (err = hal_modexp(precalc, &args)) != HAL_OK) goto fail; fp_read_unsigned_bin(res, resbuf, sizeof(resbuf)); @@ -252,6 +256,83 @@ static hal_error_t modexp(hal_core_t *core, memset(expbuf, 0, sizeof(expbuf)); memset(modbuf, 0, sizeof(modbuf)); memset(resbuf, 0, sizeof(resbuf)); + memset(&args, 0, sizeof(args)); + return err; +} + +static hal_error_t modexp2(const int precalc, + const fp_int * const msg, + hal_core_t *core1, + const fp_int * const exp1, + const fp_int * const mod1, + fp_int * res1, + uint8_t *coeff1, const size_t coeff1_len, + uint8_t *mont1, const size_t mont1_len, + hal_core_t *core2, + const fp_int * const exp2, + const fp_int * const mod2, + fp_int * res2, + uint8_t *coeff2, const size_t coeff2_len, + uint8_t *mont2, const size_t mont2_len) +{ + hal_error_t err = HAL_OK; + + if (msg == NULL || + exp1 == NULL || mod1 == NULL || res1 == NULL || coeff1 == NULL || mont1 == NULL || + exp2 == NULL || mod2 == NULL || res2 == NULL || coeff2 == NULL || mont2 == NULL) + return HAL_ERROR_IMPOSSIBLE; + + const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3; + const size_t exp1_len = (fp_unsigned_bin_size(unconst_fp_int(exp1)) + 3) & ~3; + const size_t mod1_len = (fp_unsigned_bin_size(unconst_fp_int(mod1)) + 3) & ~3; + const size_t exp2_len = (fp_unsigned_bin_size(unconst_fp_int(exp2)) + 3) & ~3; + const size_t mod2_len = (fp_unsigned_bin_size(unconst_fp_int(mod2)) + 3) & ~3; + + uint8_t msgbuf[msg_len]; + uint8_t expbuf1[exp1_len], modbuf1[mod1_len], resbuf1[mod1_len]; + uint8_t expbuf2[exp2_len], modbuf2[mod2_len], resbuf2[mod2_len]; + + hal_modexp_arg_t args1 = { + .core = core1, + .msg = msgbuf, .msg_len = sizeof(msgbuf), + .exp = expbuf1, .exp_len = sizeof(expbuf1), + .mod = modbuf1, .mod_len = sizeof(modbuf1), + .result = resbuf1, .result_len = sizeof(resbuf1), + .coeff = coeff1, .coeff_len = coeff1_len, + .mont = mont1, .mont_len = mont1_len + }; + + hal_modexp_arg_t args2 = { + .core = core2, + .msg = msgbuf, .msg_len = sizeof(msgbuf), + .exp = expbuf2, .exp_len = sizeof(expbuf2), + .mod = modbuf2, .mod_len = sizeof(modbuf2), + .result = resbuf2, .result_len = sizeof(resbuf2), + .coeff = coeff2, .coeff_len = coeff2_len, + .mont = mont2, .mont_len = mont2_len + }; + + if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK || + (err = unpack_fp(exp1, expbuf1, sizeof(expbuf1))) != HAL_OK || + (err = unpack_fp(mod1, modbuf1, sizeof(modbuf1))) != HAL_OK || + (err = unpack_fp(exp2, expbuf2, sizeof(expbuf2))) != HAL_OK || + (err = unpack_fp(mod2, modbuf2, sizeof(modbuf2))) != HAL_OK || + (err = hal_modexp2(precalc, &args1, &args2)) != HAL_OK) + goto fail; + + fp_read_unsigned_bin(res1, resbuf1, sizeof(resbuf1)); + fp_read_unsigned_bin(res2, resbuf2, sizeof(resbuf2)); + + fail: + memset(msgbuf, 0, sizeof(msgbuf)); + memset(expbuf1, 0, sizeof(expbuf1)); + memset(modbuf1, 0, sizeof(modbuf1)); + memset(resbuf1, 0, sizeof(resbuf1)); + memset(&args1, 0, sizeof(args1)); + memset(expbuf2, 0, sizeof(expbuf2)); + memset(modbuf2, 0, sizeof(modbuf2)); + memset(resbuf2, 0, sizeof(resbuf2)); + memset(&args2, 0, sizeof(args2)); return err; } @@ -280,6 +361,28 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */ return err; } +static hal_error_t modexp2(const int precalc, /* ignored */ + const fp_int * const msg, + hal_core_t *core1, /* ignored */ + const fp_int * const exp1, + const fp_int * const mod1, + fp_int * res1, + uint8_t *coeff1, const size_t coeff1_len, /* ignored */ + uint8_t *mont1, const size_t mont1_len, /* ignored */ + hal_core_t *core2, /* ignored */ + const fp_int * const exp2, + const fp_int * const mod2, + fp_int * res2, + uint8_t *coeff2, const size_t coeff2_len, /* ignored */ + uint8_t *mont2, const size_t mont2_len) /* ignored */ +{ + hal_error_t err = HAL_OK; + FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp1), unconst_fp_int(mod1), res1)); + FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp2), unconst_fp_int(mod2), res2)); + fail: + return err; +} + #endif /* HAL_RSA_SIGN_USE_MODEXP */ /* @@ -351,7 +454,7 @@ static hal_error_t create_blinding_factors(hal_core_t *core, hal_rsa_key_t *key, * RSA decryption via Chinese Remainder Theorem (Garner's formula). */ -static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp_int *sig) +static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *key, fp_int *msg, fp_int *sig) { if (key == NULL || msg == NULL || sig == NULL) return HAL_ERROR_IMPOSSIBLE; @@ -368,7 +471,7 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp * Handle blinding if requested. */ if (blinding) { - if ((err = create_blinding_factors(core, key, bf, ubf)) != HAL_OK) + if ((err = create_blinding_factors(core1, key, bf, ubf)) != HAL_OK) goto fail; FP_CHECK(fp_mulmod(msg, bf, unconst_fp_int(key->n), msg)); } @@ -376,14 +479,10 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp /* * m1 = msg ** dP mod p * m2 = msg ** dQ mod q - * - * This is just crying out to be done with parallel cores, but get - * the boring version working before jumping off that cliff. */ - if ((err = modexp(core, precalc, msg, key->dP, key->p, m1, - key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK || - (err = modexp(core, precalc, msg, key->dQ, key->q, m2, - key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK) + if ((err = modexp2(precalc, msg, + core1, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF), + core2, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK) goto fail; if (precalc) @@ -462,7 +561,8 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core, return err; } -hal_error_t hal_rsa_decrypt(hal_core_t *core, +hal_error_t hal_rsa_decrypt(hal_core_t *core1, + hal_core_t *core2, hal_rsa_key_t *key, const uint8_t * const input, const size_t input_len, uint8_t * output, const size_t output_len) @@ -484,11 +584,11 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core, if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) && !fp_iszero(key->dP) && !fp_iszero(key->dQ)) - err = rsa_crt(core, key, i, o); + err = rsa_crt(core1, core2, key, i, o); else { const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); - err = modexp(core, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC), + err = modexp(core1, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC), key->nF, sizeof(key->nF)); if (err == HAL_OK && precalc) key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; diff --git a/tests/test-rsa.c b/tests/test-rsa.c index 9ba9889..e73feea 100644 --- a/tests/test-rsa.c +++ b/tests/test-rsa.c @@ -60,8 +60,17 @@ static int test_modexp(hal_core_t *core, printf("%s test for %lu-bit RSA key\n", kind, (unsigned long) tc->size); - if (hal_modexp(core, 0, msg->val, msg->len, exp->val, exp->len, - tc->n.val, tc->n.len, result, sizeof(result), C, sizeof(C), F, sizeof(F)) != HAL_OK) + hal_modexp_arg_t args = { + .core = core, + .msg = msg->val, .msg_len = msg->len, + .exp = exp->val, .exp_len = exp->len, + .mod = tc->n.val, .mod_len = tc->n.len, + .result = result, .result_len = sizeof(result), + .coeff = C, .coeff_len = sizeof(C), + .mont = F, .mont_len = sizeof(F) + }; + + if (hal_modexp(1, &args) != HAL_OK) return printf("ModExp failed\n"), 0; if (memcmp(result, val->val, val->len)) @@ -98,7 +107,7 @@ static int test_decrypt(hal_core_t *core, uint8_t result[tc->n.len]; - if ((err = hal_rsa_decrypt(core, key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) + if ((err = hal_rsa_decrypt(core, NULL, key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) printf("RSA CRT failed: %s\n", hal_error_string(err)); const int mismatch = (err == HAL_OK && memcmp(result, tc->s.val, tc->s.len) != 0); @@ -165,7 +174,7 @@ static int test_gen(hal_core_t *core, uint8_t result[tc->n.len]; - if ((err = hal_rsa_decrypt(core, key1, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) + if ((err = hal_rsa_decrypt(core, NULL, key1, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) printf("RSA CRT failed: %s\n", hal_error_string(err)); snprintf(fn, sizeof(fn), "test-rsa-sig-%04lu.der", (unsigned long) tc->size); -- cgit v1.2.3