diff --git a/library/ssl_cli.c b/library/ssl_cli.c index a6f9e7dd1b..023fac6af8 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -204,67 +204,15 @@ static int ssl_write_renegotiation_ext( mbedtls_ssl_context *ssl, */ #if defined(MBEDTLS_SSL_PROTO_TLS1_2) && \ defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED) + static int ssl_write_sig_alg_ext( mbedtls_ssl_context *ssl, unsigned char *buf, - const unsigned char *end, size_t *olen ) + const unsigned char *end, size_t *out_len ) { unsigned char *p = buf; - size_t sig_alg_len = 0; - const int *md = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); + unsigned char *supported_sig_alg; /* Start of supported_signature_algorithms */ + size_t supported_sig_alg_len = 0; /* Length of supported_signature_algorithms */ -#if defined(MBEDTLS_RSA_C) || defined(MBEDTLS_ECDSA_C) - unsigned char *sig_alg_list = buf + 6; -#endif - - *olen = 0; - - if( ssl->conf->max_minor_ver != MBEDTLS_SSL_MINOR_VERSION_3 ) - return( 0 ); - - MBEDTLS_SSL_DEBUG_MSG( 3, - ( "client hello, adding signature_algorithms extension" ) ); - - if( md == NULL ) - return( MBEDTLS_ERR_SSL_BAD_CONFIG ); - - for( ; *md != MBEDTLS_MD_NONE; md++ ) - { -#if defined(MBEDTLS_ECDSA_C) - sig_alg_len += 2; -#endif -#if defined(MBEDTLS_RSA_C) - sig_alg_len += 2; -#endif - if( sig_alg_len > MBEDTLS_SSL_MAX_SIG_HASH_ALG_LIST_LEN ) - { - MBEDTLS_SSL_DEBUG_MSG( 3, - ( "length in bytes of sig-hash-alg extension too big" ) ); - return( MBEDTLS_ERR_SSL_BAD_CONFIG ); - } - } - - /* Empty signature algorithms list, this is a configuration error. */ - if( sig_alg_len == 0 ) - return( MBEDTLS_ERR_SSL_BAD_CONFIG ); - - MBEDTLS_SSL_CHK_BUF_PTR( p, end, sig_alg_len + 6 ); - - /* - * Prepare signature_algorithms extension (TLS 1.2) - */ - sig_alg_len = 0; - - for( md = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); - *md != MBEDTLS_MD_NONE; md++ ) - { -#if defined(MBEDTLS_ECDSA_C) - sig_alg_list[sig_alg_len++] = mbedtls_ssl_hash_from_md_alg( *md ); - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; -#endif -#if defined(MBEDTLS_RSA_C) - sig_alg_list[sig_alg_len++] = mbedtls_ssl_hash_from_md_alg( *md ); - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; -#endif - } + *out_len = 0; /* * enum { @@ -283,16 +231,47 @@ static int ssl_write_sig_alg_ext( mbedtls_ssl_context *ssl, unsigned char *buf, * SignatureAndHashAlgorithm * supported_signature_algorithms<2..2^16-2>; */ - MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_SIG_ALG, p, 0 ); - p += 2; - MBEDTLS_PUT_UINT16_BE( sig_alg_len + 2, p, 0 ); - p += 2; + MBEDTLS_SSL_DEBUG_MSG( 3, ( "adding signature_algorithms extension" ) ); - MBEDTLS_PUT_UINT16_BE( sig_alg_len, p, 0 ); - p += 2; + /* Check if we have space for header and length field: + * - extension_type (2 bytes) + * - extension_data_length (2 bytes) + * - supported_signature_algorithms_length (2 bytes) + */ + MBEDTLS_SSL_CHK_BUF_PTR( p, end, 6 ); + p += 6; - *olen = 6 + sig_alg_len; + /* + * Write supported_signature_algorithms + */ + supported_sig_alg = p; + for( const uint16_t *sig_alg = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); + *sig_alg != MBEDTLS_TLS1_3_SIG_NONE; sig_alg++ ) + { + MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 ); + MBEDTLS_PUT_UINT16_BE( *sig_alg, p, 0 ); + p += 2; + MBEDTLS_SSL_DEBUG_MSG( 3, ( "signature scheme [%x]", *sig_alg ) ); + } + + /* Length of supported_signature_algorithms */ + supported_sig_alg_len = p - supported_sig_alg; + if( supported_sig_alg_len == 0 ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "No signature algorithms defined." ) ); + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + + /* Write extension_type */ + MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_SIG_ALG, buf, 0 ); + /* Write extension_data_length */ + MBEDTLS_PUT_UINT16_BE( supported_sig_alg_len + 2, buf, 2 ); + /* Write length of supported_signature_algorithms */ + MBEDTLS_PUT_UINT16_BE( supported_sig_alg_len, buf, 4 ); + + /* Output the total length of signature algorithms extension. */ + *out_len = p - buf; return( 0 ); } diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 9eaeab407f..eb3550eb5b 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2793,27 +2793,32 @@ static int ssl_write_certificate_request( mbedtls_ssl_context *ssl ) */ if( ssl->minor_ver == MBEDTLS_SSL_MINOR_VERSION_3 ) { - const int *cur; - /* * Supported signature algorithms */ - for( cur = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); - *cur != MBEDTLS_MD_NONE; cur++ ) + for( const uint16_t *sig_alg = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); + *sig_alg != MBEDTLS_TLS1_3_SIG_NONE; sig_alg++ ) { - unsigned char hash = mbedtls_ssl_hash_from_md_alg( *cur ); + /* High byte is hash */ + unsigned char hash = ( *sig_alg >> 8 ) & 0xff; + unsigned char sig = ( *sig_alg ) & 0xff; if( MBEDTLS_SSL_HASH_NONE == hash || mbedtls_ssl_set_calc_verify_md( ssl, hash ) ) continue; +#if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_ECDSA_C) + if( sig != MBEDTLS_SSL_SIG_RSA && sig != MBEDTLS_SSL_SIG_ECDSA ) + continue; +#elif defined(MBEDTLS_RSA_C) + if( sig != MBEDTLS_SSL_SIG_RSA ) + continue; +#elif defined(MBEDTLS_ECDSA_C) + if( sig != MBEDTLS_SSL_SIG_ECDSA ) + continue; +#endif + + MBEDTLS_PUT_UINT16_BE( *sig_alg, p, sa_len ); + sa_len += 2; -#if defined(MBEDTLS_RSA_C) - p[2 + sa_len++] = hash; - p[2 + sa_len++] = MBEDTLS_SSL_SIG_RSA; -#endif -#if defined(MBEDTLS_ECDSA_C) - p[2 + sa_len++] = hash; - p[2 + sa_len++] = MBEDTLS_SSL_SIG_ECDSA; -#endif } MBEDTLS_PUT_UINT16_BE( sa_len, p, 0 ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index d6b3baa43a..8cdeb8d8b5 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -6858,14 +6858,18 @@ int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_i int mbedtls_ssl_check_sig_hash( const mbedtls_ssl_context *ssl, mbedtls_md_type_t md ) { - const int *cur = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); - if( cur == NULL ) + const uint16_t *sig_alg = mbedtls_ssl_conf_get_sig_algs( ssl->conf ); + if( sig_alg == NULL ) return( -1 ); - for( ; *cur != MBEDTLS_MD_NONE; cur++ ) - if( *cur == (int) md ) + for( ; *sig_alg != MBEDTLS_TLS1_3_SIG_NONE; sig_alg++ ) + { + mbedtls_md_type_t hash = mbedtls_ssl_md_alg_from_hash( + ( *sig_alg >> 8 ) & 0xff ); + if( hash == md ) return( 0 ); + } return( -1 ); }