diff --git a/library/pkwrite.c b/library/pkwrite.c index 2a698448be..ff6c0bfb44 100644 --- a/library/pkwrite.c +++ b/library/pkwrite.c @@ -64,22 +64,24 @@ static int pk_write_rsa_der(unsigned char **p, unsigned char *buf, { #if defined(MBEDTLS_USE_PSA_CRYPTO) if (mbedtls_pk_get_type(pk) == MBEDTLS_PK_OPAQUE) { - uint8_t tmp[PSA_EXPORT_KEY_PAIR_MAX_SIZE]; - size_t tmp_len = 0; + psa_status_t status; + size_t buf_size = (size_t) (*p - buf); + size_t key_len = 0; - if (psa_export_key(pk->priv_id, tmp, sizeof(tmp), &tmp_len) != PSA_SUCCESS) { - return MBEDTLS_ERR_PK_BAD_INPUT_DATA; - } - /* Ensure there's enough space in the provided buffer before copying data into it. */ - if (tmp_len > (size_t) (*p - buf)) { - mbedtls_platform_zeroize(tmp, sizeof(tmp)); + status = psa_export_key(pk->priv_id, buf, buf_size, &key_len); + if (status == PSA_ERROR_BUFFER_TOO_SMALL) { return MBEDTLS_ERR_ASN1_BUF_TOO_SMALL; + } else if (status != PSA_SUCCESS) { + return PSA_PK_RSA_TO_MBEDTLS_ERR(status); } - *p -= tmp_len; - memcpy(*p, tmp, tmp_len); - mbedtls_platform_zeroize(tmp, sizeof(tmp)); - return (int) tmp_len; + /* We wrote to the beginning of the buffer while + * we were supposed to write to its end. */ + *p -= key_len; + memmove(*p, buf, key_len); + mbedtls_platform_zeroize(buf, *p - buf); + + return (int) key_len; } #endif /* MBEDTLS_USE_PSA_CRYPTO */ return mbedtls_rsa_write_key(mbedtls_pk_rsa(*pk), buf, p);