From 5e4fc533393e01e16739f450d46f739ca4b24fe8 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Thu, 18 Jun 2015 14:55:51 -0400 Subject: Refactor CRT code into public API. --- rsa.c | 183 +++++++++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 120 insertions(+), 63 deletions(-) (limited to 'rsa.c') diff --git a/rsa.c b/rsa.c index 9a42563..543becc 100644 --- a/rsa.c +++ b/rsa.c @@ -196,6 +196,123 @@ static hal_error_t modexp_fp(fp_int *msg, fp_int *exp, fp_int *mod, fp_int *res) return err; } +/* + * RSA decryption via Chinese Remainder Theorem (Garner's formula). + */ + +static hal_error_t rsa_crt(struct rsa_key *key, fp_int *msg, fp_int *sig) +{ + assert(key != NULL && msg != NULL && sig != NULL); + + hal_error_t err = HAL_OK; + fp_int t, m1, m2; + + fp_init(&t); + fp_init(&m1); + fp_init(&m2); + + /* + * m1 = msg ** dP mod p + * m2 = msg ** dQ mod q + */ + if ((err = modexp_fp(msg, &key->dP, &key->p, &m1)) != HAL_OK || + (err = modexp_fp(msg, &key->dQ, &key->q, &m2)) != HAL_OK) + goto fail; + + /* + * t = m1 - m2. + */ + fp_sub(&m1, &m2, &t); + + /* + * Add zero (mod p) if needed to make t positive. If doing this + * once or twice doesn't help, something is very wrong. + */ + if (fp_cmp_d(&t, 0) == FP_LT) + fp_add(&t, &key->p, &t); + if (fp_cmp_d(&t, 0) == FP_LT) + fp_add(&t, &key->p, &t); + if (fp_cmp_d(&t, 0) == FP_LT) + lose(HAL_ERROR_IMPOSSIBLE); + + /* + * sig = (t * u mod p) * q + m2 + */ + FP_CHECK(fp_mulmod(&t, &key->u, &key->p, &t)); + fp_mul(&t, &key->q, &t); + fp_add(&t, &m2, sig); + + fail: + fp_zero(&t); + fp_zero(&m1); + fp_zero(&m2); + return err; +} + +/* + * Public API for raw RSA encryption and decryption. + */ + +hal_error_t hal_rsa_encrypt(hal_rsa_key_t key_, + const uint8_t * const input, const size_t input_len, + uint8_t * output, const size_t output_len) +{ + struct rsa_key *key = key_.key; + hal_error_t err = HAL_OK; + + if (key == NULL || input == NULL || output == NULL || input_len > output_len) + return HAL_ERROR_BAD_ARGUMENTS; + + fp_int i, o; + fp_init(&i); + fp_init(&o); + + fp_read_unsigned_bin(&i, (uint8_t *) input, input_len); + + if ((err = modexp_fp(&i, &key->e, &key->n, &o)) != HAL_OK || + (err = unpack_fp(&o, output, output_len)) != HAL_OK) + goto fail; + + fail: + fp_zero(&i); + fp_zero(&o); + return err; +} + +hal_error_t hal_rsa_decrypt(hal_rsa_key_t key_, + const uint8_t * const input, const size_t input_len, + uint8_t * output, const size_t output_len) +{ + struct rsa_key *key = key_.key; + hal_error_t err = HAL_OK; + + if (key == NULL || input == NULL || output == NULL || input_len > output_len) + return HAL_ERROR_BAD_ARGUMENTS; + + fp_int i, o; + fp_init(&i); + fp_init(&o); + + fp_read_unsigned_bin(&i, (uint8_t *) input, input_len); + + /* + * Do CRT if we have all the necessary key components, otherwise + * just do brute force ModExp. + */ + + if (fp_iszero(&key->p) || fp_iszero(&key->q) || fp_iszero(&key->u) || fp_iszero(&key->dP) || fp_iszero(&key->dQ)) + err = modexp_fp(&i, &key->d, &key->n, &o); + else + err = rsa_crt(key, &i, &o); + + if (err != HAL_OK || (err = unpack_fp(&o, output, output_len)) != HAL_OK) + goto fail; + + fail: + fp_zero(&i); + fp_zero(&o); + return err; +} /* * Clear a key. We might want to do something a bit more energetic @@ -255,74 +372,14 @@ hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type, return HAL_ERROR_BAD_ARGUMENTS; } -/* - * RSA decyrption/signature using the Chinese Remainder Theorem - * (Garner's formula). - */ - -hal_error_t hal_rsa_crt(hal_rsa_key_t key_, - const uint8_t * const m, const size_t m_len, - uint8_t * result, const size_t result_len) -{ - hal_error_t err = HAL_OK; - struct rsa_key *key = key_.key; - struct { fp_int t, msg, m1, m2; } tmp; - - fp_init(&tmp.t); - fp_init(&tmp.msg); - fp_init(&tmp.m1); - fp_init(&tmp.m2); - - fp_read_unsigned_bin(&tmp.msg, (uint8_t *) m, m_len); - - /* - * m1 = msg ** dP mod p - * m2 = msg ** dQ mod q - */ - if ((err = modexp_fp(&tmp.msg, &key->dP, &key->p, &tmp.m1)) != HAL_OK || - (err = modexp_fp(&tmp.msg, &key->dQ, &key->q, &tmp.m2)) != HAL_OK) - goto fail; - - /* - * t = m1 - m2. - * Add zero (mod p) once or twice if necessary to get positive result. - */ - fp_sub(&tmp.m1, &tmp.m2, &tmp.t); - if (fp_cmp_d(&tmp.t, 0) == FP_LT) - fp_add(&tmp.t, &key->p, &tmp.t); - if (fp_cmp_d(&tmp.t, 0) == FP_LT) - fp_add(&tmp.t, &key->p, &tmp.t); - if (fp_cmp_d(&tmp.t, 0) == FP_LT) - lose(HAL_ERROR_IMPOSSIBLE); - - /* - * t = (t * u mod p) * q + m2 - */ - FP_CHECK(fp_mulmod(&tmp.t, &key->u, &key->p, &tmp.t)); - fp_mul(&tmp.t, &key->q, &tmp.t); - fp_add(&tmp.t, &tmp.m2, &tmp.t); - - /* - * t now holds result, write it back to caller - */ - if ((err = unpack_fp(&tmp.t, result, result_len)) != HAL_OK) - goto fail; - - /* - * Done, fall through into cleanup. - */ - - fail: - memset(&tmp, 0, sizeof(tmp)); - return err; -} - static hal_error_t find_prime(unsigned prime_length, fp_int *e, fp_int *result) { uint8_t buffer[prime_length]; hal_error_t err; fp_int t; + fp_init(&t); + /* * Get random bytes, munge a few bits, and stuff into a bignum. * Keep doing this until we find a result that's (probably) prime @@ -547,7 +604,7 @@ static hal_error_t decode_integer(fp_int *bn, if (der_len != NULL) *der_len = hlen + vlen; - if (vlen < 1) + if (vlen < 1 || (der[hlen] & 0x80) != 0x00) return HAL_ERROR_ASN1_PARSE_FAILED; fp_init(bn); -- cgit v1.2.3