diff --git a/library/pk_wrap.c b/library/pk_wrap.c index cafcb87d0d..8c038e7db6 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -288,9 +288,6 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, psa_algorithm_t psa_md_alg, decrypt_alg; psa_status_t status; int key_len; - unsigned char buf[MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES]; - unsigned char *p = buf + sizeof(buf); - ((void) f_rng); ((void) p_rng); @@ -298,6 +295,13 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, return MBEDTLS_ERR_RSA_BAD_INPUT_DATA; } + const size_t key_bits = mbedtls_pk_get_bitlen(pk); + /* mbedtls_rsa_write_key() uses the same format as PSA export, which + * actually calls it under the hood, so we can use the PSA size macro. */ + const size_t buf_size = PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE(key_bits); + unsigned char *buf = mbedtls_calloc(1, buf_size); + + unsigned char *p = buf + buf_size; key_len = mbedtls_rsa_write_key(rsa, buf, &p); if (key_len <= 0) { return MBEDTLS_ERR_PK_BAD_INPUT_DATA; @@ -314,7 +318,7 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, psa_set_key_algorithm(&attributes, decrypt_alg); status = psa_import_key(&attributes, - buf + sizeof(buf) - key_len, key_len, + buf + buf_size - key_len, key_len, &key_id); if (status != PSA_SUCCESS) { ret = PSA_PK_TO_MBEDTLS_ERR(status); @@ -333,7 +337,7 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, ret = 0; cleanup: - mbedtls_platform_zeroize(buf, sizeof(buf)); + mbedtls_zeroize_and_free(buf, buf_size); status = psa_destroy_key(key_id); if (ret == 0 && status != PSA_SUCCESS) { ret = PSA_PK_TO_MBEDTLS_ERR(status);