diff options
Diffstat (limited to 'rsa.c')
-rw-r--r-- | rsa.c | 551 |
1 files changed, 464 insertions, 87 deletions
@@ -70,7 +70,6 @@ #include <stdlib.h> #include <stddef.h> #include <string.h> -#include <assert.h> #include "hal.h" #include "hal_internal.h" @@ -78,12 +77,15 @@ #include "asn1_internal.h" /* - * Whether to use ModExp core. It works, but at the moment it's so - * slow that a full test run can take more than an hour. + * Whether to use ModExp core. It works, but it's painfully slow. */ -#ifndef HAL_RSA_USE_MODEXP -#define HAL_RSA_USE_MODEXP 1 +#ifndef HAL_RSA_SIGN_USE_MODEXP +#define HAL_RSA_SIGN_USE_MODEXP 1 +#endif + +#ifndef HAL_RSA_KEYGEN_USE_MODEXP +#define HAL_RSA_KEYGEN_USE_MODEXP 0 #endif #if defined(RPC_CLIENT) && RPC_CLIENT != RPC_CLIENT_LOCAL @@ -91,6 +93,15 @@ #endif /* + * How big to make the buffers for the modulus coefficient and + * Montgomery factor. This will almost certainly want tuning. + */ + +#ifndef HAL_RSA_MAX_OPERAND_LENGTH +#define HAL_RSA_MAX_OPERAND_LENGTH MODEXPA7_OPERAND_BYTES +#endif + +/* * Whether we want debug output. */ @@ -120,7 +131,7 @@ void hal_rsa_set_blinding(const int onoff) */ struct hal_rsa_key { - hal_key_type_t type; /* What kind of key this is */ + hal_key_type_t type; /* What kind of key this is */ fp_int n[1]; /* The modulus */ fp_int e[1]; /* Public exponent */ fp_int d[1]; /* Private exponent */ @@ -129,8 +140,17 @@ struct hal_rsa_key { fp_int u[1]; /* 1/q mod p */ fp_int dP[1]; /* d mod (p - 1) */ fp_int dQ[1]; /* d mod (q - 1) */ + unsigned flags; /* Internal key flags */ + uint8_t /* ModExpA7 speedup factors */ + nC[HAL_RSA_MAX_OPERAND_LENGTH], nF[HAL_RSA_MAX_OPERAND_LENGTH], + pC[HAL_RSA_MAX_OPERAND_LENGTH/2], pF[HAL_RSA_MAX_OPERAND_LENGTH/2], + qC[HAL_RSA_MAX_OPERAND_LENGTH/2], qF[HAL_RSA_MAX_OPERAND_LENGTH/2]; }; +#define RSA_FLAG_NEEDS_SAVING (1 << 0) +#define RSA_FLAG_PRECALC_N_DONE (1 << 1) +#define RSA_FLAG_PRECALC_PQ_DONE (1 << 2) + const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t); /* @@ -155,7 +175,7 @@ const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t); case FP_OKAY: break; \ case FP_VAL: lose(HAL_ERROR_BAD_ARGUMENTS); \ case FP_MEM: lose(HAL_ERROR_ALLOCATION_FAILURE); \ - default: lose(HAL_ERROR_IMPOSSIBLE); \ + default: lose(HAL_ERROR_IMPOSSIBLE); \ } \ } while (0) @@ -168,7 +188,8 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz { hal_error_t err = HAL_OK; - assert(bn != NULL && buffer != NULL); + if (bn == NULL || buffer == NULL) + return HAL_ERROR_IMPOSSIBLE; const size_t bytes = fp_unsigned_bin_size(unconst_fp_int(bn)); @@ -182,47 +203,50 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz return err; } -#if HAL_RSA_USE_MODEXP +#if HAL_RSA_SIGN_USE_MODEXP /* * Unwrap bignums into byte arrays, feed them into hal_modexp(), and * wrap result back up as a bignum. */ -static hal_error_t modexp(const hal_core_t *core, - const fp_int * msg, +static hal_error_t modexp(hal_core_t *core, + const int precalc, + const fp_int * const msg, const fp_int * const exp, const fp_int * const mod, - fp_int *res) + fp_int *res, + uint8_t *coeff, const size_t coeff_len, + uint8_t *mont, const size_t mont_len) { hal_error_t err = HAL_OK; - assert(msg != NULL && exp != NULL && mod != NULL && res != NULL); - - fp_int reduced_msg[1] = INIT_FP_INT; - - if (fp_cmp_mag(unconst_fp_int(msg), unconst_fp_int(mod)) != FP_LT) { - fp_init(reduced_msg); - fp_mod(unconst_fp_int(msg), unconst_fp_int(mod), reduced_msg); - msg = reduced_msg; - } + if (msg == NULL || exp == NULL || mod == NULL || res == NULL || coeff == NULL || mont == NULL) + return HAL_ERROR_IMPOSSIBLE; + const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3; const size_t exp_len = (fp_unsigned_bin_size(unconst_fp_int(exp)) + 3) & ~3; const size_t mod_len = (fp_unsigned_bin_size(unconst_fp_int(mod)) + 3) & ~3; - uint8_t msgbuf[mod_len]; + uint8_t msgbuf[msg_len]; uint8_t expbuf[exp_len]; uint8_t modbuf[mod_len]; uint8_t resbuf[mod_len]; + hal_modexp_arg_t args = { + .core = core, + .msg = msgbuf, .msg_len = sizeof(msgbuf), + .exp = expbuf, .exp_len = sizeof(expbuf), + .mod = modbuf, .mod_len = sizeof(modbuf), + .result = resbuf, .result_len = sizeof(resbuf), + .coeff = coeff, .coeff_len = coeff_len, + .mont = mont, .mont_len = mont_len + }; + if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK || (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK || (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK || - (err = hal_modexp(core, - msgbuf, sizeof(msgbuf), - expbuf, sizeof(expbuf), - modbuf, sizeof(modbuf), - resbuf, sizeof(resbuf))) != HAL_OK) + (err = hal_modexp(precalc, &args)) != HAL_OK) goto fail; fp_read_unsigned_bin(res, resbuf, sizeof(resbuf)); @@ -231,37 +255,105 @@ static hal_error_t modexp(const hal_core_t *core, memset(msgbuf, 0, sizeof(msgbuf)); memset(expbuf, 0, sizeof(expbuf)); memset(modbuf, 0, sizeof(modbuf)); + memset(resbuf, 0, sizeof(resbuf)); + memset(&args, 0, sizeof(args)); return err; } -/* - * Wrapper to let us export our modexp function as a replacement for - * TFM's, to avoid dragging in all of the TFM montgomery code when we - * use TFM's Miller-Rabin test code. - * - * This code is here rather than in a separate module because of the - * error handling: TFM's error codes aren't really capable of - * expressing all the things that could go wrong here. - */ - -int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d) +static hal_error_t modexp2(const int precalc, + const fp_int * const msg, + hal_core_t *core1, + const fp_int * const exp1, + const fp_int * const mod1, + fp_int * res1, + uint8_t *coeff1, const size_t coeff1_len, + uint8_t *mont1, const size_t mont1_len, + hal_core_t *core2, + const fp_int * const exp2, + const fp_int * const mod2, + fp_int * res2, + uint8_t *coeff2, const size_t coeff2_len, + uint8_t *mont2, const size_t mont2_len) { - return modexp(NULL, a, b, c, d) == HAL_OK ? FP_OKAY : FP_VAL; + hal_error_t err = HAL_OK; + + if (msg == NULL || + exp1 == NULL || mod1 == NULL || res1 == NULL || coeff1 == NULL || mont1 == NULL || + exp2 == NULL || mod2 == NULL || res2 == NULL || coeff2 == NULL || mont2 == NULL) + return HAL_ERROR_IMPOSSIBLE; + + const size_t msg_len = (fp_unsigned_bin_size(unconst_fp_int(msg)) + 3) & ~3; + const size_t exp1_len = (fp_unsigned_bin_size(unconst_fp_int(exp1)) + 3) & ~3; + const size_t mod1_len = (fp_unsigned_bin_size(unconst_fp_int(mod1)) + 3) & ~3; + const size_t exp2_len = (fp_unsigned_bin_size(unconst_fp_int(exp2)) + 3) & ~3; + const size_t mod2_len = (fp_unsigned_bin_size(unconst_fp_int(mod2)) + 3) & ~3; + + uint8_t msgbuf[msg_len]; + uint8_t expbuf1[exp1_len], modbuf1[mod1_len], resbuf1[mod1_len]; + uint8_t expbuf2[exp2_len], modbuf2[mod2_len], resbuf2[mod2_len]; + + hal_modexp_arg_t args1 = { + .core = core1, + .msg = msgbuf, .msg_len = sizeof(msgbuf), + .exp = expbuf1, .exp_len = sizeof(expbuf1), + .mod = modbuf1, .mod_len = sizeof(modbuf1), + .result = resbuf1, .result_len = sizeof(resbuf1), + .coeff = coeff1, .coeff_len = coeff1_len, + .mont = mont1, .mont_len = mont1_len + }; + + hal_modexp_arg_t args2 = { + .core = core2, + .msg = msgbuf, .msg_len = sizeof(msgbuf), + .exp = expbuf2, .exp_len = sizeof(expbuf2), + .mod = modbuf2, .mod_len = sizeof(modbuf2), + .result = resbuf2, .result_len = sizeof(resbuf2), + .coeff = coeff2, .coeff_len = coeff2_len, + .mont = mont2, .mont_len = mont2_len + }; + + if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK || + (err = unpack_fp(exp1, expbuf1, sizeof(expbuf1))) != HAL_OK || + (err = unpack_fp(mod1, modbuf1, sizeof(modbuf1))) != HAL_OK || + (err = unpack_fp(exp2, expbuf2, sizeof(expbuf2))) != HAL_OK || + (err = unpack_fp(mod2, modbuf2, sizeof(modbuf2))) != HAL_OK || + (err = hal_modexp2(precalc, &args1, &args2)) != HAL_OK) + goto fail; + + fp_read_unsigned_bin(res1, resbuf1, sizeof(resbuf1)); + fp_read_unsigned_bin(res2, resbuf2, sizeof(resbuf2)); + + fail: + memset(msgbuf, 0, sizeof(msgbuf)); + memset(expbuf1, 0, sizeof(expbuf1)); + memset(modbuf1, 0, sizeof(modbuf1)); + memset(resbuf1, 0, sizeof(resbuf1)); + memset(&args1, 0, sizeof(args1)); + memset(expbuf2, 0, sizeof(expbuf2)); + memset(modbuf2, 0, sizeof(modbuf2)); + memset(resbuf2, 0, sizeof(resbuf2)); + memset(&args2, 0, sizeof(args2)); + return err; } -#else /* HAL_RSA_USE_MODEXP */ +#else /* HAL_RSA_SIGN_USE_MODEXP */ /* - * Workaround to let us use TFM's software implementation of modular - * exponentiation when we want to test other things and don't want to - * wait for the slow FPGA implementation. + * Use libtfm's software implementation of modular exponentiation. + * Now that the ModExpA7 core performs about as well as the software + * implementation, there's probably no need to use this, but we're + * still tuning things, so leave the hook here for now. */ static hal_error_t modexp(const hal_core_t *core, /* ignored */ + const int precalc, /* ignored */ const fp_int * const msg, const fp_int * const exp, const fp_int * const mod, - fp_int *res) + fp_int *res, + uint8_t *coeff, const size_t coeff_len, /* ignored */ + uint8_t *mont, const size_t mont_len) /* ignored */ + { hal_error_t err = HAL_OK; FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp), unconst_fp_int(mod), res)); @@ -269,7 +361,58 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */ return err; } -#endif /* HAL_RSA_USE_MODEXP */ +static hal_error_t modexp2(const int precalc, /* ignored */ + const fp_int * const msg, + hal_core_t *core1, /* ignored */ + const fp_int * const exp1, + const fp_int * const mod1, + fp_int * res1, + uint8_t *coeff1, const size_t coeff1_len, /* ignored */ + uint8_t *mont1, const size_t mont1_len, /* ignored */ + hal_core_t *core2, /* ignored */ + const fp_int * const exp2, + const fp_int * const mod2, + fp_int * res2, + uint8_t *coeff2, const size_t coeff2_len, /* ignored */ + uint8_t *mont2, const size_t mont2_len) /* ignored */ +{ + hal_error_t err = HAL_OK; + FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp1), unconst_fp_int(mod1), res1)); + FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp2), unconst_fp_int(mod2), res2)); + fail: + return err; +} + +#endif /* HAL_RSA_SIGN_USE_MODEXP */ + +/* + * Wrapper to let us export our modexp function as a replacement for + * libtfm's when running libtfm's Miller-Rabin test code. + * + * At the moment, the libtfm software implementation performs + * disproportionately better than our core does for the specific case + * of Miller-Rabin tests, for reasons we don't really understand. + * So there's not much point in enabling this, except as a test to + * confirm this behavior. + * + * This code is here rather than in a separate module because of the + * error handling: libtfm's error codes aren't really capable of + * expressing all the things that could go wrong here. + */ + +#if HAL_RSA_SIGN_USE_MODEXP && HAL_RSA_KEYGEN_USE_MODEXP + +int fp_exptmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d) +{ + const size_t len = (fp_unsigned_bin_size(unconst_fp_int(b)) + 3) & ~3; + uint8_t C[len], F[len]; + const hal_error_t err = modexp(NULL, 0, a, b, c, d, C, sizeof(C), F, sizeof(F)); + memset(C, 0, sizeof(C)); + memset(F, 0, sizeof(F)); + return err == HAL_OK ? FP_OKAY : FP_VAL; +} + +#endif /* HAL_RSA_SIGN_USE_MODEXP && HAL_RSA_KEYGEN_USE_MODEXP */ /* * Create blinding factors. There are various schemes for amortizing @@ -277,10 +420,12 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */ * try. Come back to this if it looks like a bottleneck. */ -static hal_error_t create_blinding_factors(const hal_core_t *core, const hal_rsa_key_t * const key, fp_int *bf, fp_int *ubf) +static hal_error_t create_blinding_factors(hal_core_t *core, hal_rsa_key_t *key, fp_int *bf, fp_int *ubf) { - assert(key != NULL && bf != NULL && ubf != NULL); + if (key == NULL || bf == NULL || ubf == NULL) + return HAL_ERROR_IMPOSSIBLE; + const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); uint8_t rnd[fp_unsigned_bin_size(unconst_fp_int(key->n))]; hal_error_t err = HAL_OK; @@ -291,9 +436,13 @@ static hal_error_t create_blinding_factors(const hal_core_t *core, const hal_rsa fp_read_unsigned_bin(bf, rnd, sizeof(rnd)); fp_copy(bf, ubf); - if ((err = modexp(core, bf, key->e, key->n, bf)) != HAL_OK) + if ((err = modexp(core, precalc, bf, key->e, key->n, bf, + key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) != HAL_OK) goto fail; + if (precalc) + key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; + FP_CHECK(fp_invmod(ubf, unconst_fp_int(key->n), ubf)); fail: @@ -305,10 +454,12 @@ static hal_error_t create_blinding_factors(const hal_core_t *core, const hal_rsa * RSA decryption via Chinese Remainder Theorem (Garner's formula). */ -static hal_error_t rsa_crt(const hal_core_t *core, const hal_rsa_key_t * const key, fp_int *msg, fp_int *sig) +static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *key, fp_int *msg, fp_int *sig) { - assert(key != NULL && msg != NULL && sig != NULL); + if (key == NULL || msg == NULL || sig == NULL) + return HAL_ERROR_IMPOSSIBLE; + const int precalc = !(key->flags & RSA_FLAG_PRECALC_PQ_DONE); hal_error_t err = HAL_OK; fp_int t[1] = INIT_FP_INT; fp_int m1[1] = INIT_FP_INT; @@ -320,7 +471,7 @@ static hal_error_t rsa_crt(const hal_core_t *core, const hal_rsa_key_t * const k * Handle blinding if requested. */ if (blinding) { - if ((err = create_blinding_factors(core, key, bf, ubf)) != HAL_OK) + if ((err = create_blinding_factors(core1, key, bf, ubf)) != HAL_OK) goto fail; FP_CHECK(fp_mulmod(msg, bf, unconst_fp_int(key->n), msg)); } @@ -329,10 +480,14 @@ static hal_error_t rsa_crt(const hal_core_t *core, const hal_rsa_key_t * const k * m1 = msg ** dP mod p * m2 = msg ** dQ mod q */ - if ((err = modexp(core, msg, key->dP, key->p, m1)) != HAL_OK || - (err = modexp(core, msg, key->dQ, key->q, m2)) != HAL_OK) + if ((err = modexp2(precalc, msg, + core1, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF), + core2, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK) goto fail; + if (precalc) + key->flags |= RSA_FLAG_PRECALC_PQ_DONE | RSA_FLAG_NEEDS_SAVING; + /* * t = m1 - m2. */ @@ -376,8 +531,8 @@ static hal_error_t rsa_crt(const hal_core_t *core, const hal_rsa_key_t * const k * to the caller. */ -hal_error_t hal_rsa_encrypt(const hal_core_t *core, - const hal_rsa_key_t * const key, +hal_error_t hal_rsa_encrypt(hal_core_t *core, + hal_rsa_key_t *key, const uint8_t * const input, const size_t input_len, uint8_t * output, const size_t output_len) { @@ -386,23 +541,29 @@ hal_error_t hal_rsa_encrypt(const hal_core_t *core, if (key == NULL || input == NULL || output == NULL || input_len > output_len) return HAL_ERROR_BAD_ARGUMENTS; + const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); fp_int i[1] = INIT_FP_INT; fp_int o[1] = INIT_FP_INT; fp_read_unsigned_bin(i, unconst_uint8_t(input), input_len); - if ((err = modexp(core, i, key->e, key->n, o)) != HAL_OK || - (err = unpack_fp(o, output, output_len)) != HAL_OK) - goto fail; + err = modexp(core, precalc, i, key->e, key->n, o, + key->nC, sizeof(key->nC), key->nF, sizeof(key->nF)); + + if (err == HAL_OK && precalc) + key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; + + if (err == HAL_OK) + err = unpack_fp(o, output, output_len); - fail: fp_zero(i); fp_zero(o); return err; } -hal_error_t hal_rsa_decrypt(const hal_core_t *core, - const hal_rsa_key_t * const key, +hal_error_t hal_rsa_decrypt(hal_core_t *core1, + hal_core_t *core2, + hal_rsa_key_t *key, const uint8_t * const input, const size_t input_len, uint8_t * output, const size_t output_len) { @@ -421,10 +582,17 @@ hal_error_t hal_rsa_decrypt(const hal_core_t *core, * just do brute force ModExp. */ - if (fp_iszero(key->p) || fp_iszero(key->q) || fp_iszero(key->u) || fp_iszero(key->dP) || fp_iszero(key->dQ)) - err = modexp(core, i, key->d, key->n, o); - else - err = rsa_crt(core, key, i, o); + if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) && + !fp_iszero(key->dP) && !fp_iszero(key->dQ)) + err = rsa_crt(core1, core2, key, i, o); + + else { + const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); + err = modexp(core1, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC), + key->nF, sizeof(key->nF)); + if (err == HAL_OK && precalc) + key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; + } if (err != HAL_OK || (err = unpack_fp(o, output, output_len)) != HAL_OK) goto fail; @@ -583,29 +751,120 @@ hal_error_t hal_rsa_key_get_public_exponent(const hal_rsa_key_t * const key, /* * Generate a prime factor for an RSA keypair. * - * Get random bytes, munge a few bits, and stuff into a bignum. Keep - * doing this until we find a result that's (probably) prime and for - * which result - 1 is relatively prime with respect to e. + * Get random bytes, munge a few bits, and stuff into a bignum to + * construct our initial candidate. + * + * Initialize table of remainders when dividing candidate by each + * entry in corresponding table of small primes. We'd have to perform + * these tests in any case for any succesful candidate, and doing it + * up front lets us amortize the cost over the entire search, so we do + * this unconditionally before entering the search loop. + * + * If all of the remainders were non-zero, run the requisite number of + * Miller-Rabin tests using the first few entries from that same table + * of small primes as the test values. If we get past Miller-Rabin, + * the candidate is (probably) prime, to a confidence level which we + * can tune by adjusting the number of Miller-Rabin tests. + * + * For RSA, we also need (result - 1) to be relatively prime with + * respect to the public exponent. If a (probable) prime passes that + * test, we have a winner. + * + * If any of the above tests failed, we increment the candidate and + * all remainders by two, then loop back to the remainder test. This + * is where the table pays off: incrementing remainders is really + * cheap, and since most composite numbers fail the small primes test, + * making that cheap makes the whole loop run significantly faster. + * + * General approach suggested by HAC note 4.51. Range of small prime + * table and default number of Miller-Rabin tests suggested by Schneier. */ +#ifndef HAL_RSA_MILLER_RABIN_TESTS +#define HAL_RSA_MILLER_RABIN_TESTS (5) +#endif + +static const uint16_t small_prime[] = { + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, + 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, + 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, + 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, + 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, + 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, + 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, + 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, + 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, + 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, + 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, + 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, + 971, 977, 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033, 1039, + 1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, + 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, + 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, + 1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, + 1373, 1381, 1399, 1409, 1423, 1427, 1429, 1433, 1439, 1447, 1451, + 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 1499, 1511, 1523, + 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597, 1601, + 1607, 1609, 1613, 1619, 1621, 1627, 1637, 1657, 1663, 1667, 1669, + 1693, 1697, 1699, 1709, 1721, 1723, 1733, 1741, 1747, 1753, 1759, + 1777, 1783, 1787, 1789, 1801, 1811, 1823, 1831, 1847, 1861, 1867, + 1871, 1873, 1877, 1879, 1889, 1901, 1907, 1913, 1931, 1933, 1949, + 1951, 1973, 1979, 1987, 1993, 1997, 1999 +}; + static hal_error_t find_prime(const unsigned prime_length, const fp_int * const e, fp_int *result) { + uint16_t remainder[sizeof(small_prime)/sizeof(*small_prime)]; uint8_t buffer[prime_length]; - hal_error_t err; fp_int t[1] = INIT_FP_INT; + hal_error_t err; - do { - if ((err = hal_get_random(NULL, buffer, sizeof(buffer))) != HAL_OK) - return err; - buffer[0 ] |= 0xc0; - buffer[sizeof(buffer) - 1] |= 0x01; - fp_read_unsigned_bin(result, buffer, sizeof(buffer)); + if ((err = hal_get_random(NULL, buffer, sizeof(buffer))) != HAL_OK) + return err; + + buffer[0] &= ~0x01; /* Headroom for search */ + buffer[0] |= 0xc0; /* Result large enough */ + buffer[sizeof(buffer) - 1] |= 0x01; /* Candidates are odd */ + + fp_read_unsigned_bin(result, buffer, sizeof(buffer)); + memset(buffer, 0, sizeof(buffer)); + + for (size_t i = 0; i < sizeof(small_prime)/sizeof(*small_prime); i++) { + fp_digit d; + fp_mod_d(result, small_prime[i], &d); + remainder[i] = d; + } + + for (;;) { + int possible = 1; + + for (size_t i = 0; i < sizeof(small_prime)/sizeof(*small_prime); i++) + possible &= remainder[i] != 0; - } while (!fp_isprime(result) || - (fp_sub_d(result, 1, t), fp_gcd(t, unconst_fp_int(e), t), fp_cmp_d(t, 1) != FP_EQ)); + for (size_t i = 0; possible && i < HAL_RSA_MILLER_RABIN_TESTS; i++) { + fp_set(t, small_prime[i]); + fp_prime_miller_rabin(result, t, &possible); + } + if (possible) { + fp_sub_d(result, 1, t); + fp_gcd(t, unconst_fp_int(e), t); + possible = fp_cmp_d(t, 1) == FP_EQ; + } + + if (possible) + break; + + fp_add_d(result, 2, result); + + for (size_t i = 0; i < sizeof(small_prime)/sizeof(*small_prime); i++) + if ((remainder[i] += 2) >= small_prime[i]) + remainder[i] -= small_prime[i]; + } + + memset(remainder, 0, sizeof(remainder)); fp_zero(t); return HAL_OK; } @@ -614,7 +873,7 @@ static hal_error_t find_prime(const unsigned prime_length, * Generate a new RSA keypair. */ -hal_error_t hal_rsa_key_gen(const hal_core_t *core, +hal_error_t hal_rsa_key_gen(hal_core_t *core, hal_rsa_key_t **key_, void *keybuf, const size_t keybuf_len, const unsigned key_length, @@ -659,6 +918,8 @@ hal_error_t hal_rsa_key_gen(const hal_core_t *core, FP_CHECK(fp_mod(key->d, q_1, key->dQ)); /* dQ = d % (q-1) */ FP_CHECK(fp_invmod(key->q, key->p, key->u)); /* u = (1/q) % p */ + key->flags |= RSA_FLAG_NEEDS_SAVING; + *key_ = key; /* Fall through to cleanup */ @@ -672,10 +933,26 @@ hal_error_t hal_rsa_key_gen(const hal_core_t *core, } /* + * Whether a key contains new data that need saving (newly generated + * key, updated speedup components, whatever). + */ + +int hal_rsa_key_needs_saving(const hal_rsa_key_t * const key) +{ + return key != NULL && (key->flags & RSA_FLAG_NEEDS_SAVING); +} + +/* * Just enough ASN.1 to read and write PKCS #1.5 RSAPrivateKey syntax * (RFC 2313 section 7.2) wrapped in a PKCS #8 PrivateKeyInfo (RFC 5208). * * RSAPrivateKey fields in the required order. + * + * The "extra" fields are additional key components specific to the + * systolic modexpa7 core. We represent these in ASN.1 as OPTIONAL + * fields using IMPLICIT PRIVATE tags, since this is neither + * standardized nor meaningful to anybody else. Underlying encoding + * is INTEGER or OCTET STRING (currently the latter). */ #define RSAPrivateKey_fields \ @@ -689,8 +966,17 @@ hal_error_t hal_rsa_key_gen(const hal_core_t *core, _(key->dQ); \ _(key->u); -hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, - uint8_t *der, size_t *der_len, const size_t der_max) +#define RSAPrivateKey_extra_fields \ + _(ASN1_PRIVATE + 0, nC, RSA_FLAG_PRECALC_N_DONE); \ + _(ASN1_PRIVATE + 1, nF, RSA_FLAG_PRECALC_N_DONE); \ + _(ASN1_PRIVATE + 2, pC, RSA_FLAG_PRECALC_PQ_DONE); \ + _(ASN1_PRIVATE + 3, pF, RSA_FLAG_PRECALC_PQ_DONE); \ + _(ASN1_PRIVATE + 4, qC, RSA_FLAG_PRECALC_PQ_DONE); \ + _(ASN1_PRIVATE + 5, qF, RSA_FLAG_PRECALC_PQ_DONE); + +hal_error_t hal_rsa_private_key_to_der_internal(const hal_rsa_key_t * const key, + const int include_extra, + uint8_t *der, size_t *der_len, const size_t der_max) { hal_error_t err = HAL_OK; @@ -705,10 +991,32 @@ hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, size_t hlen = 0, vlen = 0; -#define _(x) { size_t n; if ((err = hal_asn1_encode_integer(x, NULL, &n, der_max - vlen)) != HAL_OK) return err; vlen += n; } +#define _(x) \ + { \ + size_t n = 0; \ + err = hal_asn1_encode_integer(x, NULL, &n, der_max - vlen); \ + if (err != HAL_OK) \ + return err; \ + vlen += n; \ + } + RSAPrivateKey_fields; #undef _ +#define _(x,y,z) \ + if ((key->flags & z) != 0) { \ + size_t n = 0; \ + if ((err = hal_asn1_encode_header(x, sizeof(key->y), NULL, \ + &n, 0)) != HAL_OK) \ + return err; \ + vlen += n + sizeof(key->y); \ + } + + if (include_extra) { + RSAPrivateKey_extra_fields; + } +#undef _ + if ((err = hal_asn1_encode_header(ASN1_SEQUENCE, vlen, NULL, &hlen, 0)) != HAL_OK) return err; @@ -729,18 +1037,51 @@ hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, uint8_t *d = der + hlen; memset(d, 0, vlen); -#define _(x) { size_t n; if ((err = hal_asn1_encode_integer(x, d, &n, vlen)) != HAL_OK) return err; d += n; vlen -= n; } +#define _(x) \ + { \ + size_t n = 0; \ + err = hal_asn1_encode_integer(x, d, &n, vlen); \ + if (err != HAL_OK) \ + return err; \ + d += n; \ + vlen -= n; \ + } + RSAPrivateKey_fields; #undef _ +#define _(x,y,z) \ + if ((key->flags & z) != 0) { \ + size_t n = 0; \ + if ((err = hal_asn1_encode_header(x, sizeof(key->y), d, \ + &n, vlen)) != HAL_OK) \ + return err; \ + d += n; \ + vlen -= n; \ + memcpy(d, key->y, sizeof(key->y)); \ + d += sizeof(key->y); \ + vlen -= sizeof(key->y); \ + } + + if (include_extra) { + RSAPrivateKey_extra_fields; + } +#undef _ + return hal_asn1_encode_pkcs8_privatekeyinfo(hal_asn1_oid_rsaEncryption, hal_asn1_oid_rsaEncryption_len, NULL, 0, der, d - der, der, der_len, der_max); } -size_t hal_rsa_private_key_to_der_len(const hal_rsa_key_t * const key) +hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + return hal_rsa_private_key_to_der_internal(key, 0, der, der_len, der_max); +} + +hal_error_t hal_rsa_private_key_to_der_extra(const hal_rsa_key_t * const key, + uint8_t *der, size_t *der_len, const size_t der_max) { - size_t len = 0; - return hal_rsa_private_key_to_der(key, NULL, &len, 0) == HAL_OK ? len : 0; + return hal_rsa_private_key_to_der_internal(key, 1, der, der_len, der_max); } hal_error_t hal_rsa_private_key_from_der(hal_rsa_key_t **key_, @@ -778,12 +1119,48 @@ hal_error_t hal_rsa_private_key_from_der(hal_rsa_key_t **key_, fp_int version[1] = INIT_FP_INT; -#define _(x) { size_t n; if ((err = hal_asn1_decode_integer(x, d, &n, vlen)) != HAL_OK) return err; d += n; vlen -= n; } +#define _(x) \ + { \ + size_t n; \ + err = hal_asn1_decode_integer(x, d, &n, vlen); \ + if (err != HAL_OK) \ + return err; \ + d += n; \ + vlen -= n; \ + } + RSAPrivateKey_fields; #undef _ - if (d != privkey + privkey_len || !fp_iszero(version)) +#define _(x,y,z) \ + if (hal_asn1_peek(x, d, vlen)) { \ + size_t hl = 0, vl = 0; \ + if ((err = hal_asn1_decode_header(x, d, vlen, &hl, &vl)) != HAL_OK) \ + return err; \ + if (vl > sizeof(key->y)) { \ + hal_log(HAL_LOG_DEBUG, "extra factor %s too big (%lu > %lu)", \ + #y, (unsigned long) vl, (unsigned long) sizeof(key->y)); \ + return HAL_ERROR_ASN1_PARSE_FAILED; \ + } \ + memcpy(key->y, d + hl, vl); \ + key->flags |= z; \ + d += hl + vl; \ + vlen -= hl + vl; \ + } + + RSAPrivateKey_extra_fields; +#undef _ + + if (d != privkey + privkey_len) { + hal_log(HAL_LOG_DEBUG, "not at end of buffer (0x%lx != 0x%lx)", + (unsigned long) d, (unsigned long) privkey + privkey_len); return HAL_ERROR_ASN1_PARSE_FAILED; + } + + if (!fp_iszero(version)) { + hal_log(HAL_LOG_DEBUG, "nonzero version"); + return HAL_ERROR_ASN1_PARSE_FAILED; + } *key_ = key; |