diff --git a/library/mbedtls_utils.h b/library/mbedtls_utils.h new file mode 100644 index 0000000000..67f74786b3 --- /dev/null +++ b/library/mbedtls_utils.h @@ -0,0 +1,23 @@ +#include "mbedtls/pk.h" +#include "psa/crypto.h" + +#ifndef MBEDTLS_UTILS_H +#define MBEDTLS_UTILS_H + +/* Return the PSA algorithm associated to the given combination of "sigalg" and "hash_alg". */ +static inline psa_algorithm_t mbedtls_psa_alg_from_pk_sigalg(mbedtls_pk_sigalg_t sigalg, + psa_algorithm_t hash_alg) +{ + switch (sigalg) { + case MBEDTLS_PK_SIGALG_RSA_PKCS1V15: + return PSA_ALG_RSA_PKCS1V15_SIGN(hash_alg); + case MBEDTLS_PK_SIGALG_RSA_PSS: + return PSA_ALG_RSA_PSS(hash_alg); + case MBEDTLS_PK_SIGALG_ECDSA: + return MBEDTLS_PK_ALG_ECDSA(hash_alg); + default: + return PSA_ALG_NONE; + } +} + +#endif /* MBEDTLS_UTILS_H */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 36c6bf9586..be071defac 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -5605,13 +5605,15 @@ void mbedtls_ssl_config_free(mbedtls_ssl_config *conf) */ unsigned char mbedtls_ssl_sig_from_pk(mbedtls_pk_context *pk) { + psa_key_type_t key_type = mbedtls_pk_get_key_type(pk); + #if defined(MBEDTLS_RSA_C) - if (mbedtls_pk_can_do(pk, MBEDTLS_PK_RSA)) { + if (PSA_KEY_TYPE_IS_RSA(key_type)) { return MBEDTLS_SSL_SIG_RSA; } #endif #if defined(MBEDTLS_KEY_EXCHANGE_ECDSA_CERT_REQ_ANY_ALLOWED_ENABLED) - if (mbedtls_pk_can_do(pk, MBEDTLS_PK_ECDSA)) { + if (PSA_KEY_TYPE_IS_ECC(key_type)) { return MBEDTLS_SSL_SIG_ECDSA; } #endif @@ -8780,7 +8782,7 @@ int mbedtls_ssl_verify_certificate(mbedtls_ssl_context *ssl, #if defined(MBEDTLS_SSL_PROTO_TLS1_2) && \ defined(PSA_WANT_KEY_TYPE_ECC_PUBLIC_KEY) if (ssl->tls_version == MBEDTLS_SSL_VERSION_TLS1_2 && - mbedtls_pk_can_do(&chain->pk, MBEDTLS_PK_ECKEY)) { + PSA_KEY_TYPE_IS_ECC(mbedtls_pk_get_type(&chain->pk))) { if (mbedtls_ssl_check_curve(ssl, mbedtls_pk_get_ec_group_id(&chain->pk)) != 0) { MBEDTLS_SSL_DEBUG_MSG(1, ("bad certificate (EC key curve)")); ssl->session_negotiate->verify_result |= MBEDTLS_X509_BADCERT_BAD_KEY; diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c index c4f75b63de..ebcc0d56bb 100644 --- a/library/ssl_tls12_client.c +++ b/library/ssl_tls12_client.c @@ -16,6 +16,7 @@ #include "debug_internal.h" #include "mbedtls/error.h" #include "mbedtls/constant_time.h" +#include "mbedtls_utils.h" #include "psa/crypto.h" #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) @@ -1883,6 +1884,7 @@ start_processing: unsigned char hash[MBEDTLS_MD_MAX_SIZE]; mbedtls_md_type_t md_alg = MBEDTLS_MD_NONE; + psa_algorithm_t psa_hash_alg; mbedtls_pk_sigalg_t pk_alg = MBEDTLS_PK_SIGALG_NONE; unsigned char *params = ssl->in_msg + mbedtls_ssl_hs_hdr_len(ssl); size_t params_len = (size_t) (p - params); @@ -1921,7 +1923,10 @@ start_processing: } p += 2; - if (!mbedtls_pk_can_do(peer_pk, (mbedtls_pk_type_t) pk_alg)) { + psa_hash_alg = mbedtls_md_psa_alg_from_type(md_alg); + if (!mbedtls_pk_can_do_psa(peer_pk, + mbedtls_psa_alg_from_pk_sigalg(pk_alg, psa_hash_alg), + PSA_KEY_USAGE_VERIFY_HASH)) { MBEDTLS_SSL_DEBUG_MSG(1, ("bad server key exchange message")); mbedtls_ssl_send_alert_message( @@ -1977,14 +1982,6 @@ start_processing: /* * Verify signature */ - if (!mbedtls_pk_can_do(peer_pk, (mbedtls_pk_type_t) pk_alg)) { - MBEDTLS_SSL_DEBUG_MSG(1, ("bad server key exchange message")); - mbedtls_ssl_send_alert_message( - ssl, - MBEDTLS_SSL_ALERT_LEVEL_FATAL, - MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE); - return MBEDTLS_ERR_SSL_PK_TYPE_MISMATCH; - } #if defined(MBEDTLS_SSL_ECP_RESTARTABLE_ENABLED) if (ssl->handshake->ecrs_enabled) { diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c index 1f4ac3ea79..ec4446c1b4 100644 --- a/library/ssl_tls12_server.c +++ b/library/ssl_tls12_server.c @@ -16,6 +16,7 @@ #include "mbedtls/error.h" #include "mbedtls/platform_util.h" #include "mbedtls/constant_time.h" +#include "mbedtls_utils.h" #include @@ -3324,6 +3325,7 @@ static int ssl_parse_certificate_verify(mbedtls_ssl_context *ssl) const mbedtls_ssl_ciphersuite_t *ciphersuite_info = ssl->handshake->ciphersuite_info; mbedtls_pk_context *peer_pk; + psa_algorithm_t psa_sig_alg; MBEDTLS_SSL_DEBUG_MSG(2, ("=> parse certificate verify")); @@ -3421,7 +3423,8 @@ static int ssl_parse_certificate_verify(mbedtls_ssl_context *ssl) /* * Check the certificate's key type matches the signature alg */ - if (!mbedtls_pk_can_do(peer_pk, (mbedtls_pk_type_t) pk_alg)) { + psa_sig_alg = mbedtls_psa_alg_from_pk_sigalg(pk_alg, mbedtls_md_psa_alg_from_type(md_alg)); + if (!mbedtls_pk_can_do_psa(peer_pk, psa_sig_alg, PSA_KEY_USAGE_VERIFY_HASH)) { MBEDTLS_SSL_DEBUG_MSG(1, ("sig_alg doesn't match cert key")); return MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER; } diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c index c7d3d48561..078daea352 100644 --- a/library/ssl_tls13_generic.c +++ b/library/ssl_tls13_generic.c @@ -18,6 +18,7 @@ #include "mbedtls/constant_time.h" #include "psa/crypto.h" #include "mbedtls/psa_util.h" +#include "mbedtls_utils.h" #include "ssl_tls13_invasive.h" #include "ssl_tls13_keys.h" @@ -276,7 +277,9 @@ static int ssl_tls13_parse_certificate_verify(mbedtls_ssl_context *ssl, /* * Check the certificate's key type matches the signature alg */ - if (!mbedtls_pk_can_do(&ssl->session_negotiate->peer_cert->pk, (mbedtls_pk_type_t) sig_alg)) { + if (!mbedtls_pk_can_do_psa(&ssl->session_negotiate->peer_cert->pk, + mbedtls_psa_alg_from_pk_sigalg(sig_alg, hash_alg), + PSA_KEY_USAGE_VERIFY_HASH)) { MBEDTLS_SSL_DEBUG_MSG(1, ("signature algorithm doesn't match cert key")); goto error; } diff --git a/library/x509_crt.c b/library/x509_crt.c index e18dbe777e..61dca746a3 100644 --- a/library/x509_crt.c +++ b/library/x509_crt.c @@ -43,6 +43,8 @@ #include "mbedtls/threading.h" #endif +#include "mbedtls_utils.h" + #if defined(MBEDTLS_HAVE_TIME) #if defined(_WIN32) && !defined(EFIX64) && !defined(EFI32) #ifndef WIN32_LEAN_AND_MEAN @@ -2108,6 +2110,13 @@ static int x509_crt_check_signature(const mbedtls_x509_crt *child, psa_algorithm_t hash_alg = mbedtls_md_psa_alg_from_type(child->sig_md); psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + /* Skip expensive computation on obvious mismatch */ + if (!mbedtls_pk_can_do_psa(&parent->pk, + mbedtls_psa_alg_from_pk_sigalg(child->sig_pk, hash_alg), + PSA_KEY_USAGE_VERIFY_HASH)) { + return -1; + } + status = psa_hash_compute(hash_alg, child->tbs.p, child->tbs.len, @@ -2118,11 +2127,6 @@ static int x509_crt_check_signature(const mbedtls_x509_crt *child, return MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED; } - /* Skip expensive computation on obvious mismatch */ - if (!mbedtls_pk_can_do(&parent->pk, (mbedtls_pk_type_t) child->sig_pk)) { - return -1; - } - #if defined(MBEDTLS_ECP_RESTARTABLE) if (rs_ctx != NULL && child->sig_pk == MBEDTLS_PK_SIGALG_ECDSA) { return mbedtls_pk_verify_restartable(&parent->pk, diff --git a/library/x509write_crt.c b/library/x509write_crt.c index 399c923097..8c77f10c34 100644 --- a/library/x509write_crt.c +++ b/library/x509write_crt.c @@ -392,6 +392,7 @@ int mbedtls_x509write_crt_der(mbedtls_x509write_cert *ctx, unsigned char hash[MBEDTLS_MD_MAX_SIZE]; psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; psa_algorithm_t psa_algorithm; + psa_key_type_t key_type = mbedtls_pk_get_key_type(ctx->issuer_key); size_t sub_len = 0, pub_len = 0, sig_and_oid_len = 0, sig_len; size_t len = 0; @@ -407,9 +408,9 @@ int mbedtls_x509write_crt_der(mbedtls_x509write_cert *ctx, /* There's no direct way of extracting a signature algorithm * (represented as an element of mbedtls_pk_type_t) from a PK instance. */ - if (mbedtls_pk_can_do(ctx->issuer_key, MBEDTLS_PK_RSA)) { + if (PSA_KEY_TYPE_IS_RSA(key_type)) { pk_alg = MBEDTLS_PK_SIGALG_RSA_PKCS1V15; - } else if (mbedtls_pk_can_do(ctx->issuer_key, MBEDTLS_PK_ECDSA)) { + } else if (PSA_KEY_TYPE_IS_ECC(key_type)) { pk_alg = MBEDTLS_PK_SIGALG_ECDSA; } else { return MBEDTLS_ERR_X509_INVALID_ALG; diff --git a/library/x509write_csr.c b/library/x509write_csr.c index 8a81f7ee56..22651032b1 100644 --- a/library/x509write_csr.c +++ b/library/x509write_csr.c @@ -144,6 +144,7 @@ static int x509write_csr_der_internal(mbedtls_x509write_csr *ctx, mbedtls_pk_sigalg_t pk_alg; size_t hash_len; psa_algorithm_t hash_alg = mbedtls_md_psa_alg_from_type(ctx->md_alg); + psa_key_type_t key_type = mbedtls_pk_get_key_type(ctx->key); /* Write the CSR backwards starting from the end of buf */ c = buf + size; @@ -217,9 +218,9 @@ static int x509write_csr_der_internal(mbedtls_x509write_csr *ctx, return MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED; } - if (mbedtls_pk_can_do(ctx->key, MBEDTLS_PK_RSA)) { + if (PSA_KEY_TYPE_IS_RSA(key_type)) { pk_alg = MBEDTLS_PK_SIGALG_RSA_PKCS1V15; - } else if (mbedtls_pk_can_do(ctx->key, MBEDTLS_PK_ECDSA)) { + } else if (PSA_KEY_TYPE_IS_ECC(key_type)) { pk_alg = MBEDTLS_PK_SIGALG_ECDSA; } else { return MBEDTLS_ERR_X509_INVALID_ALG;