From d4dad4860a75459c9274ae2d5b4958634984a526 Mon Sep 17 00:00:00 2001
From: Christian <6939810+chkr1011@users.noreply.github.com>
Date: Wed, 14 Sep 2022 14:15:35 +0200
Subject: [PATCH] Fix and improve session takeover process handling.
---
.github/workflows/ReleaseNotes.md | 1 +
MQTTnet.sln.DotSettings | 1 +
Source/MQTTnet.Tests/Server/Session_Tests.cs | 38 ++++++-----------
.../Implementations/MqttTcpServerListener.cs | 2 +-
Source/MQTTnet/Server/Internal/MqttClient.cs | 20 ++++++++-
.../Internal/MqttClientSessionsManager.cs | 42 +++++++++++--------
6 files changed, 59 insertions(+), 45 deletions(-)
diff --git a/.github/workflows/ReleaseNotes.md b/.github/workflows/ReleaseNotes.md
index bd6f0335..a4f66c6c 100644
--- a/.github/workflows/ReleaseNotes.md
+++ b/.github/workflows/ReleaseNotes.md
@@ -1,3 +1,4 @@
* [Core] MQTT Packets being sent over web socket transport are now setting the web socket frame boundaries correctly (#1499).
* [Client] Keep alive mechanism now uses the configured timeout value from the options (thanks to @Stannieman, #1495).
* [Server] A DISCONNECT packet is no longer sent to MQTT clients < 5.0.0 (thanks to @logicaloud, #1506).
+* [Server] Improved "take over" process handling.
diff --git a/MQTTnet.sln.DotSettings b/MQTTnet.sln.DotSettings
index bf4f67af..7eed6537 100644
--- a/MQTTnet.sln.DotSettings
+++ b/MQTTnet.sln.DotSettings
@@ -241,5 +241,6 @@ See the LICENSE file in the project root for more information.
True
True
True
+ True
True
True
\ No newline at end of file
diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs
index a791c4c5..772b8484 100644
--- a/Source/MQTTnet.Tests/Server/Session_Tests.cs
+++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs
@@ -124,7 +124,9 @@ namespace MQTTnet.Tests.Server
}
[TestMethod]
- public async Task Handle_Parallel_Connection_Attempts()
+ [DataRow(MqttProtocolVersion.V311)]
+ [DataRow(MqttProtocolVersion.V500)]
+ public async Task Handle_Parallel_Connection_Attempts(MqttProtocolVersion protocolVersion)
{
using (var testEnvironment = CreateTestEnvironment())
{
@@ -132,7 +134,7 @@ namespace MQTTnet.Tests.Server
await testEnvironment.StartServer();
- var options = new MqttClientOptionsBuilder().WithClientId("1").WithKeepAlivePeriod(TimeSpan.FromSeconds(5));
+ var options = new MqttClientOptionsBuilder().WithClientId("1").WithTimeout(TimeSpan.FromSeconds(1)).WithProtocolVersion(protocolVersion).WithKeepAlivePeriod(TimeSpan.FromSeconds(5));
var hasReceive = false;
@@ -146,7 +148,10 @@ namespace MQTTnet.Tests.Server
// Try to connect 50 clients at the same time.
var clients = await Task.WhenAll(Enumerable.Range(0, 50).Select(i => ConnectAndSubscribe(testEnvironment, options, OnReceive)));
- var connectedClients = clients.Where(c => c?.IsConnected ?? false).ToList();
+
+ await LongTestDelay();
+
+ var connectedClients = clients.Where(c => c != null).Where(c => c.TryPingAsync().GetAwaiter().GetResult()).ToList();
Assert.AreEqual(1, connectedClients.Count);
@@ -161,26 +166,7 @@ namespace MQTTnet.Tests.Server
Assert.AreEqual(true, hasReceive);
}
}
-
- [TestMethod]
- public async Task Manage_Session_MaxParallel()
- {
- using (var testEnvironment = CreateTestEnvironment())
- {
- testEnvironment.IgnoreClientLogErrors = true;
- var serverOptions = new MqttServerOptionsBuilder();
- await testEnvironment.StartServer(serverOptions);
-
- var options = new MqttClientOptionsBuilder().WithClientId("1");
-
- var clients = await Task.WhenAll(Enumerable.Range(0, 10).Select(i => TryConnect(testEnvironment, options)));
-
- var connectedClients = clients.Where(c => c?.IsConnected ?? false).ToList();
-
- Assert.AreEqual(1, connectedClients.Count);
- }
- }
-
+
[DataTestMethod]
[DataRow(MqttQualityOfServiceLevel.ExactlyOnce)]
[DataRow(MqttQualityOfServiceLevel.AtLeastOnce)]
@@ -360,7 +346,7 @@ namespace MQTTnet.Tests.Server
}
}
- async Task ConnectAndSubscribe(TestEnvironment testEnvironment, MqttClientOptionsBuilder options, Action onReceive)
+ static async Task ConnectAndSubscribe(TestEnvironment testEnvironment, MqttClientOptionsBuilder options, Action onReceive)
{
try
{
@@ -379,13 +365,13 @@ namespace MQTTnet.Tests.Server
return sendClient;
}
- catch (Exception)
+ catch
{
return null;
}
}
- async Task TryConnect(TestEnvironment testEnvironment, MqttClientOptionsBuilder options)
+ static async Task TryConnect(TestEnvironment testEnvironment, MqttClientOptionsBuilder options)
{
try
{
diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs
index b2f814da..9341d275 100644
--- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs
+++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs
@@ -134,7 +134,7 @@ namespace MQTTnet.Implementations
continue;
}
- Task.Run(() => TryHandleClientConnectionAsync(clientSocket), cancellationToken).RunInBackground(_logger);
+ _ = Task.Factory.StartNew(() => TryHandleClientConnectionAsync(clientSocket), cancellationToken, TaskCreationOptions.PreferFairness, TaskScheduler.Default).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
diff --git a/Source/MQTTnet/Server/Internal/MqttClient.cs b/Source/MQTTnet/Server/Internal/MqttClient.cs
index 2982961c..18c3804a 100644
--- a/Source/MQTTnet/Server/Internal/MqttClient.cs
+++ b/Source/MQTTnet/Server/Internal/MqttClient.cs
@@ -98,7 +98,7 @@ namespace MQTTnet.Server
try
{
- Task.Run(() => SendPacketsLoop(cancellationToken), cancellationToken).RunInBackground(_logger);
+ _ = Task.Factory.StartNew(() => SendPacketsLoop(cancellationToken), cancellationToken, TaskCreationOptions.PreferFairness, TaskScheduler.Default).ConfigureAwait(false);
IsRunning = true;
@@ -391,7 +391,20 @@ namespace MQTTnet.Server
{
return;
}
+
+ // Check for cancellation again because receive packet might block some time.
+ if (cancellationToken.IsCancellationRequested)
+ {
+ return;
+ }
+ // The TCP connection of this client may be still open but the client has already been taken over by
+ // a new TCP connection. So we must exit here to make sure to no longer process any message.
+ if (IsTakenOver || !IsRunning)
+ {
+ return;
+ }
+
var processPacket = true;
if (_eventContainer.InterceptingInboundPacketEvent.HasHandlers)
@@ -493,6 +506,11 @@ namespace MQTTnet.Server
return;
}
+ if (IsTakenOver || !IsRunning)
+ {
+ return;
+ }
+
try
{
await SendPacketAsync(packetBusItem.Packet, cancellationToken).ConfigureAwait(false);
diff --git a/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs
index eb2359c7..59dace26 100644
--- a/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs
+++ b/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs
@@ -23,7 +23,6 @@ namespace MQTTnet.Server
{
readonly Dictionary _clients = new Dictionary(4096);
- readonly AsyncLock _createConnectionSyncRoot = new AsyncLock();
readonly MqttServerEventContainer _eventContainer;
readonly MqttNetSourceLogger _logger;
readonly MqttServerOptions _options;
@@ -39,6 +38,8 @@ namespace MQTTnet.Server
readonly object _sessionsManagementLock = new object();
readonly HashSet _subscriberSessions = new HashSet();
+ readonly SemaphoreSlim _createConnectionSyncRoot = new SemaphoreSlim(1, 1);
+
public MqttClientSessionsManager(
MqttServerOptions options,
MqttRetainedMessagesManager retainedMessagesManager,
@@ -75,6 +76,8 @@ namespace MQTTnet.Server
public async Task DeleteSessionAsync(string clientId)
{
+ _logger.Verbose("Deleting session for client '{0}'.", clientId);
+
MqttClient connection;
lock (_clients)
@@ -193,7 +196,7 @@ namespace MQTTnet.Server
public void Dispose()
{
- _createConnectionSyncRoot?.Dispose();
+ _createConnectionSyncRoot.Dispose();
lock (_sessionsManagementLock)
{
@@ -444,7 +447,7 @@ namespace MQTTnet.Server
IMqttChannelAdapter channelAdapter,
ValidatingConnectionEventArgs validatingConnectionEventArgs)
{
- MqttClient connection;
+ MqttClient client;
bool sessionShouldPersist;
@@ -470,8 +473,9 @@ namespace MQTTnet.Server
sessionShouldPersist = !connectPacket.CleanSession;
}
-
- using (await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false))
+
+ await _createConnectionSyncRoot.WaitAsync(CancellationToken.None).ConfigureAwait(false);
+ try
{
MqttSession session;
lock (_sessionsManagementLock)
@@ -484,7 +488,7 @@ namespace MQTTnet.Server
{
if (connectPacket.CleanSession)
{
- _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId);
+ _logger.Verbose("Deleting existing session of client '{0}' due to clean start.", connectPacket.ClientId);
session = CreateSession(connectPacket.ClientId, validatingConnectionEventArgs.SessionItems, sessionShouldPersist);
}
else
@@ -507,40 +511,44 @@ namespace MQTTnet.Server
await _eventContainer.PreparingSessionEvent.InvokeAsync(preparingSessionEventArgs).ConfigureAwait(false);
}
- MqttClient existing;
+ MqttClient existingClient;
lock (_clients)
{
- _clients.TryGetValue(connectPacket.ClientId, out existing);
- connection = CreateConnection(connectPacket, channelAdapter, session);
+ _clients.TryGetValue(connectPacket.ClientId, out existingClient);
+ client = CreateClient(connectPacket, channelAdapter, session);
- _clients[connectPacket.ClientId] = connection;
+ _clients[connectPacket.ClientId] = client;
}
- if (existing != null)
+ if (existingClient != null)
{
- existing.IsTakenOver = true;
- await existing.StopAsync(MqttDisconnectReasonCode.SessionTakenOver).ConfigureAwait(false);
+ existingClient.IsTakenOver = true;
+ await existingClient.StopAsync(MqttDisconnectReasonCode.SessionTakenOver).ConfigureAwait(false);
if (_eventContainer.ClientConnectedEvent.HasHandlers)
{
- var eventArgs = new ClientDisconnectedEventArgs(existing.Id, MqttClientDisconnectType.Takeover, existing.Endpoint, existing.Session.Items);
+ var eventArgs = new ClientDisconnectedEventArgs(existingClient.Id, MqttClientDisconnectType.Takeover, existingClient.Endpoint, existingClient.Session.Items);
await _eventContainer.ClientDisconnectedEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
}
}
}
+ finally
+ {
+ _createConnectionSyncRoot.Release();
+ }
- return connection;
+ return client;
}
- MqttClient CreateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, MqttSession session)
+ MqttClient CreateClient(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, MqttSession session)
{
return new MqttClient(connectPacket, channelAdapter, session, _options, _eventContainer, this, _rootLogger);
}
MqttSession CreateSession(string clientId, IDictionary sessionItems, bool isPersistent)
{
- _logger.Verbose("Created a new session for client '{0}'.", clientId);
+ _logger.Verbose("Created new session for client '{0}'.", clientId);
return new MqttSession(clientId, isPersistent, sessionItems, _options, _eventContainer, _retainedMessagesManager, this);
}
--
GitLab