From 8ff9d4131bf79b36551c2ed995881a88fb9c0a61 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Tue, 12 Sep 2017 10:04:55 -0400 Subject: Untested ASN.1 support for ModExpA7 private speedup factors. --- asn1_internal.h | 10 ++++ hal.h | 7 ++- modexp.c | 4 +- rsa.c | 147 ++++++++++++++++++++++++++++++++++++++++++++++---------- 4 files changed, 140 insertions(+), 28 deletions(-) diff --git a/asn1_internal.h b/asn1_internal.h index fe2f293..3de8bd6 100644 --- a/asn1_internal.h +++ b/asn1_internal.h @@ -151,6 +151,16 @@ extern hal_error_t hal_asn1_decode_pkcs8_encryptedprivatekeyinfo(const uint8_t * extern hal_error_t hal_asn1_guess_key_type(hal_key_type_t *type, hal_curve_name_t *curve, const uint8_t *const der, const size_t der_len); +/* + * Peek ahead for an OPTIONAL attribute. + */ + +static inline int hal_asn1_peek(const uint8_t tag, + const uint8_t * const der, size_t der_max) +{ + return der != NULL && der_max > 0 && der[0] == tag; +} + #endif /* _HAL_ASN1_INTERNAL_H_ */ /* diff --git a/hal.h b/hal.h index 74f35fa..b7eae72 100644 --- a/hal.h +++ b/hal.h @@ -374,7 +374,7 @@ extern hal_error_t hal_pbkdf2(hal_core_t *core, extern void hal_modexp_set_debug(const int onoff); extern hal_error_t hal_modexp(hal_core_t *core, - const int precalc_done, + const int precalc, const uint8_t * const msg, const size_t msg_len, /* Message */ const uint8_t * const exp, const size_t exp_len, /* Exponent */ const uint8_t * const mod, const size_t mod_len, /* Modulus */ @@ -476,6 +476,9 @@ extern hal_error_t hal_rsa_key_gen(hal_core_t *core, extern hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, uint8_t *der, size_t *der_len, const size_t der_max); +extern hal_error_t hal_rsa_private_key_to_der_extra(const hal_rsa_key_t * const key, + uint8_t *der, size_t *der_len, const size_t der_max); + extern size_t hal_rsa_private_key_to_der_len(const hal_rsa_key_t * const key); extern hal_error_t hal_rsa_private_key_from_der(hal_rsa_key_t **key, @@ -491,6 +494,8 @@ extern hal_error_t hal_rsa_public_key_from_der(hal_rsa_key_t **key, void *keybuf, const size_t keybuf_len, const uint8_t * const der, const size_t der_len); +extern int hal_rsa_key_needs_saving(const hal_rsa_key_t * const key); + /* * ECDSA. */ diff --git a/modexp.c b/modexp.c index 7ff7b21..12b5789 100644 --- a/modexp.c +++ b/modexp.c @@ -177,7 +177,7 @@ static inline hal_error_t set_buffer(const hal_core_t *core, */ hal_error_t hal_modexp(hal_core_t *core, - const int precalc_done, + const int precalc, const uint8_t * const msg, const size_t msg_len, /* Message */ const uint8_t * const exp, const size_t exp_len, /* Exponent */ const uint8_t * const mod, const size_t mod_len, /* Modulus */ @@ -242,7 +242,7 @@ hal_error_t hal_modexp(hal_core_t *core, * is edge-triggered by "init" bit going from zero to one. */ - if (!precalc_done) { + if (precalc) { check(hal_io_zero(core)); check(hal_io_init(core)); check(hal_io_wait_ready(core)); diff --git a/rsa.c b/rsa.c index 9cc940c..e414e93 100644 --- a/rsa.c +++ b/rsa.c @@ -147,9 +147,9 @@ struct hal_rsa_key { qC[HAL_RSA_MAX_OPERAND_LENGTH/2], qF[HAL_RSA_MAX_OPERAND_LENGTH/2]; }; -#define RSA_FLAG_PRECALC_N_DONE (1 << 0) -#define RSA_FLAG_PRECALC_P_DONE (1 << 1) -#define RSA_FLAG_PRECALC_Q_DONE (1 << 2) +#define RSA_FLAG_NEEDS_SAVING (1 << 0) +#define RSA_FLAG_PRECALC_N_DONE (1 << 1) +#define RSA_FLAG_PRECALC_PQ_DONE (1 << 2) const size_t hal_rsa_key_t_size = sizeof(hal_rsa_key_t); @@ -211,7 +211,7 @@ static hal_error_t unpack_fp(const fp_int * const bn, uint8_t *buffer, const siz */ static hal_error_t modexp(hal_core_t *core, - const int precalc_done, + const int precalc, const fp_int * const msg, const fp_int * const exp, const fp_int * const mod, @@ -236,7 +236,7 @@ static hal_error_t modexp(hal_core_t *core, if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK || (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK || (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK || - (err = hal_modexp(core, precalc_done, + (err = hal_modexp(core, precalc, msgbuf, sizeof(msgbuf), expbuf, sizeof(expbuf), modbuf, sizeof(modbuf), @@ -265,7 +265,7 @@ static hal_error_t modexp(hal_core_t *core, */ static hal_error_t modexp(const hal_core_t *core, /* ignored */ - const int precalc_done, /* ignored */ + const int precalc, /* ignored */ const fp_int * const msg, const fp_int * const exp, const fp_int * const mod, @@ -322,6 +322,7 @@ static hal_error_t create_blinding_factors(hal_core_t *core, hal_rsa_key_t *key, if (key == NULL || bf == NULL || ubf == NULL) return HAL_ERROR_IMPOSSIBLE; + const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); uint8_t rnd[fp_unsigned_bin_size(unconst_fp_int(key->n))]; hal_error_t err = HAL_OK; @@ -332,11 +333,12 @@ static hal_error_t create_blinding_factors(hal_core_t *core, hal_rsa_key_t *key, fp_read_unsigned_bin(bf, rnd, sizeof(rnd)); fp_copy(bf, ubf); - if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), bf, key->e, key->n, bf, + if ((err = modexp(core, precalc, bf, key->e, key->n, bf, key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) != HAL_OK) goto fail; - key->flags |= RSA_FLAG_PRECALC_N_DONE; + if (precalc) + key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; FP_CHECK(fp_invmod(ubf, unconst_fp_int(key->n), ubf)); @@ -354,6 +356,7 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp if (key == NULL || msg == NULL || sig == NULL) return HAL_ERROR_IMPOSSIBLE; + const int precalc = !(key->flags & RSA_FLAG_PRECALC_PQ_DONE); hal_error_t err = HAL_OK; fp_int t[1] = INIT_FP_INT; fp_int m1[1] = INIT_FP_INT; @@ -377,13 +380,14 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp * This is just crying out to be done with parallel cores, but get * the boring version working before jumping off that cliff. */ - if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_P_DONE), - msg, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK || - (err = modexp(core, (key->flags & RSA_FLAG_PRECALC_Q_DONE), - msg, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK) + if ((err = modexp(core, precalc, msg, key->dP, key->p, m1, + key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK || + (err = modexp(core, precalc, msg, key->dQ, key->q, m2, + key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK) goto fail; - key->flags |= RSA_FLAG_PRECALC_P_DONE | RSA_FLAG_PRECALC_Q_DONE; + if (precalc) + key->flags |= RSA_FLAG_PRECALC_PQ_DONE | RSA_FLAG_NEEDS_SAVING; /* * t = m1 - m2. @@ -438,16 +442,20 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core, if (key == NULL || input == NULL || output == NULL || input_len > output_len) return HAL_ERROR_BAD_ARGUMENTS; + const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); fp_int i[1] = INIT_FP_INT; fp_int o[1] = INIT_FP_INT; fp_read_unsigned_bin(i, unconst_uint8_t(input), input_len); - if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), i, key->e, key->n, o, - key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) == HAL_OK) { - key->flags |= RSA_FLAG_PRECALC_N_DONE; + err = modexp(core, precalc, i, key->e, key->n, o, + key->nC, sizeof(key->nC), key->nF, sizeof(key->nF)); + + if (err == HAL_OK && precalc) + key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; + + if (err == HAL_OK) err = unpack_fp(o, output, output_len); - } fp_zero(i); fp_zero(o); @@ -474,12 +482,17 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core, * just do brute force ModExp. */ - if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) && !fp_iszero(key->dP) && !fp_iszero(key->dQ)) + if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) && + !fp_iszero(key->dP) && !fp_iszero(key->dQ)) err = rsa_crt(core, key, i, o); - else if ((err = modexp(core, (key->flags & RSA_FLAG_PRECALC_N_DONE), i, key->d, key->n, o, - key->nC, sizeof(key->nC), key->nF, sizeof(key->nF))) == HAL_OK) - key->flags |= RSA_FLAG_PRECALC_N_DONE; + else { + const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE); + err = modexp(core, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC), + key->nF, sizeof(key->nF)); + if (err == HAL_OK && precalc) + key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING; + } if (err != HAL_OK || (err = unpack_fp(o, output, output_len)) != HAL_OK) goto fail; @@ -802,6 +815,8 @@ hal_error_t hal_rsa_key_gen(hal_core_t *core, FP_CHECK(fp_mod(key->d, q_1, key->dQ)); /* dQ = d % (q-1) */ FP_CHECK(fp_invmod(key->q, key->p, key->u)); /* u = (1/q) % p */ + key->flags |= RSA_FLAG_NEEDS_SAVING; + *key_ = key; /* Fall through to cleanup */ @@ -814,11 +829,27 @@ hal_error_t hal_rsa_key_gen(hal_core_t *core, return err; } +/* + * Whether a key contains new data that need saving (newly generated + * key, updated speedup components, whatever). + */ + +int hal_rsa_key_needs_saving(const hal_rsa_key_t * const key) +{ + return key != NULL && (key->flags & RSA_FLAG_NEEDS_SAVING); +} + /* * Just enough ASN.1 to read and write PKCS #1.5 RSAPrivateKey syntax * (RFC 2313 section 7.2) wrapped in a PKCS #8 PrivateKeyInfo (RFC 5208). * * RSAPrivateKey fields in the required order. + * + * The "extra" fields are additional key components specific to the + * systolic modexpa7 core. We represent these in ASN.1 as OPTIONAL + * fields using IMPLICIT PRIVATE tags, since this is neither + * standardized nor meaningful to anybody else. Underlying encoding + * is INTEGER or OCTET STRING (currently the latter). */ #define RSAPrivateKey_fields \ @@ -832,8 +863,17 @@ hal_error_t hal_rsa_key_gen(hal_core_t *core, _(key->dQ); \ _(key->u); -hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, - uint8_t *der, size_t *der_len, const size_t der_max) +#define RSAPrivateKey_extra_fields \ + _(ASN1_PRIVATE + 0, nC, RSA_FLAG_PRECALC_N_DONE); \ + _(ASN1_PRIVATE + 1, nF, RSA_FLAG_PRECALC_N_DONE); \ + _(ASN1_PRIVATE + 2, pC, RSA_FLAG_PRECALC_PQ_DONE); \ + _(ASN1_PRIVATE + 3, pF, RSA_FLAG_PRECALC_PQ_DONE); \ + _(ASN1_PRIVATE + 4, qC, RSA_FLAG_PRECALC_PQ_DONE); \ + _(ASN1_PRIVATE + 5, qF, RSA_FLAG_PRECALC_PQ_DONE); + +hal_error_t hal_rsa_private_key_to_der_extra_maybe(const hal_rsa_key_t * const key, + const int include_extra, + uint8_t *der, size_t *der_len, const size_t der_max) { hal_error_t err = HAL_OK; @@ -848,10 +888,24 @@ hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, size_t hlen = 0, vlen = 0; -#define _(x) { size_t n; if ((err = hal_asn1_encode_integer(x, NULL, &n, der_max - vlen)) != HAL_OK) return err; vlen += n; } +#define _(x) { size_t n = 0; if ((err = hal_asn1_encode_integer(x, NULL, &n, der_max - vlen)) != HAL_OK) return err; vlen += n; } RSAPrivateKey_fields; #undef _ +#define _(x,y,z) \ + if ((key->flags & z) != 0) { \ + size_t n = 0; \ + if ((err = hal_asn1_encode_HEADER(x, sizeof(key->y), NULL, \ + &n, 0)) != HAL_OK) \ + return err; \ + vlen += n + sizeof(key->y); \ + } + + if (include_extra) { + RSAPrivateKey_extra_fields; + } +#undef _ + if ((err = hal_asn1_encode_header(ASN1_SEQUENCE, vlen, NULL, &hlen, 0)) != HAL_OK) return err; @@ -872,14 +926,41 @@ hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, uint8_t *d = der + hlen; memset(d, 0, vlen); -#define _(x) { size_t n; if ((err = hal_asn1_encode_integer(x, d, &n, vlen)) != HAL_OK) return err; d += n; vlen -= n; } +#define _(x) { size_t n = 0; if ((err = hal_asn1_encode_integer(x, d, &n, vlen)) != HAL_OK) return err; d += n; vlen -= n; } RSAPrivateKey_fields; #undef _ +#define _(x,y,z) \ + if ((key->flags & z) != 0) { \ + size_t n = 0; \ + if ((err = hal_asn1_encode_header(x, sizeof(key->y), d, \ + &n, vlen)) != HAL_OK) \ + return err; \ + d += n + sizeof(key->y); \ + vlen -= n + sizeof(key->y); \ + } + + if (include_extra) { + RSAPrivateKey_extra_fields; + } +#undef _ + return hal_asn1_encode_pkcs8_privatekeyinfo(hal_asn1_oid_rsaEncryption, hal_asn1_oid_rsaEncryption_len, NULL, 0, der, d - der, der, der_len, der_max); } +hal_error_t hal_rsa_private_key_to_der(const hal_rsa_key_t * const key, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + return hal_rsa_private_key_to_der_extra_maybe(key, 0, der, der_len, der_max); +} + +hal_error_t hal_rsa_private_key_to_der_extra(const hal_rsa_key_t * const key, + uint8_t *der, size_t *der_len, const size_t der_max) +{ + return hal_rsa_private_key_to_der_extra_maybe(key, 1, der, der_len, der_max); +} + size_t hal_rsa_private_key_to_der_len(const hal_rsa_key_t * const key) { size_t len = 0; @@ -925,6 +1006,22 @@ hal_error_t hal_rsa_private_key_from_der(hal_rsa_key_t **key_, RSAPrivateKey_fields; #undef _ +#define _(x,y,z) \ + if (hal_asn1_peek(d, vlen, x) { \ + size_t hl = 0, vl = 0; \ + if ((err = hal_asn1_decode_header(x, d, vlen, &hl, &vl)) != HAL_OK) \ + return err; \ + if (vl > sizeof(key->y)) \ + return HAL_ERROR_ASN1_PARSE_FAILED; \ + memcpy(key->y, d + hl, vl); \ + key->flags |= z; \ + d += hl + vl; \ + vlen -= hl + vl; \ + } + + RSAPrivateKey_extra_fields; +#undef _ + if (d != privkey + privkey_len || !fp_iszero(version)) return HAL_ERROR_ASN1_PARSE_FAILED; -- cgit v1.2.3