diff --git a/library/pkwrite.c b/library/pkwrite.c index 2a698448be..ff6c0bfb44 100644 --- a/library/pkwrite.c +++ b/library/pkwrite.c @@ -64,22 +64,24 @@ static int pk_write_rsa_der(unsigned char **p, unsigned char *buf, { #if defined(MBEDTLS_USE_PSA_CRYPTO) if (mbedtls_pk_get_type(pk) == MBEDTLS_PK_OPAQUE) { - uint8_t tmp[PSA_EXPORT_KEY_PAIR_MAX_SIZE]; - size_t tmp_len = 0; + psa_status_t status; + size_t buf_size = (size_t) (*p - buf); + size_t key_len = 0; - if (psa_export_key(pk->priv_id, tmp, sizeof(tmp), &tmp_len) != PSA_SUCCESS) { - return MBEDTLS_ERR_PK_BAD_INPUT_DATA; - } - /* Ensure there's enough space in the provided buffer before copying data into it. */ - if (tmp_len > (size_t) (*p - buf)) { - mbedtls_platform_zeroize(tmp, sizeof(tmp)); + status = psa_export_key(pk->priv_id, buf, buf_size, &key_len); + if (status == PSA_ERROR_BUFFER_TOO_SMALL) { return MBEDTLS_ERR_ASN1_BUF_TOO_SMALL; + } else if (status != PSA_SUCCESS) { + return PSA_PK_RSA_TO_MBEDTLS_ERR(status); } - *p -= tmp_len; - memcpy(*p, tmp, tmp_len); - mbedtls_platform_zeroize(tmp, sizeof(tmp)); - return (int) tmp_len; + /* We wrote to the beginning of the buffer while + * we were supposed to write to its end. */ + *p -= key_len; + memmove(*p, buf, key_len); + mbedtls_platform_zeroize(buf, *p - buf); + + return (int) key_len; } #endif /* MBEDTLS_USE_PSA_CRYPTO */ return mbedtls_rsa_write_key(mbedtls_pk_rsa(*pk), buf, p); diff --git a/tests/suites/test_suite_pkwrite.function b/tests/suites/test_suite_pkwrite.function index 491bc489aa..691125e416 100644 --- a/tests/suites/test_suite_pkwrite.function +++ b/tests/suites/test_suite_pkwrite.function @@ -66,15 +66,59 @@ static int pk_write_any_key(mbedtls_pk_context *pk, unsigned char **p, return 0; } +static int pk_write_check_context(mbedtls_pk_context *key, + int is_public_key, int is_der, + unsigned char *check_buf, size_t check_buf_len) +{ + int ret = -1; + unsigned char *buf = NULL; + int expected_error = is_der ? + MBEDTLS_ERR_ASN1_BUF_TOO_SMALL : + MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL; + + /* Test with: + * - buffer too small (all sizes) + * - buffer exactly the right size + * - buffer a bit larger - DER functions should write to the end of the + * buffer, and we can only tell the difference with a larger buffer + */ + for (size_t buf_size = 1; buf_size <= check_buf_len + 2; buf_size++) { + mbedtls_free(buf); + buf = NULL; + TEST_CALLOC(buf, buf_size); + + unsigned char *start_buf = buf; + size_t out_len = buf_size; + int expected_result = buf_size < check_buf_len ? expected_error : 0; + mbedtls_test_set_step(buf_size); + + TEST_EQUAL(pk_write_any_key(key, &start_buf, &out_len, is_public_key, + is_der), expected_result); + + if (expected_result == 0) { + TEST_MEMORY_COMPARE(start_buf, out_len, check_buf, check_buf_len); + + if (is_der) { + /* Data should be at the end of the buffer */ + TEST_ASSERT(start_buf + out_len == buf + buf_size); + } + } + } + + ret = 0; + +exit: + mbedtls_free(buf); + return ret; +} + + static void pk_write_check_common(char *key_file, int is_public_key, int is_der) { mbedtls_pk_context key; mbedtls_pk_init(&key); - unsigned char *buf = NULL; unsigned char *check_buf = NULL; - unsigned char *start_buf; - size_t buf_len, check_buf_len; - int expected_result; + size_t check_buf_len; #if defined(MBEDTLS_USE_PSA_CRYPTO) mbedtls_svc_key_id_t opaque_id = MBEDTLS_SVC_KEY_ID_INIT; psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT; @@ -100,8 +144,6 @@ static void pk_write_check_common(char *key_file, int is_public_key, int is_der) } TEST_ASSERT(check_buf_len > 0); - TEST_CALLOC(buf, check_buf_len); - if (is_public_key) { TEST_EQUAL(mbedtls_pk_parse_public_keyfile(&key, key_file), 0); } else { @@ -109,28 +151,14 @@ static void pk_write_check_common(char *key_file, int is_public_key, int is_der) mbedtls_test_rnd_std_rand, NULL), 0); } - start_buf = buf; - buf_len = check_buf_len; - if (is_der) { - expected_result = MBEDTLS_ERR_ASN1_BUF_TOO_SMALL; - } else { - expected_result = MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL; + if (pk_write_check_context(&key, is_public_key, is_der, + check_buf, check_buf_len) != 0) { + goto exit; } - /* Intentionally pass a wrong size for the provided output buffer and check - * that the writing functions fails as expected. */ - for (size_t i = 1; i < buf_len; i++) { - TEST_EQUAL(pk_write_any_key(&key, &start_buf, &i, is_public_key, - is_der), expected_result); - } - TEST_EQUAL(pk_write_any_key(&key, &start_buf, &buf_len, is_public_key, - is_der), 0); - - TEST_MEMORY_COMPARE(start_buf, buf_len, check_buf, check_buf_len); #if defined(MBEDTLS_USE_PSA_CRYPTO) /* Verify that pk_write works also for opaque private keys */ if (!is_public_key) { - memset(buf, 0, check_buf_len); /* Turn the key PK context into an opaque one. * Note: set some practical usage for the key to make get_psa_attributes() happy. */ TEST_EQUAL(mbedtls_pk_get_psa_attributes(&key, PSA_KEY_USAGE_SIGN_MESSAGE, &key_attr), 0); @@ -138,18 +166,11 @@ static void pk_write_check_common(char *key_file, int is_public_key, int is_der) mbedtls_pk_free(&key); mbedtls_pk_init(&key); TEST_EQUAL(mbedtls_pk_setup_opaque(&key, opaque_id), 0); - start_buf = buf; - buf_len = check_buf_len; - /* Intentionally pass a wrong size for the provided output buffer and check - * that the writing functions fails as expected. */ - for (size_t i = 1; i < buf_len; i++) { - TEST_EQUAL(pk_write_any_key(&key, &start_buf, &i, is_public_key, - is_der), expected_result); - } - TEST_EQUAL(pk_write_any_key(&key, &start_buf, &buf_len, is_public_key, - is_der), 0); - TEST_MEMORY_COMPARE(start_buf, buf_len, check_buf, check_buf_len); + if (pk_write_check_context(&key, is_public_key, is_der, + check_buf, check_buf_len) != 0) { + goto exit; + } } #endif /* MBEDTLS_USE_PSA_CRYPTO */ @@ -157,7 +178,6 @@ exit: #if defined(MBEDTLS_USE_PSA_CRYPTO) psa_destroy_key(opaque_id); #endif /* MBEDTLS_USE_PSA_CRYPTO */ - mbedtls_free(buf); mbedtls_free(check_buf); mbedtls_pk_free(&key); USE_PSA_DONE();