diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index e3079e8a836f6d86b67fa8a40a90ba8359af3c12..f3a58a37164eb3ba1b58564ba3a7fb07558ffea0 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -32,52 +32,6 @@ public CertificateCallbackMapper(Func ConnectAsync(Func> callback, DnsEndPoint endPoint, HttpRequestMessage requestMessage, CancellationToken cancellationToken) - { - Stream stream; - try - { - stream = await callback(new SocketsHttpConnectionContext(endPoint, requestMessage), cancellationToken).ConfigureAwait(false); - } - catch (OperationCanceledException ex) when (ex.CancellationToken == cancellationToken) - { - throw CancellationHelper.CreateOperationCanceledException(innerException: null, cancellationToken); - } - catch (Exception ex) - { - throw CreateWrappedException(ex, endPoint.Host, endPoint.Port, cancellationToken); - } - - if (stream == null) - { - throw new HttpRequestException(SR.net_http_null_from_connect_callback); - } - - return stream; - } - - public static Stream Connect(string host, int port, CancellationToken cancellationToken) - { - // For synchronous connections, we can just create a socket and make the connection. - cancellationToken.ThrowIfCancellationRequested(); - var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); - try - { - socket.NoDelay = true; - using (cancellationToken.UnsafeRegister(static s => ((Socket)s!).Dispose(), socket)) - { - socket.Connect(new DnsEndPoint(host, port)); - } - - return new NetworkStream(socket, ownsSocket: true); - } - catch (Exception e) - { - socket.Dispose(); - throw CreateWrappedException(e, host, port, cancellationToken); - } - } - public static ValueTask EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken) { // If there's a cert validation callback, and if it came from HttpClientHandler, @@ -161,7 +115,7 @@ public static async ValueTask ConnectQuicAsync(QuicImplementatio } } - private static Exception CreateWrappedException(Exception error, string host, int port, CancellationToken cancellationToken) + internal static Exception CreateWrappedException(Exception error, string host, int port, CancellationToken cancellationToken) { return CancellationHelper.ShouldWrapInOperationCanceledException(error, cancellationToken) ? CancellationHelper.CreateOperationCanceledException(error, cancellationToken) : diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index cfd4252c3db9fe60fd4ca1daf4187708351ea78d..43fdcd46a9b5126ab719902c62f2478771240a58 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -1282,49 +1282,62 @@ private async ValueTask<(Stream?, TransportContext?, HttpResponseMessage?)> Conn } } - private static async ValueTask DefaultConnectAsync(SocketsHttpConnectionContext context, CancellationToken cancellationToken) + private async ValueTask ConnectToTcpHostAsync(string host, int port, HttpRequestMessage initialRequest, bool async, CancellationToken cancellationToken) { - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); - socket.NoDelay = true; + cancellationToken.ThrowIfCancellationRequested(); + var endPoint = new DnsEndPoint(host, port); + Socket? socket = null; try { - await socket.ConnectAsync(context.DnsEndPoint, cancellationToken).ConfigureAwait(false); - return new NetworkStream(socket, ownsSocket: true); - } - catch - { - socket.Dispose(); - throw; - } - } - - private static readonly Func> s_defaultConnectCallback = DefaultConnectAsync; - - private ValueTask ConnectToTcpHostAsync(string host, int port, HttpRequestMessage initialRequest, bool async, CancellationToken cancellationToken) - { - if (async) - { - Func> connectCallback = Settings._connectCallback ?? s_defaultConnectCallback; + // If a ConnectCallback was supplied, use that to establish the connection. + if (Settings._connectCallback != null) + { + ValueTask streamTask = Settings._connectCallback(new SocketsHttpConnectionContext(endPoint, initialRequest), cancellationToken); - var endPoint = new DnsEndPoint(host, port); - return ConnectHelper.ConnectAsync(connectCallback, endPoint, initialRequest, cancellationToken); - } + Stream stream; + if (async || streamTask.IsCompleted) + { + stream = await streamTask.ConfigureAwait(false); + } + else + { + // User-provided ConnectCallback is completing asynchronously but the user is making a synchronous request; if the user cares, they should + // set it up so that synchronous requests are made on a handler with a synchronously-completing ConnectCallback supplied. If in the future, + // we could add a Boolean to SocketsHttpConnectionContext (https://github.com/dotnet/runtime/issues/44876) to let the callback know whether + // this request is sync or async. For now, log it and block. + Trace($"{nameof(SocketsHttpHandler.ConnectCallback)} completing asynchronously for a synchronous request."); + stream = streamTask.AsTask().GetAwaiter().GetResult(); + } - // Synchronous path. + return stream ?? throw new HttpRequestException(SR.net_http_null_from_connect_callback); + } + else + { + // Otherwise, create and connect a socket using default settings. + socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - if (Settings._connectCallback is not null) - { - throw new NotSupportedException(SR.net_http_sync_operations_not_allowed_with_connect_callback); - } + if (async) + { + await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); + } + else + { + using (cancellationToken.UnsafeRegister(static s => ((Socket)s!).Dispose(), socket)) + { + socket.Connect(endPoint); + } + } - try - { - return new ValueTask(ConnectHelper.Connect(host, port, cancellationToken)); + return new NetworkStream(socket, ownsSocket: true); + } } catch (Exception ex) { - return ValueTask.FromException(ex); + socket?.Dispose(); + throw ex is OperationCanceledException oce && oce.CancellationToken == cancellationToken ? + CancellationHelper.CreateOperationCanceledException(innerException: null, cancellationToken) : + ConnectHelper.CreateWrappedException(ex, endPoint.Host, endPoint.Port, cancellationToken); } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 186358ac425ecefac5841361c6a16309e7904bc1..b019a1e89855b37625ed4bc54bd2dbbaea6a6c15 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -2335,22 +2335,19 @@ public abstract class SocketsHttpHandlerTest_ConnectCallback : HttpClientHandler { public SocketsHttpHandlerTest_ConnectCallback(ITestOutputHelper output) : base(output) { } - [Fact] - public void ConnectCallback_SyncRequest_Fails() + [Theory] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task ConnectCallback_ContextHasCorrectProperties_Success(bool syncRequest, bool syncCallback) { - using SocketsHttpHandler handler = new SocketsHttpHandler + if (syncRequest && UseVersion > HttpVersion.Version11) { - ConnectCallback = (context, token) => default, - }; - - using HttpClient client = CreateHttpClient(handler); - - Assert.ThrowsAny(() => client.Send(new HttpRequestMessage(HttpMethod.Get, $"http://bing.com"))); - } + // Sync requests are only supported on 1.x + return; + } - [Fact] - public async Task ConnectCallback_ContextHasCorrectProperties_Success() - { await LoopbackServerFactory.CreateClientAndServerAsync( async uri => { @@ -2367,14 +2364,23 @@ public async Task ConnectCallback_ContextHasCorrectProperties_Success() Assert.Equal(uri.Port, context.DnsEndPoint.Port); Assert.Equal(requestMessage, context.InitialRequestMessage); - Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await s.ConnectAsync(context.DnsEndPoint, token); + var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + if (syncCallback) + { + s.Connect(context.DnsEndPoint); + } + else + { + await s.ConnectAsync(context.DnsEndPoint, token); + } return new NetworkStream(s, ownsSocket: true); }; using HttpClient client = CreateHttpClient(handler); - HttpResponseMessage response = await client.SendAsync(requestMessage); + HttpResponseMessage response = await (syncRequest ? + Task.Run(() => client.Send(requestMessage)) : + client.SendAsync(requestMessage)); Assert.Equal("foo", await response.Content.ReadAsStringAsync()); }, async server =>