From 15033649baba034cd66ddf4df7c6c3033690c7d2 Mon Sep 17 00:00:00 2001 From: Justin Kotalik Date: Wed, 25 Nov 2020 20:23:43 -0800 Subject: [PATCH] Close accept loop when closing connection for Quic (#44885) --- .../Implementations/MsQuic/MsQuicConnection.cs | 10 ++++------ .../Implementations/MsQuic/MsQuicListener.cs | 11 +++++++++-- .../Quic/Implementations/MsQuic/MsQuicStream.cs | 4 ++-- .../tests/FunctionalTests/MsQuicTests.cs | 17 +++++++++++++++++ .../tests/FunctionalTests/QuicListenerTests.cs | 5 +++-- .../tests/FunctionalTests/QuicStreamTests.cs | 2 -- 6 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index fcc975f021d..ee5dba0ef41 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -161,21 +161,21 @@ private uint HandleEventShutdownInitiatedByTransport(ref ConnectionEvent connect _connectTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex)); } - _acceptQueue.Writer.Complete(); - return MsQuicStatusCodes.Success; } private uint HandleEventShutdownInitiatedByPeer(ref ConnectionEvent connectionEvent) { _abortErrorCode = connectionEvent.Data.ShutdownInitiatedByPeer.ErrorCode; - _acceptQueue.Writer.Complete(); return MsQuicStatusCodes.Success; } private uint HandleEventShutdownComplete(ref ConnectionEvent connectionEvent) { _shutdownTcs.SetResult(MsQuicStatusCodes.Success); + + // Stop accepting new streams. + _acceptQueue?.Writer.Complete(); return MsQuicStatusCodes.Success; } @@ -291,7 +291,7 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d private void SetCallbackHandler() { - Debug.Assert(!_handle.IsAllocated); + Debug.Assert(!_handle.IsAllocated, "callback handler allocated already"); _handle = GCHandle.Alloc(this); MsQuicApi.Api.SetCallbackHandlerDelegate( @@ -310,8 +310,6 @@ private void SetCallbackHandler() ErrorCode); QuicExceptionHelpers.ThrowIfFailed(status, "Failed to shutdown connection."); - Debug.Assert(_shutdownTcs.Task.IsCompleted == false); - return new ValueTask(_shutdownTcs.Task); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs index 60b17a6ad0e..192085132b9 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs @@ -34,7 +34,7 @@ internal sealed class MsQuicListener : QuicListenerProvider, IDisposable private QuicListenerOptions _options; private volatile bool _disposed; private IPEndPoint _listenEndPoint; - + private bool _started; private readonly Channel _acceptConnectionQueue; internal MsQuicListener(QuicListenerOptions options) @@ -120,6 +120,13 @@ internal override void Start() { ThrowIfDisposed(); + // protect against double starts. + if (_started) + { + throw new QuicException("Cannot start Listener multiple times"); + } + + _started = true; SetCallbackHandler(); SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet(_listenEndPoint); @@ -202,7 +209,7 @@ private void StopAcceptingConnections() internal void SetCallbackHandler() { - Debug.Assert(!_handle.IsAllocated); + Debug.Assert(!_handle.IsAllocated, "listener allocated"); _handle = GCHandle.Alloc(this); MsQuicApi.Api.SetCallbackHandlerDelegate( diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 5379f1ff1ce..66ea94d218b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -71,7 +71,7 @@ internal sealed class MsQuicStream : QuicStreamProvider // Creates a new MsQuicStream internal MsQuicStream(MsQuicConnection connection, QUIC_STREAM_OPEN_FLAG flags, IntPtr nativeObjPtr, bool inbound) { - Debug.Assert(connection != null); + Debug.Assert(connection != null, "Connection null"); _ptr = nativeObjPtr; @@ -936,7 +936,7 @@ private void SetCallbackHandler() /// private void StartLocalStream() { - Debug.Assert(!_started); + Debug.Assert(!_started, "start local stream"); uint status = MsQuicApi.Api.StreamStartDelegate( _ptr, (uint)QUIC_STREAM_START_FLAG.ASYNC); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 21388ed3702..2614a7a49d7 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -177,6 +177,23 @@ public async Task CallDifferentWriteMethodsWorks() Assert.Equal(24, res); } + [Fact] + public async Task CloseAsync_ByServer_AcceptThrows() + { + await RunClientServer( + clientConnection => + { + return Task.CompletedTask; + }, + async serverConnection => + { + var acceptTask = serverConnection.AcceptStreamAsync(); + await serverConnection.CloseAsync(errorCode: 0); + // make sure + await Assert.ThrowsAsync(() => acceptTask.AsTask()); + }); + } + private static ReadOnlySequence CreateReadOnlySequenceFromBytes(byte[] data) { List segments = new List diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs index a1e9bf3b72b..3fbfcaf75c6 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs @@ -21,10 +21,11 @@ public async Task Listener_Backlog_Success() using QuicListener listener = CreateQuicListener(); using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); + var clientStreamTask = clientConnection.ConnectAsync(); using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - }).TimeoutAfter(millisecondsTimeout: 5_000); + await clientStreamTask; + }).TimeoutAfter(millisecondsTimeout: 6_000); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index e339ac1a30a..3da7e7e9bf8 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -305,8 +305,6 @@ public async Task LargeDataSentAndReceived() } } - - [Fact] public async Task TestStreams() { -- GitLab