diff --git a/test/ossl_shim/ossl_shim.cc b/test/ossl_shim/ossl_shim.cc index b1067e842092c717ed504d2eecd1d4e549578a4b..90d1f1ef402706049ea865f4f1f57563c6c91f79 100644 --- a/test/ossl_shim/ossl_shim.cc +++ b/test/ossl_shim/ossl_shim.cc @@ -459,6 +459,20 @@ static int CustomExtensionParseCallback(SSL *ssl, unsigned extension_value, return 1; } +static int ServerNameCallback(SSL *ssl, int *out_alert, void *arg) { + // SNI must be accessible from the SNI callback. + const TestConfig *config = GetTestConfig(ssl); + const char *server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (server_name == nullptr || + std::string(server_name) != config->expected_server_name) { + fprintf(stderr, "servername mismatch (got %s; want %s)\n", server_name, + config->expected_server_name.c_str()); + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + + return SSL_TLSEXT_ERR_OK; +} + // Connect returns a new socket connected to localhost on |port| or -1 on // error. static int Connect(uint16_t port) { @@ -645,6 +659,10 @@ static bssl::UniquePtr SetupCtx(const TestConfig *config) { sizeof(sess_id_ctx) - 1)) return nullptr; + if (!config->expected_server_name.empty()) { + SSL_CTX_set_tlsext_servername_callback(ssl_ctx.get(), ServerNameCallback); + } + return ssl_ctx; } @@ -809,7 +827,8 @@ static bool CheckHandshakeProperties(SSL *ssl, bool is_resume) { if (!config->expected_server_name.empty()) { const char *server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); - if (server_name != config->expected_server_name) { + if (server_name == nullptr || + std::string(server_name) != config->expected_server_name) { fprintf(stderr, "servername mismatch (got %s; want %s)\n", server_name, config->expected_server_name.c_str()); return false;