diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index 68671fb98549a381e5726c246dbc1af082f7d289..3fb532ab4c85a0e264d62f02768ba371834f38f3 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -472,7 +472,7 @@ private void CheckForHttp11ConnectionInjection() { Debug.Assert(HasSyncObjLock); - if (!_http11RequestQueue.TryPeekRequest(out HttpRequestMessage? request)) + if (!_http11RequestQueue.TryPeekUncanceledRequest(this, out HttpRequestMessage? request)) { return; } @@ -701,7 +701,7 @@ private void CheckForHttp2ConnectionInjection() { Debug.Assert(HasSyncObjLock); - if (!_http2RequestQueue.TryPeekRequest(out HttpRequestMessage? request)) + if (!_http2RequestQueue.TryPeekUncanceledRequest(this, out HttpRequestMessage? request)) { return; } @@ -2273,12 +2273,23 @@ public bool TryDequeueWaiter([MaybeNullWhen(false)] out TaskCompletionSourceWith return false; } - public bool TryPeekRequest([MaybeNullWhen(false)] out HttpRequestMessage request) + public bool TryPeekUncanceledRequest(HttpConnectionPool pool, [MaybeNullWhen(false)] out HttpRequestMessage request) { - if (_queue is not null && _queue.TryPeek(out QueueItem item)) + if (_queue is not null) { - request = item.Request; - return true; + while (_queue.TryPeek(out QueueItem item)) + { + if (item.Waiter.Task.IsCanceled) + { + if (NetEventSource.Log.IsEnabled()) pool.Trace("Discarding canceled request from queue."); + _queue.Dequeue(); + } + else + { + request = item.Request; + return true; + } + } } request = null; diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs index 0c62b3bc3c0c7b13ed354eeb81b85f334c1d2c7d..0fdd4c97771d7895678166b9279da79217ec9f44 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs @@ -12,6 +12,7 @@ namespace System.Net.Http.Functional.Tests { + [ConditionalClass(typeof(SocketsHttpHandler), nameof(SocketsHttpHandler.IsSupported))] public abstract class SocketsHttpHandler_Cancellation_Test : HttpClientHandler_Cancellation_Test { protected SocketsHttpHandler_Cancellation_Test(ITestOutputHelper output) : base(output) { } @@ -196,6 +197,70 @@ public async Task ConnectionFailure_AfterInitialRequestCancelled_SecondRequestSu options: new GenericLoopbackOptions() { UseSsl = useSsl }); } + [Fact] + public async Task RequestsCanceled_NoConnectionAttemptForCanceledRequests() + { + if (UseVersion == HttpVersion.Version30) + { + // HTTP3 does not support ConnectCallback + return; + } + + bool seenRequest1 = false; + bool seenRequest2 = false; + bool seenRequest3 = false; + + var uri = new Uri("https://example.com"); + HttpRequestMessage request1 = CreateRequest(HttpMethod.Get, uri, UseVersion, exactVersion: true); + HttpRequestMessage request2 = CreateRequest(HttpMethod.Get, uri, UseVersion, exactVersion: true); + HttpRequestMessage request3 = CreateRequest(HttpMethod.Get, uri, UseVersion, exactVersion: true); + + TaskCompletionSource connectCallbackEntered = new(TaskCreationOptions.RunContinuationsAsynchronously); + TaskCompletionSource connectCallbackGate = new(TaskCreationOptions.RunContinuationsAsynchronously); + + using HttpClientHandler handler = CreateHttpClientHandler(); + handler.MaxConnectionsPerServer = 1; + GetUnderlyingSocketsHttpHandler(handler).ConnectCallback = async (context, cancellation) => + { + if (context.InitialRequestMessage == request1) seenRequest1 = true; + if (context.InitialRequestMessage == request2) seenRequest2 = true; + if (context.InitialRequestMessage == request3) seenRequest3 = true; + + connectCallbackEntered.TrySetResult(); + + await connectCallbackGate.Task.WaitAsync(TestHelper.PassingTestTimeout); + + throw new Exception("No connection"); + }; + using HttpClient client = CreateHttpClient(handler); + + Task request1Task = client.SendAsync(TestAsync, request1); + await connectCallbackEntered.Task.WaitAsync(TestHelper.PassingTestTimeout); + Assert.True(seenRequest1); + + using var request2Cts = new CancellationTokenSource(); + Task request2Task = client.SendAsync(TestAsync, request2, request2Cts.Token); + Assert.False(seenRequest2); + + Task request3Task = client.SendAsync(TestAsync, request3); + Assert.False(seenRequest2); + Assert.False(seenRequest3); + + request2Cts.Cancel(); + + await Assert.ThrowsAsync(() => request2Task).WaitAsync(TestHelper.PassingTestTimeout); + Assert.False(seenRequest2); + Assert.False(seenRequest3); + + connectCallbackGate.SetResult(); + + await Assert.ThrowsAsync(() => request1Task).WaitAsync(TestHelper.PassingTestTimeout); + await Assert.ThrowsAsync(() => request3Task).WaitAsync(TestHelper.PassingTestTimeout); + + Assert.False(seenRequest2); + Assert.True(seenRequest3); + } + [OuterLoop("Incurs significant delay")] [Fact] public async Task Expect100Continue_WaitsExpectedPeriodOfTimeBeforeSendingContent() diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 78c6f36c7a00cb285ba38bb2388a2377b5f7e787..f243807cbf90d0cb985c6ab9227077c68a9c038a 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -1044,12 +1044,12 @@ public sealed class SocketsHttpHandlerTest_Cookies_Http11 : HttpClientHandlerTes public SocketsHttpHandlerTest_Cookies_Http11(ITestOutputHelper output) : base(output) { } } + [ConditionalClass(typeof(SocketsHttpHandler), nameof(SocketsHttpHandler.IsSupported))] public sealed class SocketsHttpHandler_HttpClientHandler_Http11_Cancellation_Test : SocketsHttpHandler_Cancellation_Test { public SocketsHttpHandler_HttpClientHandler_Http11_Cancellation_Test(ITestOutputHelper output) : base(output) { } [Fact] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void ConnectTimeout_Default() { using (var handler = new SocketsHttpHandler()) @@ -1062,7 +1062,6 @@ public void ConnectTimeout_Default() [InlineData(0)] [InlineData(-2)] [InlineData(int.MaxValue + 1L)] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void ConnectTimeout_InvalidValues(long ms) { using (var handler = new SocketsHttpHandler()) @@ -1076,7 +1075,6 @@ public void ConnectTimeout_InvalidValues(long ms) [InlineData(1)] [InlineData(int.MaxValue - 1)] [InlineData(int.MaxValue)] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void ConnectTimeout_ValidValues_Roundtrip(long ms) { using (var handler = new SocketsHttpHandler()) @@ -1087,7 +1085,6 @@ public void ConnectTimeout_ValidValues_Roundtrip(long ms) } [Fact] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void ConnectTimeout_SetAfterUse_Throws() { using (var handler = new SocketsHttpHandler()) @@ -1101,7 +1098,6 @@ public void ConnectTimeout_SetAfterUse_Throws() } [Fact] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void Expect100ContinueTimeout_Default() { using (var handler = new SocketsHttpHandler()) @@ -1113,7 +1109,6 @@ public void Expect100ContinueTimeout_Default() [Theory] [InlineData(-2)] [InlineData(int.MaxValue + 1L)] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void Expect100ContinueTimeout_InvalidValues(long ms) { using (var handler = new SocketsHttpHandler()) @@ -1127,7 +1122,6 @@ public void Expect100ContinueTimeout_InvalidValues(long ms) [InlineData(1)] [InlineData(int.MaxValue - 1)] [InlineData(int.MaxValue)] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void Expect100ContinueTimeout_ValidValues_Roundtrip(long ms) { using (var handler = new SocketsHttpHandler()) @@ -1138,7 +1132,6 @@ public void Expect100ContinueTimeout_ValidValues_Roundtrip(long ms) } [Fact] - [SkipOnPlatform(TestPlatforms.Browser, "ConnectTimeout is not supported on Browser")] public void Expect100ContinueTimeout_SetAfterUse_Throws() { using (var handler = new SocketsHttpHandler())