aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cryptech.h10
-rw-r--r--rsa.c183
-rw-r--r--tests/test-rsa.c8
3 files changed, 131 insertions, 70 deletions
diff --git a/cryptech.h b/cryptech.h
index 6af9ce8..4b8fe17 100644
--- a/cryptech.h
+++ b/cryptech.h
@@ -628,9 +628,13 @@ extern hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type,
extern void hal_rsa_key_clear(hal_rsa_key_t key);
-extern hal_error_t hal_rsa_crt(hal_rsa_key_t key,
- const uint8_t * const m, const size_t m_len,
- uint8_t * result, const size_t result_len);
+extern hal_error_t hal_rsa_encrypt(hal_rsa_key_t key,
+ const uint8_t * const input, const size_t input_len,
+ uint8_t * output, const size_t output_len);
+
+extern hal_error_t hal_rsa_decrypt(hal_rsa_key_t key,
+ const uint8_t * const input, const size_t input_len,
+ uint8_t * output, const size_t output_len);
extern hal_error_t hal_rsa_key_gen(hal_rsa_key_t *key,
void *keybuf, const size_t keybuf_len,
diff --git a/rsa.c b/rsa.c
index 9a42563..543becc 100644
--- a/rsa.c
+++ b/rsa.c
@@ -196,6 +196,123 @@ static hal_error_t modexp_fp(fp_int *msg, fp_int *exp, fp_int *mod, fp_int *res)
return err;
}
+/*
+ * RSA decryption via Chinese Remainder Theorem (Garner's formula).
+ */
+
+static hal_error_t rsa_crt(struct rsa_key *key, fp_int *msg, fp_int *sig)
+{
+ assert(key != NULL && msg != NULL && sig != NULL);
+
+ hal_error_t err = HAL_OK;
+ fp_int t, m1, m2;
+
+ fp_init(&t);
+ fp_init(&m1);
+ fp_init(&m2);
+
+ /*
+ * m1 = msg ** dP mod p
+ * m2 = msg ** dQ mod q
+ */
+ if ((err = modexp_fp(msg, &key->dP, &key->p, &m1)) != HAL_OK ||
+ (err = modexp_fp(msg, &key->dQ, &key->q, &m2)) != HAL_OK)
+ goto fail;
+
+ /*
+ * t = m1 - m2.
+ */
+ fp_sub(&m1, &m2, &t);
+
+ /*
+ * Add zero (mod p) if needed to make t positive. If doing this
+ * once or twice doesn't help, something is very wrong.
+ */
+ if (fp_cmp_d(&t, 0) == FP_LT)
+ fp_add(&t, &key->p, &t);
+ if (fp_cmp_d(&t, 0) == FP_LT)
+ fp_add(&t, &key->p, &t);
+ if (fp_cmp_d(&t, 0) == FP_LT)
+ lose(HAL_ERROR_IMPOSSIBLE);
+
+ /*
+ * sig = (t * u mod p) * q + m2
+ */
+ FP_CHECK(fp_mulmod(&t, &key->u, &key->p, &t));
+ fp_mul(&t, &key->q, &t);
+ fp_add(&t, &m2, sig);
+
+ fail:
+ fp_zero(&t);
+ fp_zero(&m1);
+ fp_zero(&m2);
+ return err;
+}
+
+/*
+ * Public API for raw RSA encryption and decryption.
+ */
+
+hal_error_t hal_rsa_encrypt(hal_rsa_key_t key_,
+ const uint8_t * const input, const size_t input_len,
+ uint8_t * output, const size_t output_len)
+{
+ struct rsa_key *key = key_.key;
+ hal_error_t err = HAL_OK;
+
+ if (key == NULL || input == NULL || output == NULL || input_len > output_len)
+ return HAL_ERROR_BAD_ARGUMENTS;
+
+ fp_int i, o;
+ fp_init(&i);
+ fp_init(&o);
+
+ fp_read_unsigned_bin(&i, (uint8_t *) input, input_len);
+
+ if ((err = modexp_fp(&i, &key->e, &key->n, &o)) != HAL_OK ||
+ (err = unpack_fp(&o, output, output_len)) != HAL_OK)
+ goto fail;
+
+ fail:
+ fp_zero(&i);
+ fp_zero(&o);
+ return err;
+}
+
+hal_error_t hal_rsa_decrypt(hal_rsa_key_t key_,
+ const uint8_t * const input, const size_t input_len,
+ uint8_t * output, const size_t output_len)
+{
+ struct rsa_key *key = key_.key;
+ hal_error_t err = HAL_OK;
+
+ if (key == NULL || input == NULL || output == NULL || input_len > output_len)
+ return HAL_ERROR_BAD_ARGUMENTS;
+
+ fp_int i, o;
+ fp_init(&i);
+ fp_init(&o);
+
+ fp_read_unsigned_bin(&i, (uint8_t *) input, input_len);
+
+ /*
+ * Do CRT if we have all the necessary key components, otherwise
+ * just do brute force ModExp.
+ */
+
+ if (fp_iszero(&key->p) || fp_iszero(&key->q) || fp_iszero(&key->u) || fp_iszero(&key->dP) || fp_iszero(&key->dQ))
+ err = modexp_fp(&i, &key->d, &key->n, &o);
+ else
+ err = rsa_crt(key, &i, &o);
+
+ if (err != HAL_OK || (err = unpack_fp(&o, output, output_len)) != HAL_OK)
+ goto fail;
+
+ fail:
+ fp_zero(&i);
+ fp_zero(&o);
+ return err;
+}
/*
* Clear a key. We might want to do something a bit more energetic
@@ -255,74 +372,14 @@ hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type,
return HAL_ERROR_BAD_ARGUMENTS;
}
-/*
- * RSA decyrption/signature using the Chinese Remainder Theorem
- * (Garner's formula).
- */
-
-hal_error_t hal_rsa_crt(hal_rsa_key_t key_,
- const uint8_t * const m, const size_t m_len,
- uint8_t * result, const size_t result_len)
-{
- hal_error_t err = HAL_OK;
- struct rsa_key *key = key_.key;
- struct { fp_int t, msg, m1, m2; } tmp;
-
- fp_init(&tmp.t);
- fp_init(&tmp.msg);
- fp_init(&tmp.m1);
- fp_init(&tmp.m2);
-
- fp_read_unsigned_bin(&tmp.msg, (uint8_t *) m, m_len);
-
- /*
- * m1 = msg ** dP mod p
- * m2 = msg ** dQ mod q
- */
- if ((err = modexp_fp(&tmp.msg, &key->dP, &key->p, &tmp.m1)) != HAL_OK ||
- (err = modexp_fp(&tmp.msg, &key->dQ, &key->q, &tmp.m2)) != HAL_OK)
- goto fail;
-
- /*
- * t = m1 - m2.
- * Add zero (mod p) once or twice if necessary to get positive result.
- */
- fp_sub(&tmp.m1, &tmp.m2, &tmp.t);
- if (fp_cmp_d(&tmp.t, 0) == FP_LT)
- fp_add(&tmp.t, &key->p, &tmp.t);
- if (fp_cmp_d(&tmp.t, 0) == FP_LT)
- fp_add(&tmp.t, &key->p, &tmp.t);
- if (fp_cmp_d(&tmp.t, 0) == FP_LT)
- lose(HAL_ERROR_IMPOSSIBLE);
-
- /*
- * t = (t * u mod p) * q + m2
- */
- FP_CHECK(fp_mulmod(&tmp.t, &key->u, &key->p, &tmp.t));
- fp_mul(&tmp.t, &key->q, &tmp.t);
- fp_add(&tmp.t, &tmp.m2, &tmp.t);
-
- /*
- * t now holds result, write it back to caller
- */
- if ((err = unpack_fp(&tmp.t, result, result_len)) != HAL_OK)
- goto fail;
-
- /*
- * Done, fall through into cleanup.
- */
-
- fail:
- memset(&tmp, 0, sizeof(tmp));
- return err;
-}
-
static hal_error_t find_prime(unsigned prime_length, fp_int *e, fp_int *result)
{
uint8_t buffer[prime_length];
hal_error_t err;
fp_int t;
+ fp_init(&t);
+
/*
* Get random bytes, munge a few bits, and stuff into a bignum.
* Keep doing this until we find a result that's (probably) prime
@@ -547,7 +604,7 @@ static hal_error_t decode_integer(fp_int *bn,
if (der_len != NULL)
*der_len = hlen + vlen;
- if (vlen < 1)
+ if (vlen < 1 || (der[hlen] & 0x80) != 0x00)
return HAL_ERROR_ASN1_PARSE_FAILED;
fp_init(bn);
diff --git a/tests/test-rsa.c b/tests/test-rsa.c
index 08d22c5..799f1fa 100644
--- a/tests/test-rsa.c
+++ b/tests/test-rsa.c
@@ -82,7 +82,7 @@ static int test_modexp(const char * const kind,
* Run one RSA CRT test.
*/
-static int test_crt(const char * const kind, const rsa_tc_t * const tc)
+static int test_decrypt(const char * const kind, const rsa_tc_t * const tc)
{
printf("%s test for %lu-bit RSA key\n", kind, (unsigned long) tc->size);
@@ -106,7 +106,7 @@ static int test_crt(const char * const kind, const rsa_tc_t * const tc)
uint8_t result[tc->n.len];
- if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
+ if ((err = hal_rsa_decrypt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
printf("RSA CRT failed: %s\n", hal_error_string(err));
const int mismatch = (err == HAL_OK && memcmp(result, tc->s.val, tc->s.len) != 0);
@@ -172,7 +172,7 @@ static int test_gen(const char * const kind, const rsa_tc_t * const tc)
uint8_t result[tc->n.len];
- if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
+ if ((err = hal_rsa_decrypt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
printf("RSA CRT failed: %s\n", hal_error_string(err));
snprintf(fn, sizeof(fn), "test-rsa-sig-%04lu.der", (unsigned long) tc->size);
@@ -244,7 +244,7 @@ static int test_rsa(const rsa_tc_t * const tc)
time_check(test_modexp("Signature (ModExp)", tc, &tc->m, &tc->d, &tc->s));
/* RSA decyrption using CRT */
- time_check(test_crt("Signature (CRT)", tc));
+ time_check(test_decrypt("Signature (CRT)", tc));
/* Key generation and CRT -- not test vector, so writes key and sig to file */
time_check(test_gen("Generation and CRT", tc));