diff options
-rw-r--r-- | cryptech.h | 10 | ||||
-rw-r--r-- | rsa.c | 183 | ||||
-rw-r--r-- | tests/test-rsa.c | 8 |
3 files changed, 131 insertions, 70 deletions
@@ -628,9 +628,13 @@ extern hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type, extern void hal_rsa_key_clear(hal_rsa_key_t key); -extern hal_error_t hal_rsa_crt(hal_rsa_key_t key, - const uint8_t * const m, const size_t m_len, - uint8_t * result, const size_t result_len); +extern hal_error_t hal_rsa_encrypt(hal_rsa_key_t key, + const uint8_t * const input, const size_t input_len, + uint8_t * output, const size_t output_len); + +extern hal_error_t hal_rsa_decrypt(hal_rsa_key_t key, + const uint8_t * const input, const size_t input_len, + uint8_t * output, const size_t output_len); extern hal_error_t hal_rsa_key_gen(hal_rsa_key_t *key, void *keybuf, const size_t keybuf_len, @@ -196,6 +196,123 @@ static hal_error_t modexp_fp(fp_int *msg, fp_int *exp, fp_int *mod, fp_int *res) return err; } +/* + * RSA decryption via Chinese Remainder Theorem (Garner's formula). + */ + +static hal_error_t rsa_crt(struct rsa_key *key, fp_int *msg, fp_int *sig) +{ + assert(key != NULL && msg != NULL && sig != NULL); + + hal_error_t err = HAL_OK; + fp_int t, m1, m2; + + fp_init(&t); + fp_init(&m1); + fp_init(&m2); + + /* + * m1 = msg ** dP mod p + * m2 = msg ** dQ mod q + */ + if ((err = modexp_fp(msg, &key->dP, &key->p, &m1)) != HAL_OK || + (err = modexp_fp(msg, &key->dQ, &key->q, &m2)) != HAL_OK) + goto fail; + + /* + * t = m1 - m2. + */ + fp_sub(&m1, &m2, &t); + + /* + * Add zero (mod p) if needed to make t positive. If doing this + * once or twice doesn't help, something is very wrong. + */ + if (fp_cmp_d(&t, 0) == FP_LT) + fp_add(&t, &key->p, &t); + if (fp_cmp_d(&t, 0) == FP_LT) + fp_add(&t, &key->p, &t); + if (fp_cmp_d(&t, 0) == FP_LT) + lose(HAL_ERROR_IMPOSSIBLE); + + /* + * sig = (t * u mod p) * q + m2 + */ + FP_CHECK(fp_mulmod(&t, &key->u, &key->p, &t)); + fp_mul(&t, &key->q, &t); + fp_add(&t, &m2, sig); + + fail: + fp_zero(&t); + fp_zero(&m1); + fp_zero(&m2); + return err; +} + +/* + * Public API for raw RSA encryption and decryption. + */ + +hal_error_t hal_rsa_encrypt(hal_rsa_key_t key_, + const uint8_t * const input, const size_t input_len, + uint8_t * output, const size_t output_len) +{ + struct rsa_key *key = key_.key; + hal_error_t err = HAL_OK; + + if (key == NULL || input == NULL || output == NULL || input_len > output_len) + return HAL_ERROR_BAD_ARGUMENTS; + + fp_int i, o; + fp_init(&i); + fp_init(&o); + + fp_read_unsigned_bin(&i, (uint8_t *) input, input_len); + + if ((err = modexp_fp(&i, &key->e, &key->n, &o)) != HAL_OK || + (err = unpack_fp(&o, output, output_len)) != HAL_OK) + goto fail; + + fail: + fp_zero(&i); + fp_zero(&o); + return err; +} + +hal_error_t hal_rsa_decrypt(hal_rsa_key_t key_, + const uint8_t * const input, const size_t input_len, + uint8_t * output, const size_t output_len) +{ + struct rsa_key *key = key_.key; + hal_error_t err = HAL_OK; + + if (key == NULL || input == NULL || output == NULL || input_len > output_len) + return HAL_ERROR_BAD_ARGUMENTS; + + fp_int i, o; + fp_init(&i); + fp_init(&o); + + fp_read_unsigned_bin(&i, (uint8_t *) input, input_len); + + /* + * Do CRT if we have all the necessary key components, otherwise + * just do brute force ModExp. + */ + + if (fp_iszero(&key->p) || fp_iszero(&key->q) || fp_iszero(&key->u) || fp_iszero(&key->dP) || fp_iszero(&key->dQ)) + err = modexp_fp(&i, &key->d, &key->n, &o); + else + err = rsa_crt(key, &i, &o); + + if (err != HAL_OK || (err = unpack_fp(&o, output, output_len)) != HAL_OK) + goto fail; + + fail: + fp_zero(&i); + fp_zero(&o); + return err; +} /* * Clear a key. We might want to do something a bit more energetic @@ -255,74 +372,14 @@ hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type, return HAL_ERROR_BAD_ARGUMENTS; } -/* - * RSA decyrption/signature using the Chinese Remainder Theorem - * (Garner's formula). - */ - -hal_error_t hal_rsa_crt(hal_rsa_key_t key_, - const uint8_t * const m, const size_t m_len, - uint8_t * result, const size_t result_len) -{ - hal_error_t err = HAL_OK; - struct rsa_key *key = key_.key; - struct { fp_int t, msg, m1, m2; } tmp; - - fp_init(&tmp.t); - fp_init(&tmp.msg); - fp_init(&tmp.m1); - fp_init(&tmp.m2); - - fp_read_unsigned_bin(&tmp.msg, (uint8_t *) m, m_len); - - /* - * m1 = msg ** dP mod p - * m2 = msg ** dQ mod q - */ - if ((err = modexp_fp(&tmp.msg, &key->dP, &key->p, &tmp.m1)) != HAL_OK || - (err = modexp_fp(&tmp.msg, &key->dQ, &key->q, &tmp.m2)) != HAL_OK) - goto fail; - - /* - * t = m1 - m2. - * Add zero (mod p) once or twice if necessary to get positive result. - */ - fp_sub(&tmp.m1, &tmp.m2, &tmp.t); - if (fp_cmp_d(&tmp.t, 0) == FP_LT) - fp_add(&tmp.t, &key->p, &tmp.t); - if (fp_cmp_d(&tmp.t, 0) == FP_LT) - fp_add(&tmp.t, &key->p, &tmp.t); - if (fp_cmp_d(&tmp.t, 0) == FP_LT) - lose(HAL_ERROR_IMPOSSIBLE); - - /* - * t = (t * u mod p) * q + m2 - */ - FP_CHECK(fp_mulmod(&tmp.t, &key->u, &key->p, &tmp.t)); - fp_mul(&tmp.t, &key->q, &tmp.t); - fp_add(&tmp.t, &tmp.m2, &tmp.t); - - /* - * t now holds result, write it back to caller - */ - if ((err = unpack_fp(&tmp.t, result, result_len)) != HAL_OK) - goto fail; - - /* - * Done, fall through into cleanup. - */ - - fail: - memset(&tmp, 0, sizeof(tmp)); - return err; -} - static hal_error_t find_prime(unsigned prime_length, fp_int *e, fp_int *result) { uint8_t buffer[prime_length]; hal_error_t err; fp_int t; + fp_init(&t); + /* * Get random bytes, munge a few bits, and stuff into a bignum. * Keep doing this until we find a result that's (probably) prime @@ -547,7 +604,7 @@ static hal_error_t decode_integer(fp_int *bn, if (der_len != NULL) *der_len = hlen + vlen; - if (vlen < 1) + if (vlen < 1 || (der[hlen] & 0x80) != 0x00) return HAL_ERROR_ASN1_PARSE_FAILED; fp_init(bn); diff --git a/tests/test-rsa.c b/tests/test-rsa.c index 08d22c5..799f1fa 100644 --- a/tests/test-rsa.c +++ b/tests/test-rsa.c @@ -82,7 +82,7 @@ static int test_modexp(const char * const kind, * Run one RSA CRT test. */ -static int test_crt(const char * const kind, const rsa_tc_t * const tc) +static int test_decrypt(const char * const kind, const rsa_tc_t * const tc) { printf("%s test for %lu-bit RSA key\n", kind, (unsigned long) tc->size); @@ -106,7 +106,7 @@ static int test_crt(const char * const kind, const rsa_tc_t * const tc) uint8_t result[tc->n.len]; - if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) + if ((err = hal_rsa_decrypt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) printf("RSA CRT failed: %s\n", hal_error_string(err)); const int mismatch = (err == HAL_OK && memcmp(result, tc->s.val, tc->s.len) != 0); @@ -172,7 +172,7 @@ static int test_gen(const char * const kind, const rsa_tc_t * const tc) uint8_t result[tc->n.len]; - if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) + if ((err = hal_rsa_decrypt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK) printf("RSA CRT failed: %s\n", hal_error_string(err)); snprintf(fn, sizeof(fn), "test-rsa-sig-%04lu.der", (unsigned long) tc->size); @@ -244,7 +244,7 @@ static int test_rsa(const rsa_tc_t * const tc) time_check(test_modexp("Signature (ModExp)", tc, &tc->m, &tc->d, &tc->s)); /* RSA decyrption using CRT */ - time_check(test_crt("Signature (CRT)", tc)); + time_check(test_decrypt("Signature (CRT)", tc)); /* Key generation and CRT -- not test vector, so writes key and sig to file */ time_check(test_gen("Generation and CRT", tc)); |