diff --git a/library/pk.c b/library/pk.c index b158547613..254ed64d3e 100644 --- a/library/pk.c +++ b/library/pk.c @@ -681,20 +681,26 @@ static int import_pair_into_psa(const mbedtls_pk_context *pk, if (psa_get_key_type(attributes) != PSA_KEY_TYPE_RSA_KEY_PAIR) { return MBEDTLS_ERR_PK_TYPE_MISMATCH; } - unsigned char key_buffer[ - PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE(PSA_VENDOR_RSA_MAX_KEY_BITS)]; - unsigned char *const key_end = key_buffer + sizeof(key_buffer); + size_t key_bits = psa_get_key_bits(attributes); + size_t key_buffer_size = PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE(key_bits); + unsigned char *key_buffer = mbedtls_calloc(1, key_buffer_size); + if (key_buffer == NULL) { + return MBEDTLS_ERR_PK_ALLOC_FAILED; + } + unsigned char *const key_end = key_buffer + key_buffer_size; unsigned char *key_data = key_end; int ret = mbedtls_rsa_write_key(mbedtls_pk_rsa(*pk), key_buffer, &key_data); if (ret < 0) { - return ret; + goto cleanup_rsa; } size_t key_length = key_end - key_data; ret = PSA_PK_TO_MBEDTLS_ERR(psa_import_key(attributes, key_data, key_length, key_id)); - mbedtls_platform_zeroize(key_data, key_length); +cleanup_rsa: + mbedtls_platform_zeroize(key_buffer, key_buffer_size); + mbedtls_free(key_buffer); return ret; } #endif /* MBEDTLS_RSA_C */