diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c index 131efbe248..ebcc0d56bb 100644 --- a/library/ssl_tls12_client.c +++ b/library/ssl_tls12_client.c @@ -16,6 +16,7 @@ #include "debug_internal.h" #include "mbedtls/error.h" #include "mbedtls/constant_time.h" +#include "mbedtls_utils.h" #include "psa/crypto.h" #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) @@ -1883,6 +1884,7 @@ start_processing: unsigned char hash[MBEDTLS_MD_MAX_SIZE]; mbedtls_md_type_t md_alg = MBEDTLS_MD_NONE; + psa_algorithm_t psa_hash_alg; mbedtls_pk_sigalg_t pk_alg = MBEDTLS_PK_SIGALG_NONE; unsigned char *params = ssl->in_msg + mbedtls_ssl_hs_hdr_len(ssl); size_t params_len = (size_t) (p - params); @@ -1921,7 +1923,10 @@ start_processing: } p += 2; - if (!mbedtls_pk_can_do(peer_pk, (mbedtls_pk_type_t) pk_alg)) { + psa_hash_alg = mbedtls_md_psa_alg_from_type(md_alg); + if (!mbedtls_pk_can_do_psa(peer_pk, + mbedtls_psa_alg_from_pk_sigalg(pk_alg, psa_hash_alg), + PSA_KEY_USAGE_VERIFY_HASH)) { MBEDTLS_SSL_DEBUG_MSG(1, ("bad server key exchange message")); mbedtls_ssl_send_alert_message( diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c index 1f4ac3ea79..c02aeeaa08 100644 --- a/library/ssl_tls12_server.c +++ b/library/ssl_tls12_server.c @@ -16,6 +16,7 @@ #include "mbedtls/error.h" #include "mbedtls/platform_util.h" #include "mbedtls/constant_time.h" +#include "mbedtls_utils.h" #include @@ -3421,7 +3422,9 @@ static int ssl_parse_certificate_verify(mbedtls_ssl_context *ssl) /* * Check the certificate's key type matches the signature alg */ - if (!mbedtls_pk_can_do(peer_pk, (mbedtls_pk_type_t) pk_alg)) { + if (!mbedtls_pk_can_do_psa(peer_pk, + mbedtls_psa_alg_from_pk_sigalg(pk_alg, PSA_ALG_ANY_HASH), + PSA_KEY_USAGE_VERIFY_HASH)) { MBEDTLS_SSL_DEBUG_MSG(1, ("sig_alg doesn't match cert key")); return MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER; } diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c index c7d3d48561..078daea352 100644 --- a/library/ssl_tls13_generic.c +++ b/library/ssl_tls13_generic.c @@ -18,6 +18,7 @@ #include "mbedtls/constant_time.h" #include "psa/crypto.h" #include "mbedtls/psa_util.h" +#include "mbedtls_utils.h" #include "ssl_tls13_invasive.h" #include "ssl_tls13_keys.h" @@ -276,7 +277,9 @@ static int ssl_tls13_parse_certificate_verify(mbedtls_ssl_context *ssl, /* * Check the certificate's key type matches the signature alg */ - if (!mbedtls_pk_can_do(&ssl->session_negotiate->peer_cert->pk, (mbedtls_pk_type_t) sig_alg)) { + if (!mbedtls_pk_can_do_psa(&ssl->session_negotiate->peer_cert->pk, + mbedtls_psa_alg_from_pk_sigalg(sig_alg, hash_alg), + PSA_KEY_USAGE_VERIFY_HASH)) { MBEDTLS_SSL_DEBUG_MSG(1, ("signature algorithm doesn't match cert key")); goto error; }