提交 f0ad359c 编写于 作者: C chengxiangwang

add logic of resending msg when acktimeout

上级 caf14397
...@@ -54,6 +54,8 @@ public class MqttConfig { ...@@ -54,6 +54,8 @@ public class MqttConfig {
private long persistOffsetInterval = 2 * 1000; private long persistOffsetInterval = 2 * 1000;
private long scanAckTimeoutInterval = 1000;
public int getListenPort() { public int getListenPort() {
return listenPort; return listenPort;
} }
...@@ -149,4 +151,12 @@ public class MqttConfig { ...@@ -149,4 +151,12 @@ public class MqttConfig {
public void setPersistOffsetInterval(long persistOffsetInterval) { public void setPersistOffsetInterval(long persistOffsetInterval) {
this.persistOffsetInterval = persistOffsetInterval; this.persistOffsetInterval = persistOffsetInterval;
} }
public long getScanAckTimeoutInterval() {
return scanAckTimeoutInterval;
}
public void setScanAckTimeoutInterval(long scanAckTimeoutInterval) {
this.scanAckTimeoutInterval = scanAckTimeoutInterval;
}
} }
...@@ -23,6 +23,7 @@ import java.util.Map; ...@@ -23,6 +23,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import org.apache.rocketmq.common.client.Client; import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.client.ClientManagerImpl; import org.apache.rocketmq.common.client.ClientManagerImpl;
import org.apache.rocketmq.common.client.Subscription; import org.apache.rocketmq.common.client.Subscription;
...@@ -43,8 +44,9 @@ public class IOTClientManagerImpl extends ClientManagerImpl { ...@@ -43,8 +44,9 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
1024); 1024);
private final ConcurrentHashMap<String/*clientId*/, Subscription> clientId2Subscription = new ConcurrentHashMap<>(1024); private final ConcurrentHashMap<String/*clientId*/, Subscription> clientId2Subscription = new ConcurrentHashMap<>(1024);
private final Map<String/*snode ip*/, MqttClient> snode2MqttClient = new HashMap<>(); private final Map<String/*snode ip*/, MqttClient> snode2MqttClient = new HashMap<>();
private final ConcurrentHashMap<String /*broker*/, ConcurrentHashMap<String /*topic@clientId*/, TreeMap<Long/*queueOffset*/, MessageExt>>> processTable = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String /*broker*/, ConcurrentHashMap<String /*rootTopic@clientId*/, TreeMap<Long/*queueOffset*/, MessageExt>>> processTable = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String /*rootTopic@clientId*/, Integer> consumeOffsetTable = new ConcurrentHashMap<>();
private final DelayQueue<InFlightPacket> inflightTimeouts = new DelayQueue<>();
public IOTClientManagerImpl() { public IOTClientManagerImpl() {
} }
...@@ -129,4 +131,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl { ...@@ -129,4 +131,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
public ConcurrentHashMap<String, ConcurrentHashMap<String, TreeMap<Long, MessageExt>>> getProcessTable() { public ConcurrentHashMap<String, ConcurrentHashMap<String, TreeMap<Long, MessageExt>>> getProcessTable() {
return processTable; return processTable;
} }
public DelayQueue<InFlightPacket> getInflightTimeouts() {
return inflightTimeouts;
}
} }
...@@ -23,16 +23,13 @@ public class InFlightMessage { ...@@ -23,16 +23,13 @@ public class InFlightMessage {
private final Integer pushQos; private final Integer pushQos;
private final BrokerData brokerData; private final BrokerData brokerData;
private final byte[] body; private final byte[] body;
private final String messageId;
private final long queueOffset; private final long queueOffset;
public InFlightMessage(String topic, Integer pushQos, byte[] body, BrokerData brokerData, String messageId, public InFlightMessage(String topic, Integer pushQos, byte[] body, BrokerData brokerData, long queueOffset) {
long queueOffset) {
this.topic = topic; this.topic = topic;
this.pushQos = pushQos; this.pushQos = pushQos;
this.body = body; this.body = body;
this.brokerData = brokerData; this.brokerData = brokerData;
this.messageId = messageId;
this.queueOffset = queueOffset; this.queueOffset = queueOffset;
} }
...@@ -44,10 +41,6 @@ public class InFlightMessage { ...@@ -44,10 +41,6 @@ public class InFlightMessage {
return brokerData; return brokerData;
} }
public String getMessageId() {
return messageId;
}
public long getQueueOffset() { public long getQueueOffset() {
return queueOffset; return queueOffset;
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.mqtt.client;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
public class InFlightPacket implements Delayed {
private final MQTTSession client;
private final int packetId;
private long startTime;
private int resendTime = 0;
InFlightPacket(MQTTSession client, int packetId, long delayInMilliseconds) {
this.client = client;
this.packetId = packetId;
this.startTime = System.currentTimeMillis() + delayInMilliseconds;
}
@Override
public long getDelay(TimeUnit unit) {
long diff = startTime - System.currentTimeMillis();
return unit.convert(diff, TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed o) {
if ((this.startTime - ((InFlightPacket) o).startTime) == 0) {
return 0;
}
if ((this.startTime - ((InFlightPacket) o).startTime) > 0) {
return 1;
} else {
return -1;
}
}
public MQTTSession getClient() {
return client;
}
public int getPacketId() {
return packetId;
}
public long getStartTime() {
return startTime;
}
public void setStartTime(long startTime) {
this.startTime = startTime;
}
public int getResendTime() {
return resendTime;
}
public void setResendTime(int resendTime) {
this.resendTime = resendTime;
}
@Override public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof InFlightPacket)) {
return false;
}
InFlightPacket packet = (InFlightPacket) obj;
return packet.getClient().equals(this.getClient()) &&
packet.getPacketId() == this.getPacketId();
}
}
\ No newline at end of file
...@@ -23,9 +23,6 @@ import java.util.Objects; ...@@ -23,9 +23,6 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.apache.rocketmq.common.client.Client; import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.client.ClientRole; import org.apache.rocketmq.common.client.ClientRole;
...@@ -45,6 +42,7 @@ import org.apache.rocketmq.remoting.netty.NettyChannelImpl; ...@@ -45,6 +42,7 @@ import org.apache.rocketmq.remoting.netty.NettyChannelImpl;
import org.apache.rocketmq.remoting.protocol.RemotingCommand; import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader;
import static org.apache.rocketmq.mqtt.constant.MqttConstant.FLIGHT_BEFORE_RESEND_MS;
import static org.apache.rocketmq.mqtt.constant.MqttConstant.TOPIC_CLIENTID_SEPARATOR; import static org.apache.rocketmq.mqtt.constant.MqttConstant.TOPIC_CLIENTID_SEPARATOR;
public class MQTTSession extends Client { public class MQTTSession extends Client {
...@@ -57,40 +55,9 @@ public class MQTTSession extends Client { ...@@ -57,40 +55,9 @@ public class MQTTSession extends Client {
private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; private final DefaultMqttMessageProcessor defaultMqttMessageProcessor;
private final AtomicInteger inflightSlots = new AtomicInteger(10); private final AtomicInteger inflightSlots = new AtomicInteger(10);
private final Map<Integer, InFlightMessage> inflightWindow = new HashMap<>(); private final Map<Integer, InFlightMessage> inflightWindow = new HashMap<>();
private final DelayQueue<InFlightPacket> inflightTimeouts = new DelayQueue<>();
private static final int FLIGHT_BEFORE_RESEND_MS = 5_000;
private Hashtable inUsePacketIds = new Hashtable(); private Hashtable inUsePacketIds = new Hashtable();
private int nextPacketId = 0; private int nextPacketId = 0;
static class InFlightPacket implements Delayed {
final int packetId;
private long startTime;
InFlightPacket(int packetId, long delayInMilliseconds) {
this.packetId = packetId;
this.startTime = System.currentTimeMillis() + delayInMilliseconds;
}
@Override
public long getDelay(TimeUnit unit) {
long diff = startTime - System.currentTimeMillis();
return unit.convert(diff, TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed o) {
if ((this.startTime - ((InFlightPacket) o).startTime) == 0) {
return 0;
}
if ((this.startTime - ((InFlightPacket) o).startTime) > 0) {
return 1;
} else {
return -1;
}
}
}
public MQTTSession(String clientId, ClientRole clientRole, Set<String> groups, boolean isConnected, public MQTTSession(String clientId, ClientRole clientRole, Set<String> groups, boolean isConnected,
boolean cleanSession, RemotingChannel remotingChannel, long lastUpdateTimestamp, boolean cleanSession, RemotingChannel remotingChannel, long lastUpdateTimestamp,
DefaultMqttMessageProcessor defaultMqttMessageProcessor) { DefaultMqttMessageProcessor defaultMqttMessageProcessor) {
...@@ -149,9 +116,10 @@ public class MQTTSession extends Client { ...@@ -149,9 +116,10 @@ public class MQTTSession extends Client {
if (inflightSlots.get() > 0) { if (inflightSlots.get() > 0) {
inflightSlots.decrementAndGet(); inflightSlots.decrementAndGet();
mqttHeader.setPacketId(getNextPacketId()); mqttHeader.setPacketId(getNextPacketId());
inflightWindow.put(mqttHeader.getPacketId(), new InFlightMessage(mqttHeader.getTopicName(), mqttHeader.getQosLevel(), messageExt.getBody(), brokerData, messageExt.getMsgId(), messageExt.getQueueOffset())); inflightWindow.put(mqttHeader.getPacketId(), new InFlightMessage(mqttHeader.getTopicName(), mqttHeader.getQosLevel(), messageExt.getBody(), brokerData, messageExt.getQueueOffset()));
// inflightTimeouts.add(new InFlightPacket(mqttHeader.getPacketId(), FLIGHT_BEFORE_RESEND_MS)); IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager();
put2processTable(((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getProcessTable(), brokerData.getBrokerName(), MqttUtil.getRootTopic(mqttHeader.getTopicName()), messageExt); iotClientManager.getInflightTimeouts().add(new InFlightPacket(this, mqttHeader.getPacketId(), FLIGHT_BEFORE_RESEND_MS));
put2processTable(iotClientManager.getProcessTable(), brokerData.getBrokerName(), MqttUtil.getRootTopic(mqttHeader.getTopicName()), messageExt);
pushMessage2Client(mqttHeader, messageExt.getBody()); pushMessage2Client(mqttHeader, messageExt.getBody());
} }
} }
...@@ -168,11 +136,12 @@ public class MQTTSession extends Client { ...@@ -168,11 +136,12 @@ public class MQTTSession extends Client {
} }
} }
inflightSlots.incrementAndGet(); inflightSlots.incrementAndGet();
((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getInflightTimeouts().remove(new InFlightPacket(this, ackPacketId, 0));
releasePacketId(ackPacketId); releasePacketId(ackPacketId);
return remove; return remove;
} }
private void pushMessage2Client(MqttHeader mqttHeader, byte[] body) { public void pushMessage2Client(MqttHeader mqttHeader, byte[] body) {
try { try {
//set remaining length //set remaining length
int remainingLength = mqttHeader.getTopicName().getBytes().length + body.length; int remainingLength = mqttHeader.getTopicName().getBytes().length + body.length;
...@@ -259,10 +228,6 @@ public class MQTTSession extends Client { ...@@ -259,10 +228,6 @@ public class MQTTSession extends Client {
return inflightWindow; return inflightWindow;
} }
public DelayQueue<InFlightPacket> getInflightTimeouts() {
return inflightTimeouts;
}
public Hashtable getInUsePacketIds() { public Hashtable getInUsePacketIds() {
return inUsePacketIds; return inUsePacketIds;
} }
......
...@@ -27,6 +27,7 @@ public class MqttConstant { ...@@ -27,6 +27,7 @@ public class MqttConstant {
public static final String SUBSCRIPTION_SEPARATOR = "/"; public static final String SUBSCRIPTION_SEPARATOR = "/";
public static final String TOPIC_CLIENTID_SEPARATOR = "@"; public static final String TOPIC_CLIENTID_SEPARATOR = "@";
public static final long DEFAULT_TIMEOUT_MILLS = 3000L; public static final long DEFAULT_TIMEOUT_MILLS = 3000L;
public static final int FLIGHT_BEFORE_RESEND_MS = 5_000;
public static final String PROPERTY_MQTT_QOS = "PROPERTY_MQTT_QOS"; public static final String PROPERTY_MQTT_QOS = "PROPERTY_MQTT_QOS";
public static final AttributeKey<Client> MQTT_CLIENT_ATTRIBUTE_KEY = AttributeKey.valueOf("mqtt.client"); public static final AttributeKey<Client> MQTT_CLIENT_ATTRIBUTE_KEY = AttributeKey.valueOf("mqtt.client");
} }
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
*/ */
package org.apache.rocketmq.mqtt.service.impl; package org.apache.rocketmq.mqtt.service.impl;
import io.netty.handler.codec.mqtt.MqttMessageType;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
...@@ -29,7 +32,12 @@ import org.apache.rocketmq.common.service.ScheduledService; ...@@ -29,7 +32,12 @@ import org.apache.rocketmq.common.service.ScheduledService;
import org.apache.rocketmq.logging.InternalLogger; import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory; import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl;
import org.apache.rocketmq.mqtt.client.InFlightMessage;
import org.apache.rocketmq.mqtt.client.InFlightPacket;
import org.apache.rocketmq.mqtt.client.MQTTSession;
import org.apache.rocketmq.mqtt.constant.MqttConstant;
import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor;
import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader;
public class MqttScheduledServiceImpl implements ScheduledService { public class MqttScheduledServiceImpl implements ScheduledService {
private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME);
...@@ -67,6 +75,36 @@ public class MqttScheduledServiceImpl implements ScheduledService { ...@@ -67,6 +75,36 @@ public class MqttScheduledServiceImpl implements ScheduledService {
} }
} }
}, 0, defaultMqttMessageProcessor.getMqttConfig().getPersistOffsetInterval(), TimeUnit.MILLISECONDS); }, 0, defaultMqttMessageProcessor.getMqttConfig().getPersistOffsetInterval(), TimeUnit.MILLISECONDS);
this.mqttScheduledExecutorService.scheduleAtFixedRate(new Runnable() {
@Override public void run() {
IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager();
Collection<InFlightPacket> expired = new ArrayList<>();
iotClientManager.getInflightTimeouts().drainTo(expired);
for (InFlightPacket notAcked : expired) {
MQTTSession client = notAcked.getClient();
if (!client.isConnected()) {
continue;
}
if (notAcked.getResendTime() > 3) {
client.getRemotingChannel().close();
continue;
}
if (client.getInflightWindow().containsKey(notAcked.getPacketId())) {
InFlightMessage inFlightMessage = client.getInflightWindow().get(notAcked.getPacketId());
MqttHeader mqttHeader = new MqttHeader();
mqttHeader.setTopicName(inFlightMessage.getTopic());
mqttHeader.setQosLevel(inFlightMessage.getPushQos());
mqttHeader.setRetain(false);
mqttHeader.setDup(true);
mqttHeader.setMessageType(MqttMessageType.PUBLISH.value());
notAcked.setStartTime(System.currentTimeMillis() + MqttConstant.FLIGHT_BEFORE_RESEND_MS);
notAcked.setResendTime(notAcked.getResendTime() + 1);
iotClientManager.getInflightTimeouts().add(notAcked);
client.pushMessage2Client(mqttHeader, inFlightMessage.getBody());
}
}
}
}, 10000, defaultMqttMessageProcessor.getMqttConfig().getScanAckTimeoutInterval(), TimeUnit.MILLISECONDS);
} }
@Override @Override
......
...@@ -68,8 +68,8 @@ public class MqttPushTask implements Runnable { ...@@ -68,8 +68,8 @@ public class MqttPushTask implements Runnable {
private BrokerData brokerData; private BrokerData brokerData;
private String rootTopic; private String rootTopic;
public MqttPushTask(DefaultMqttMessageProcessor processor, final MqttHeader mqttHeader, String rootTopic, Client client, public MqttPushTask(DefaultMqttMessageProcessor processor, final MqttHeader mqttHeader, String rootTopic,
BrokerData brokerData) { Client client, BrokerData brokerData) {
this.defaultMqttMessageProcessor = processor; this.defaultMqttMessageProcessor = processor;
this.mqttHeader = mqttHeader; this.mqttHeader = mqttHeader;
this.rootTopic = rootTopic; this.rootTopic = rootTopic;
......
...@@ -84,7 +84,7 @@ public class MqttPubackMessageHandlerTest { ...@@ -84,7 +84,7 @@ public class MqttPubackMessageHandlerTest {
MQTTSession mqttSession = Mockito.spy(new MQTTSession("client1", ClientRole.IOTCLIENT, null, true, true, remotingChannel, System.currentTimeMillis(), defaultMqttMessageProcessor)); MQTTSession mqttSession = Mockito.spy(new MQTTSession("client1", ClientRole.IOTCLIENT, null, true, true, remotingChannel, System.currentTimeMillis(), defaultMqttMessageProcessor));
Mockito.when(iotClientManager.getClient(anyString(), any(RemotingChannel.class))).thenReturn(mqttSession); Mockito.when(iotClientManager.getClient(anyString(), any(RemotingChannel.class))).thenReturn(mqttSession);
InFlightMessage inFlightMessage = Mockito.spy(new InFlightMessage("topicTest", 0, "Hello".getBytes(), null, null, 0)); InFlightMessage inFlightMessage = Mockito.spy(new InFlightMessage("topicTest", 0, "Hello".getBytes(), null, 0));
doReturn(inFlightMessage).when(mqttSession).pubAckReceived(anyInt()); doReturn(inFlightMessage).when(mqttSession).pubAckReceived(anyInt());
RemotingCommand remotingCommand = mqttPubackMessageHandler.handleMessage(mqttMessage, remotingChannel); RemotingCommand remotingCommand = mqttPubackMessageHandler.handleMessage(mqttMessage, remotingChannel);
assert remotingCommand == null; assert remotingCommand == null;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册