From d4dad4860a75459c9274ae2d5b4958634984a526 Mon Sep 17 00:00:00 2001 From: Christian <6939810+chkr1011@users.noreply.github.com> Date: Wed, 14 Sep 2022 14:15:35 +0200 Subject: [PATCH] Fix and improve session takeover process handling. --- .github/workflows/ReleaseNotes.md | 1 + MQTTnet.sln.DotSettings | 1 + Source/MQTTnet.Tests/Server/Session_Tests.cs | 38 ++++++----------- .../Implementations/MqttTcpServerListener.cs | 2 +- Source/MQTTnet/Server/Internal/MqttClient.cs | 20 ++++++++- .../Internal/MqttClientSessionsManager.cs | 42 +++++++++++-------- 6 files changed, 59 insertions(+), 45 deletions(-) diff --git a/.github/workflows/ReleaseNotes.md b/.github/workflows/ReleaseNotes.md index bd6f0335..a4f66c6c 100644 --- a/.github/workflows/ReleaseNotes.md +++ b/.github/workflows/ReleaseNotes.md @@ -1,3 +1,4 @@ * [Core] MQTT Packets being sent over web socket transport are now setting the web socket frame boundaries correctly (#1499). * [Client] Keep alive mechanism now uses the configured timeout value from the options (thanks to @Stannieman, #1495). * [Server] A DISCONNECT packet is no longer sent to MQTT clients < 5.0.0 (thanks to @logicaloud, #1506). +* [Server] Improved "take over" process handling. diff --git a/MQTTnet.sln.DotSettings b/MQTTnet.sln.DotSettings index bf4f67af..7eed6537 100644 --- a/MQTTnet.sln.DotSettings +++ b/MQTTnet.sln.DotSettings @@ -241,5 +241,6 @@ See the LICENSE file in the project root for more information. True True True + True True True \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index a791c4c5..772b8484 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -124,7 +124,9 @@ namespace MQTTnet.Tests.Server } [TestMethod] - public async Task Handle_Parallel_Connection_Attempts() + [DataRow(MqttProtocolVersion.V311)] + [DataRow(MqttProtocolVersion.V500)] + public async Task Handle_Parallel_Connection_Attempts(MqttProtocolVersion protocolVersion) { using (var testEnvironment = CreateTestEnvironment()) { @@ -132,7 +134,7 @@ namespace MQTTnet.Tests.Server await testEnvironment.StartServer(); - var options = new MqttClientOptionsBuilder().WithClientId("1").WithKeepAlivePeriod(TimeSpan.FromSeconds(5)); + var options = new MqttClientOptionsBuilder().WithClientId("1").WithTimeout(TimeSpan.FromSeconds(1)).WithProtocolVersion(protocolVersion).WithKeepAlivePeriod(TimeSpan.FromSeconds(5)); var hasReceive = false; @@ -146,7 +148,10 @@ namespace MQTTnet.Tests.Server // Try to connect 50 clients at the same time. var clients = await Task.WhenAll(Enumerable.Range(0, 50).Select(i => ConnectAndSubscribe(testEnvironment, options, OnReceive))); - var connectedClients = clients.Where(c => c?.IsConnected ?? false).ToList(); + + await LongTestDelay(); + + var connectedClients = clients.Where(c => c != null).Where(c => c.TryPingAsync().GetAwaiter().GetResult()).ToList(); Assert.AreEqual(1, connectedClients.Count); @@ -161,26 +166,7 @@ namespace MQTTnet.Tests.Server Assert.AreEqual(true, hasReceive); } } - - [TestMethod] - public async Task Manage_Session_MaxParallel() - { - using (var testEnvironment = CreateTestEnvironment()) - { - testEnvironment.IgnoreClientLogErrors = true; - var serverOptions = new MqttServerOptionsBuilder(); - await testEnvironment.StartServer(serverOptions); - - var options = new MqttClientOptionsBuilder().WithClientId("1"); - - var clients = await Task.WhenAll(Enumerable.Range(0, 10).Select(i => TryConnect(testEnvironment, options))); - - var connectedClients = clients.Where(c => c?.IsConnected ?? false).ToList(); - - Assert.AreEqual(1, connectedClients.Count); - } - } - + [DataTestMethod] [DataRow(MqttQualityOfServiceLevel.ExactlyOnce)] [DataRow(MqttQualityOfServiceLevel.AtLeastOnce)] @@ -360,7 +346,7 @@ namespace MQTTnet.Tests.Server } } - async Task ConnectAndSubscribe(TestEnvironment testEnvironment, MqttClientOptionsBuilder options, Action onReceive) + static async Task ConnectAndSubscribe(TestEnvironment testEnvironment, MqttClientOptionsBuilder options, Action onReceive) { try { @@ -379,13 +365,13 @@ namespace MQTTnet.Tests.Server return sendClient; } - catch (Exception) + catch { return null; } } - async Task TryConnect(TestEnvironment testEnvironment, MqttClientOptionsBuilder options) + static async Task TryConnect(TestEnvironment testEnvironment, MqttClientOptionsBuilder options) { try { diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index b2f814da..9341d275 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -134,7 +134,7 @@ namespace MQTTnet.Implementations continue; } - Task.Run(() => TryHandleClientConnectionAsync(clientSocket), cancellationToken).RunInBackground(_logger); + _ = Task.Factory.StartNew(() => TryHandleClientConnectionAsync(clientSocket), cancellationToken, TaskCreationOptions.PreferFairness, TaskScheduler.Default).ConfigureAwait(false); } catch (OperationCanceledException) { diff --git a/Source/MQTTnet/Server/Internal/MqttClient.cs b/Source/MQTTnet/Server/Internal/MqttClient.cs index 2982961c..18c3804a 100644 --- a/Source/MQTTnet/Server/Internal/MqttClient.cs +++ b/Source/MQTTnet/Server/Internal/MqttClient.cs @@ -98,7 +98,7 @@ namespace MQTTnet.Server try { - Task.Run(() => SendPacketsLoop(cancellationToken), cancellationToken).RunInBackground(_logger); + _ = Task.Factory.StartNew(() => SendPacketsLoop(cancellationToken), cancellationToken, TaskCreationOptions.PreferFairness, TaskScheduler.Default).ConfigureAwait(false); IsRunning = true; @@ -391,7 +391,20 @@ namespace MQTTnet.Server { return; } + + // Check for cancellation again because receive packet might block some time. + if (cancellationToken.IsCancellationRequested) + { + return; + } + // The TCP connection of this client may be still open but the client has already been taken over by + // a new TCP connection. So we must exit here to make sure to no longer process any message. + if (IsTakenOver || !IsRunning) + { + return; + } + var processPacket = true; if (_eventContainer.InterceptingInboundPacketEvent.HasHandlers) @@ -493,6 +506,11 @@ namespace MQTTnet.Server return; } + if (IsTakenOver || !IsRunning) + { + return; + } + try { await SendPacketAsync(packetBusItem.Packet, cancellationToken).ConfigureAwait(false); diff --git a/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs index eb2359c7..59dace26 100644 --- a/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs @@ -23,7 +23,6 @@ namespace MQTTnet.Server { readonly Dictionary _clients = new Dictionary(4096); - readonly AsyncLock _createConnectionSyncRoot = new AsyncLock(); readonly MqttServerEventContainer _eventContainer; readonly MqttNetSourceLogger _logger; readonly MqttServerOptions _options; @@ -39,6 +38,8 @@ namespace MQTTnet.Server readonly object _sessionsManagementLock = new object(); readonly HashSet _subscriberSessions = new HashSet(); + readonly SemaphoreSlim _createConnectionSyncRoot = new SemaphoreSlim(1, 1); + public MqttClientSessionsManager( MqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, @@ -75,6 +76,8 @@ namespace MQTTnet.Server public async Task DeleteSessionAsync(string clientId) { + _logger.Verbose("Deleting session for client '{0}'.", clientId); + MqttClient connection; lock (_clients) @@ -193,7 +196,7 @@ namespace MQTTnet.Server public void Dispose() { - _createConnectionSyncRoot?.Dispose(); + _createConnectionSyncRoot.Dispose(); lock (_sessionsManagementLock) { @@ -444,7 +447,7 @@ namespace MQTTnet.Server IMqttChannelAdapter channelAdapter, ValidatingConnectionEventArgs validatingConnectionEventArgs) { - MqttClient connection; + MqttClient client; bool sessionShouldPersist; @@ -470,8 +473,9 @@ namespace MQTTnet.Server sessionShouldPersist = !connectPacket.CleanSession; } - - using (await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + + await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false); + try { MqttSession session; lock (_sessionsManagementLock) @@ -484,7 +488,7 @@ namespace MQTTnet.Server { if (connectPacket.CleanSession) { - _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId); + _logger.Verbose("Deleting existing session of client '{0}' due to clean start.", connectPacket.ClientId); session = CreateSession(connectPacket.ClientId, validatingConnectionEventArgs.SessionItems, sessionShouldPersist); } else @@ -507,40 +511,44 @@ namespace MQTTnet.Server await _eventContainer.PreparingSessionEvent.InvokeAsync(preparingSessionEventArgs).ConfigureAwait(false); } - MqttClient existing; + MqttClient existingClient; lock (_clients) { - _clients.TryGetValue(connectPacket.ClientId, out existing); - connection = CreateConnection(connectPacket, channelAdapter, session); + _clients.TryGetValue(connectPacket.ClientId, out existingClient); + client = CreateClient(connectPacket, channelAdapter, session); - _clients[connectPacket.ClientId] = connection; + _clients[connectPacket.ClientId] = client; } - if (existing != null) + if (existingClient != null) { - existing.IsTakenOver = true; - await existing.StopAsync(MqttDisconnectReasonCode.SessionTakenOver).ConfigureAwait(false); + existingClient.IsTakenOver = true; + await existingClient.StopAsync(MqttDisconnectReasonCode.SessionTakenOver).ConfigureAwait(false); if (_eventContainer.ClientConnectedEvent.HasHandlers) { - var eventArgs = new ClientDisconnectedEventArgs(existing.Id, MqttClientDisconnectType.Takeover, existing.Endpoint, existing.Session.Items); + var eventArgs = new ClientDisconnectedEventArgs(existingClient.Id, MqttClientDisconnectType.Takeover, existingClient.Endpoint, existingClient.Session.Items); await _eventContainer.ClientDisconnectedEvent.InvokeAsync(eventArgs).ConfigureAwait(false); } } } + finally + { + _createConnectionSyncRoot.Release(); + } - return connection; + return client; } - MqttClient CreateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, MqttSession session) + MqttClient CreateClient(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, MqttSession session) { return new MqttClient(connectPacket, channelAdapter, session, _options, _eventContainer, this, _rootLogger); } MqttSession CreateSession(string clientId, IDictionary sessionItems, bool isPersistent) { - _logger.Verbose("Created a new session for client '{0}'.", clientId); + _logger.Verbose("Created new session for client '{0}'.", clientId); return new MqttSession(clientId, isPersistent, sessionItems, _options, _eventContainer, _retainedMessagesManager, this); } -- GitLab