未验证 提交 53a15b63 编写于 作者: H Heng Du 提交者: GitHub

Merge pull request #1245 from xiangwangcheng/mqtt

[RIP-11] add logic of resending messages when ack-timeout
...@@ -54,6 +54,8 @@ public class MqttConfig { ...@@ -54,6 +54,8 @@ public class MqttConfig {
private long persistOffsetInterval = 2 * 1000; private long persistOffsetInterval = 2 * 1000;
private long scanAckTimeoutInterval = 1000;
public int getListenPort() { public int getListenPort() {
return listenPort; return listenPort;
} }
...@@ -149,4 +151,12 @@ public class MqttConfig { ...@@ -149,4 +151,12 @@ public class MqttConfig {
public void setPersistOffsetInterval(long persistOffsetInterval) { public void setPersistOffsetInterval(long persistOffsetInterval) {
this.persistOffsetInterval = persistOffsetInterval; this.persistOffsetInterval = persistOffsetInterval;
} }
public long getScanAckTimeoutInterval() {
return scanAckTimeoutInterval;
}
public void setScanAckTimeoutInterval(long scanAckTimeoutInterval) {
this.scanAckTimeoutInterval = scanAckTimeoutInterval;
}
} }
...@@ -31,7 +31,7 @@ public interface NnodeService { ...@@ -31,7 +31,7 @@ public interface NnodeService {
* *
* @param snodeConfig {@link SnodeConfig} * @param snodeConfig {@link SnodeConfig}
*/ */
void registerSnode(SnodeConfig snodeConfig) throws Exception; void registerSnode(SnodeConfig snodeConfig) throws Exception;
/** /**
* Update Nnode server address list. * Update Nnode server address list.
......
...@@ -23,6 +23,7 @@ import java.util.Map; ...@@ -23,6 +23,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import org.apache.rocketmq.common.client.Client; import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.client.ClientManagerImpl; import org.apache.rocketmq.common.client.ClientManagerImpl;
import org.apache.rocketmq.common.client.Subscription; import org.apache.rocketmq.common.client.Subscription;
...@@ -43,8 +44,9 @@ public class IOTClientManagerImpl extends ClientManagerImpl { ...@@ -43,8 +44,9 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
1024); 1024);
private final ConcurrentHashMap<String/*clientId*/, Subscription> clientId2Subscription = new ConcurrentHashMap<>(1024); private final ConcurrentHashMap<String/*clientId*/, Subscription> clientId2Subscription = new ConcurrentHashMap<>(1024);
private final Map<String/*snode ip*/, MqttClient> snode2MqttClient = new HashMap<>(); private final Map<String/*snode ip*/, MqttClient> snode2MqttClient = new HashMap<>();
private final ConcurrentHashMap<String /*broker*/, ConcurrentHashMap<String /*topic@clientId*/, TreeMap<Long/*queueOffset*/, MessageExt>>> processTable = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String /*broker*/, ConcurrentHashMap<String /*rootTopic@clientId*/, TreeMap<Long/*queueOffset*/, MessageExt>>> processTable = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String /*rootTopic@clientId*/, Integer> consumeOffsetTable = new ConcurrentHashMap<>();
private final DelayQueue<InFlightPacket> inflightTimeouts = new DelayQueue<>();
public IOTClientManagerImpl() { public IOTClientManagerImpl() {
} }
...@@ -129,4 +131,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl { ...@@ -129,4 +131,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
public ConcurrentHashMap<String, ConcurrentHashMap<String, TreeMap<Long, MessageExt>>> getProcessTable() { public ConcurrentHashMap<String, ConcurrentHashMap<String, TreeMap<Long, MessageExt>>> getProcessTable() {
return processTable; return processTable;
} }
public DelayQueue<InFlightPacket> getInflightTimeouts() {
return inflightTimeouts;
}
} }
...@@ -23,16 +23,13 @@ public class InFlightMessage { ...@@ -23,16 +23,13 @@ public class InFlightMessage {
private final Integer pushQos; private final Integer pushQos;
private final BrokerData brokerData; private final BrokerData brokerData;
private final byte[] body; private final byte[] body;
private final String messageId;
private final long queueOffset; private final long queueOffset;
public InFlightMessage(String topic, Integer pushQos, byte[] body, BrokerData brokerData, String messageId, public InFlightMessage(String topic, Integer pushQos, byte[] body, BrokerData brokerData, long queueOffset) {
long queueOffset) {
this.topic = topic; this.topic = topic;
this.pushQos = pushQos; this.pushQos = pushQos;
this.body = body; this.body = body;
this.brokerData = brokerData; this.brokerData = brokerData;
this.messageId = messageId;
this.queueOffset = queueOffset; this.queueOffset = queueOffset;
} }
...@@ -44,10 +41,6 @@ public class InFlightMessage { ...@@ -44,10 +41,6 @@ public class InFlightMessage {
return brokerData; return brokerData;
} }
public String getMessageId() {
return messageId;
}
public long getQueueOffset() { public long getQueueOffset() {
return queueOffset; return queueOffset;
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.mqtt.client;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
public class InFlightPacket implements Delayed {
private final MQTTSession client;
private final int packetId;
private long startTime;
private int resendTime = 0;
InFlightPacket(MQTTSession client, int packetId, long delayInMilliseconds) {
this.client = client;
this.packetId = packetId;
this.startTime = System.currentTimeMillis() + delayInMilliseconds;
}
@Override
public long getDelay(TimeUnit unit) {
long diff = startTime - System.currentTimeMillis();
return unit.convert(diff, TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed o) {
if ((this.startTime - ((InFlightPacket) o).startTime) == 0) {
return 0;
}
if ((this.startTime - ((InFlightPacket) o).startTime) > 0) {
return 1;
} else {
return -1;
}
}
public MQTTSession getClient() {
return client;
}
public int getPacketId() {
return packetId;
}
public long getStartTime() {
return startTime;
}
public void setStartTime(long startTime) {
this.startTime = startTime;
}
public int getResendTime() {
return resendTime;
}
public void setResendTime(int resendTime) {
this.resendTime = resendTime;
}
@Override public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof InFlightPacket)) {
return false;
}
InFlightPacket packet = (InFlightPacket) obj;
return packet.getClient().equals(this.getClient()) &&
packet.getPacketId() == this.getPacketId();
}
}
\ No newline at end of file
...@@ -23,9 +23,6 @@ import java.util.Objects; ...@@ -23,9 +23,6 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.apache.rocketmq.common.client.Client; import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.client.ClientRole; import org.apache.rocketmq.common.client.ClientRole;
...@@ -45,6 +42,7 @@ import org.apache.rocketmq.remoting.netty.NettyChannelImpl; ...@@ -45,6 +42,7 @@ import org.apache.rocketmq.remoting.netty.NettyChannelImpl;
import org.apache.rocketmq.remoting.protocol.RemotingCommand; import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader;
import static org.apache.rocketmq.mqtt.constant.MqttConstant.FLIGHT_BEFORE_RESEND_MS;
import static org.apache.rocketmq.mqtt.constant.MqttConstant.TOPIC_CLIENTID_SEPARATOR; import static org.apache.rocketmq.mqtt.constant.MqttConstant.TOPIC_CLIENTID_SEPARATOR;
public class MQTTSession extends Client { public class MQTTSession extends Client {
...@@ -57,40 +55,9 @@ public class MQTTSession extends Client { ...@@ -57,40 +55,9 @@ public class MQTTSession extends Client {
private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; private final DefaultMqttMessageProcessor defaultMqttMessageProcessor;
private final AtomicInteger inflightSlots = new AtomicInteger(10); private final AtomicInteger inflightSlots = new AtomicInteger(10);
private final Map<Integer, InFlightMessage> inflightWindow = new HashMap<>(); private final Map<Integer, InFlightMessage> inflightWindow = new HashMap<>();
private final DelayQueue<InFlightPacket> inflightTimeouts = new DelayQueue<>();
private static final int FLIGHT_BEFORE_RESEND_MS = 5_000;
private Hashtable inUsePacketIds = new Hashtable(); private Hashtable inUsePacketIds = new Hashtable();
private int nextPacketId = 0; private int nextPacketId = 0;
static class InFlightPacket implements Delayed {
final int packetId;
private long startTime;
InFlightPacket(int packetId, long delayInMilliseconds) {
this.packetId = packetId;
this.startTime = System.currentTimeMillis() + delayInMilliseconds;
}
@Override
public long getDelay(TimeUnit unit) {
long diff = startTime - System.currentTimeMillis();
return unit.convert(diff, TimeUnit.MILLISECONDS);
}
@Override
public int compareTo(Delayed o) {
if ((this.startTime - ((InFlightPacket) o).startTime) == 0) {
return 0;
}
if ((this.startTime - ((InFlightPacket) o).startTime) > 0) {
return 1;
} else {
return -1;
}
}
}
public MQTTSession(String clientId, ClientRole clientRole, Set<String> groups, boolean isConnected, public MQTTSession(String clientId, ClientRole clientRole, Set<String> groups, boolean isConnected,
boolean cleanSession, RemotingChannel remotingChannel, long lastUpdateTimestamp, boolean cleanSession, RemotingChannel remotingChannel, long lastUpdateTimestamp,
DefaultMqttMessageProcessor defaultMqttMessageProcessor) { DefaultMqttMessageProcessor defaultMqttMessageProcessor) {
...@@ -149,9 +116,10 @@ public class MQTTSession extends Client { ...@@ -149,9 +116,10 @@ public class MQTTSession extends Client {
if (inflightSlots.get() > 0) { if (inflightSlots.get() > 0) {
inflightSlots.decrementAndGet(); inflightSlots.decrementAndGet();
mqttHeader.setPacketId(getNextPacketId()); mqttHeader.setPacketId(getNextPacketId());
inflightWindow.put(mqttHeader.getPacketId(), new InFlightMessage(mqttHeader.getTopicName(), mqttHeader.getQosLevel(), messageExt.getBody(), brokerData, messageExt.getMsgId(), messageExt.getQueueOffset())); inflightWindow.put(mqttHeader.getPacketId(), new InFlightMessage(mqttHeader.getTopicName(), mqttHeader.getQosLevel(), messageExt.getBody(), brokerData, messageExt.getQueueOffset()));
// inflightTimeouts.add(new InFlightPacket(mqttHeader.getPacketId(), FLIGHT_BEFORE_RESEND_MS)); IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager();
put2processTable(((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getProcessTable(), brokerData.getBrokerName(), MqttUtil.getRootTopic(mqttHeader.getTopicName()), messageExt); iotClientManager.getInflightTimeouts().add(new InFlightPacket(this, mqttHeader.getPacketId(), FLIGHT_BEFORE_RESEND_MS));
put2processTable(iotClientManager.getProcessTable(), brokerData.getBrokerName(), MqttUtil.getRootTopic(mqttHeader.getTopicName()), messageExt);
pushMessage2Client(mqttHeader, messageExt.getBody()); pushMessage2Client(mqttHeader, messageExt.getBody());
} }
} }
...@@ -168,11 +136,12 @@ public class MQTTSession extends Client { ...@@ -168,11 +136,12 @@ public class MQTTSession extends Client {
} }
} }
inflightSlots.incrementAndGet(); inflightSlots.incrementAndGet();
((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getInflightTimeouts().remove(new InFlightPacket(this, ackPacketId, 0));
releasePacketId(ackPacketId); releasePacketId(ackPacketId);
return remove; return remove;
} }
private void pushMessage2Client(MqttHeader mqttHeader, byte[] body) { public void pushMessage2Client(MqttHeader mqttHeader, byte[] body) {
try { try {
//set remaining length //set remaining length
int remainingLength = mqttHeader.getTopicName().getBytes().length + body.length; int remainingLength = mqttHeader.getTopicName().getBytes().length + body.length;
...@@ -259,10 +228,6 @@ public class MQTTSession extends Client { ...@@ -259,10 +228,6 @@ public class MQTTSession extends Client {
return inflightWindow; return inflightWindow;
} }
public DelayQueue<InFlightPacket> getInflightTimeouts() {
return inflightTimeouts;
}
public Hashtable getInUsePacketIds() { public Hashtable getInUsePacketIds() {
return inUsePacketIds; return inUsePacketIds;
} }
......
...@@ -27,6 +27,7 @@ public class MqttConstant { ...@@ -27,6 +27,7 @@ public class MqttConstant {
public static final String SUBSCRIPTION_SEPARATOR = "/"; public static final String SUBSCRIPTION_SEPARATOR = "/";
public static final String TOPIC_CLIENTID_SEPARATOR = "@"; public static final String TOPIC_CLIENTID_SEPARATOR = "@";
public static final long DEFAULT_TIMEOUT_MILLS = 3000L; public static final long DEFAULT_TIMEOUT_MILLS = 3000L;
public static final int FLIGHT_BEFORE_RESEND_MS = 5_000;
public static final String PROPERTY_MQTT_QOS = "PROPERTY_MQTT_QOS"; public static final String PROPERTY_MQTT_QOS = "PROPERTY_MQTT_QOS";
public static final AttributeKey<Client> MQTT_CLIENT_ATTRIBUTE_KEY = AttributeKey.valueOf("mqtt.client"); public static final AttributeKey<Client> MQTT_CLIENT_ATTRIBUTE_KEY = AttributeKey.valueOf("mqtt.client");
} }
...@@ -75,7 +75,7 @@ public class MqttSubscribeMessageHandler implements MessageHandler { ...@@ -75,7 +75,7 @@ public class MqttSubscribeMessageHandler implements MessageHandler {
MqttSubscribeMessage mqttSubscribeMessage = (MqttSubscribeMessage) message; MqttSubscribeMessage mqttSubscribeMessage = (MqttSubscribeMessage) message;
MqttSubscribePayload payload = mqttSubscribeMessage.payload(); MqttSubscribePayload payload = mqttSubscribeMessage.payload();
IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager(); IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager();
MQTTSession client = (MQTTSession)iotClientManager.getClient(IOTClientManagerImpl.IOT_GROUP, remotingChannel); MQTTSession client = (MQTTSession) iotClientManager.getClient(IOTClientManagerImpl.IOT_GROUP, remotingChannel);
if (client == null) { if (client == null) {
log.error("Can't find associated client, the connection will be closed. remotingChannel={}, MqttMessage={}", remotingChannel.toString(), message.toString()); log.error("Can't find associated client, the connection will be closed. remotingChannel={}, MqttMessage={}", remotingChannel.toString(), message.toString());
remotingChannel.close(); remotingChannel.close();
...@@ -91,7 +91,7 @@ public class MqttSubscribeMessageHandler implements MessageHandler { ...@@ -91,7 +91,7 @@ public class MqttSubscribeMessageHandler implements MessageHandler {
remotingChannel.close(); remotingChannel.close();
return null; return null;
} }
if (isTopicWithWildcard(payload.topicSubscriptions())) { if (topicStartWithWildcard(payload.topicSubscriptions())) {
log.error("Client can not subscribe topic starts with wildcards! clientId={}, topicSubscriptions={}", client.getClientId(), payload.topicSubscriptions().toString()); log.error("Client can not subscribe topic starts with wildcards! clientId={}, topicSubscriptions={}", client.getClientId(), payload.topicSubscriptions().toString());
} }
...@@ -126,7 +126,7 @@ public class MqttSubscribeMessageHandler implements MessageHandler { ...@@ -126,7 +126,7 @@ public class MqttSubscribeMessageHandler implements MessageHandler {
subscription = clientId2Subscription.get(client.getClientId()); subscription = clientId2Subscription.get(client.getClientId());
} else { } else {
subscription = new Subscription(); subscription = new Subscription();
subscription.setCleanSession(((MQTTSession)client).isCleanSession()); subscription.setCleanSession(((MQTTSession) client).isCleanSession());
} }
ConcurrentHashMap<String, SubscriptionData> subscriptionDatas = subscription.getSubscriptionTable(); ConcurrentHashMap<String, SubscriptionData> subscriptionDatas = subscription.getSubscriptionTable();
List<Integer> grantQoss = new ArrayList<>(); List<Integer> grantQoss = new ArrayList<>();
...@@ -155,17 +155,17 @@ public class MqttSubscribeMessageHandler implements MessageHandler { ...@@ -155,17 +155,17 @@ public class MqttSubscribeMessageHandler implements MessageHandler {
private boolean isQosLegal(List<MqttTopicSubscription> mqttTopicSubscriptions) { private boolean isQosLegal(List<MqttTopicSubscription> mqttTopicSubscriptions) {
for (MqttTopicSubscription subscription : mqttTopicSubscriptions) { for (MqttTopicSubscription subscription : mqttTopicSubscriptions) {
if (!(subscription.qualityOfService().equals(MqttQoS.AT_LEAST_ONCE) || subscription.qualityOfService().equals(MqttQoS.EXACTLY_ONCE) || subscription.qualityOfService().equals(MqttQoS.AT_MOST_ONCE))) { if (MqttUtil.isQosLegal(subscription.qualityOfService())) {
return true; return true;
} }
} }
return false; return false;
} }
private boolean isTopicWithWildcard(List<MqttTopicSubscription> mqttTopicSubscriptions) { private boolean topicStartWithWildcard(List<MqttTopicSubscription> mqttTopicSubscriptions) {
for (MqttTopicSubscription subscription : mqttTopicSubscriptions) { for (MqttTopicSubscription subscription : mqttTopicSubscriptions) {
String rootTopic = MqttUtil.getRootTopic(subscription.topicName()); String rootTopic = MqttUtil.getRootTopic(subscription.topicName());
if (rootTopic.contains(MqttConstant.SUBSCRIPTION_FLAG_PLUS) || rootTopic.contains(MqttConstant.SUBSCRIPTION_FLAG_SHARP)) { if (rootTopic.contains(MqttConstant.SUBSCRIPTION_FLAG_PLUS) || rootTopic.contains(MqttConstant.SUBSCRIPTION_FLAG_SHARP) || rootTopic.isEmpty()) {
return true; return true;
} }
} }
......
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
*/ */
package org.apache.rocketmq.mqtt.service.impl; package org.apache.rocketmq.mqtt.service.impl;
import io.netty.handler.codec.mqtt.MqttMessageType;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
...@@ -29,7 +32,12 @@ import org.apache.rocketmq.common.service.ScheduledService; ...@@ -29,7 +32,12 @@ import org.apache.rocketmq.common.service.ScheduledService;
import org.apache.rocketmq.logging.InternalLogger; import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory; import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl;
import org.apache.rocketmq.mqtt.client.InFlightMessage;
import org.apache.rocketmq.mqtt.client.InFlightPacket;
import org.apache.rocketmq.mqtt.client.MQTTSession;
import org.apache.rocketmq.mqtt.constant.MqttConstant;
import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor;
import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader;
public class MqttScheduledServiceImpl implements ScheduledService { public class MqttScheduledServiceImpl implements ScheduledService {
private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME);
...@@ -67,6 +75,36 @@ public class MqttScheduledServiceImpl implements ScheduledService { ...@@ -67,6 +75,36 @@ public class MqttScheduledServiceImpl implements ScheduledService {
} }
} }
}, 0, defaultMqttMessageProcessor.getMqttConfig().getPersistOffsetInterval(), TimeUnit.MILLISECONDS); }, 0, defaultMqttMessageProcessor.getMqttConfig().getPersistOffsetInterval(), TimeUnit.MILLISECONDS);
this.mqttScheduledExecutorService.scheduleAtFixedRate(new Runnable() {
@Override public void run() {
IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager();
Collection<InFlightPacket> expired = new ArrayList<>();
iotClientManager.getInflightTimeouts().drainTo(expired);
for (InFlightPacket notAcked : expired) {
MQTTSession client = notAcked.getClient();
if (!client.isConnected()) {
continue;
}
if (notAcked.getResendTime() > 3) {
client.getRemotingChannel().close();
continue;
}
if (client.getInflightWindow().containsKey(notAcked.getPacketId())) {
InFlightMessage inFlightMessage = client.getInflightWindow().get(notAcked.getPacketId());
MqttHeader mqttHeader = new MqttHeader();
mqttHeader.setTopicName(inFlightMessage.getTopic());
mqttHeader.setQosLevel(inFlightMessage.getPushQos());
mqttHeader.setRetain(false);
mqttHeader.setDup(true);
mqttHeader.setMessageType(MqttMessageType.PUBLISH.value());
notAcked.setStartTime(System.currentTimeMillis() + MqttConstant.FLIGHT_BEFORE_RESEND_MS);
notAcked.setResendTime(notAcked.getResendTime() + 1);
iotClientManager.getInflightTimeouts().add(notAcked);
client.pushMessage2Client(mqttHeader, inFlightMessage.getBody());
}
}
}
}, 10000, defaultMqttMessageProcessor.getMqttConfig().getScanAckTimeoutInterval(), TimeUnit.MILLISECONDS);
} }
@Override @Override
......
...@@ -68,8 +68,8 @@ public class MqttPushTask implements Runnable { ...@@ -68,8 +68,8 @@ public class MqttPushTask implements Runnable {
private BrokerData brokerData; private BrokerData brokerData;
private String rootTopic; private String rootTopic;
public MqttPushTask(DefaultMqttMessageProcessor processor, final MqttHeader mqttHeader, String rootTopic, Client client, public MqttPushTask(DefaultMqttMessageProcessor processor, final MqttHeader mqttHeader, String rootTopic,
BrokerData brokerData) { Client client, BrokerData brokerData) {
this.defaultMqttMessageProcessor = processor; this.defaultMqttMessageProcessor = processor;
this.mqttHeader = mqttHeader; this.mqttHeader = mqttHeader;
this.rootTopic = rootTopic; this.rootTopic = rootTopic;
......
...@@ -84,7 +84,7 @@ public class MqttPubackMessageHandlerTest { ...@@ -84,7 +84,7 @@ public class MqttPubackMessageHandlerTest {
MQTTSession mqttSession = Mockito.spy(new MQTTSession("client1", ClientRole.IOTCLIENT, null, true, true, remotingChannel, System.currentTimeMillis(), defaultMqttMessageProcessor)); MQTTSession mqttSession = Mockito.spy(new MQTTSession("client1", ClientRole.IOTCLIENT, null, true, true, remotingChannel, System.currentTimeMillis(), defaultMqttMessageProcessor));
Mockito.when(iotClientManager.getClient(anyString(), any(RemotingChannel.class))).thenReturn(mqttSession); Mockito.when(iotClientManager.getClient(anyString(), any(RemotingChannel.class))).thenReturn(mqttSession);
InFlightMessage inFlightMessage = Mockito.spy(new InFlightMessage("topicTest", 0, "Hello".getBytes(), null, null, 0)); InFlightMessage inFlightMessage = Mockito.spy(new InFlightMessage("topicTest", 0, "Hello".getBytes(), null, 0));
doReturn(inFlightMessage).when(mqttSession).pubAckReceived(anyInt()); doReturn(inFlightMessage).when(mqttSession).pubAckReceived(anyInt());
RemotingCommand remotingCommand = mqttPubackMessageHandler.handleMessage(mqttMessage, remotingChannel); RemotingCommand remotingCommand = mqttPubackMessageHandler.handleMessage(mqttMessage, remotingChannel);
assert remotingCommand == null; assert remotingCommand == null;
......
...@@ -17,10 +17,124 @@ ...@@ -17,10 +17,124 @@
package org.apache.rocketmq.mqtt; package org.apache.rocketmq.mqtt;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttConnectPayload;
import io.netty.handler.codec.mqtt.MqttConnectVariableHeader;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttSubscribeMessage;
import io.netty.handler.codec.mqtt.MqttSubscribePayload;
import io.netty.handler.codec.mqtt.MqttTopicSubscription;
import io.netty.util.internal.StringUtil;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import org.apache.rocketmq.common.MqttConfig;
import org.apache.rocketmq.common.SnodeConfig;
import org.apache.rocketmq.mqtt.client.MQTTSession;
import org.apache.rocketmq.mqtt.exception.WrongMessageTypeException;
import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttSubscribeMessageHandler;
import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import static org.apache.rocketmq.mqtt.client.IOTClientManagerImpl.IOT_GROUP;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class MqttSubscribeMessageHandlerTest { public class MqttSubscribeMessageHandlerTest {
@Rule
public ExpectedException exception = ExpectedException.none();
private DefaultMqttMessageProcessor defaultMqttMessageProcessor = new DefaultMqttMessageProcessor(new MqttConfig(), new SnodeConfig(), null, null, null);
@Spy
private MqttSubscribeMessageHandler mqttSubscribeMessageHandler = new MqttSubscribeMessageHandler(defaultMqttMessageProcessor);
@Mock
private RemotingChannel remotingChannel;
@Test
public void test_topicStartWithWildcard() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Method method = MqttSubscribeMessageHandler.class.getDeclaredMethod("topicStartWithWildcard", List.class);
method.setAccessible(true);
List<MqttTopicSubscription> subscriptions1 = new ArrayList<>();
subscriptions1.add(new MqttTopicSubscription("+/test", MqttQoS.AT_MOST_ONCE));
boolean invoke1 = (boolean) method.invoke(mqttSubscribeMessageHandler, subscriptions1);
assert invoke1;
List<MqttTopicSubscription> subscriptions2 = new ArrayList<>();
subscriptions2.add(new MqttTopicSubscription("test/topic", MqttQoS.AT_MOST_ONCE));
boolean invoke2 = (boolean) method.invoke(mqttSubscribeMessageHandler, subscriptions2);
assert !invoke2;
List<MqttTopicSubscription> subscriptions3 = new ArrayList<>();
subscriptions3.add(new MqttTopicSubscription("/test/topic", MqttQoS.AT_MOST_ONCE));
boolean invoke3 = (boolean) method.invoke(mqttSubscribeMessageHandler, subscriptions3);
assert invoke3;
}
@Test
public void test_handleMessage_wrongMessageType() {
MqttConnectMessage mqttConnectMessage = new MqttConnectMessage(new MqttFixedHeader(
MqttMessageType.CONNECT, false, MqttQoS.AT_MOST_ONCE, false, 200), new MqttConnectVariableHeader(null, 4, false, false, false, 0, false, false, 50), new MqttConnectPayload("abcd", "ttest", "message".getBytes(), "user", "password".getBytes()));
exception.expect(WrongMessageTypeException.class);
mqttSubscribeMessageHandler.handleMessage(mqttConnectMessage, remotingChannel);
}
@Test
public void test_handleMessage_clientNotFound() {
List<MqttTopicSubscription> subscriptions = new ArrayList<>();
subscriptions.add(new MqttTopicSubscription("test/a", MqttQoS.AT_MOST_ONCE));
MqttSubscribeMessage mqttSubscribeMessage = new MqttSubscribeMessage(new MqttFixedHeader(
MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 200), MqttMessageIdVariableHeader.from(1), new MqttSubscribePayload(subscriptions));
RemotingCommand remotingCommand = mqttSubscribeMessageHandler.handleMessage(mqttSubscribeMessage, remotingChannel);
assertEquals(null, defaultMqttMessageProcessor.getIotClientManager().getClient(IOT_GROUP, remotingChannel));
assert remotingCommand == null;
}
@Test
public void test_handleMessage_emptyTopicFilter() {
List<MqttTopicSubscription> subscriptions = new ArrayList<>();
MqttSubscribeMessage mqttSubscribeMessage = new MqttSubscribeMessage(new MqttFixedHeader(
MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 200), MqttMessageIdVariableHeader.from(1), new MqttSubscribePayload(subscriptions));
MQTTSession mqttSession = Mockito.mock(MQTTSession.class);
Mockito.when(mqttSession.getRemotingChannel()).thenReturn(remotingChannel);
// Mockito.when(mqttSubscribeMessage.toString()).thenReturn("toString");
defaultMqttMessageProcessor.getIotClientManager().register(IOT_GROUP, mqttSession);
RemotingCommand remotingCommand = mqttSubscribeMessageHandler.handleMessage(mqttSubscribeMessage, remotingChannel);
assertNotNull(defaultMqttMessageProcessor.getIotClientManager().getClient(IOT_GROUP, remotingChannel));
assert remotingCommand == null;
}
@Test
public void test_MqttSubscribePayload_toString() {
List<MqttTopicSubscription> topicSubscriptions = new ArrayList<>();
topicSubscriptions.add(new MqttTopicSubscription("test/topic", MqttQoS.AT_MOST_ONCE));
StringBuilder builder = new StringBuilder(StringUtil.simpleClassName(this)).append('[');
for (int i = 0; i <= topicSubscriptions.size() - 1; i++) {
builder.append(topicSubscriptions.get(i)).append(", ");
}
if (builder.substring(builder.length() - 2).equals(", ")) {
builder.delete(builder.length() - 2, builder.length());
}
builder.append(']');
System.out.println(builder.toString());
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.mqtt;
import org.apache.rocketmq.mqtt.util.MqttUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import static org.junit.Assert.assertEquals;
@RunWith(MockitoJUnitRunner.class)
public class MqttUtilTest {
@Test
public void test_getRootTopic() {
String rootTopic = MqttUtil.getRootTopic("/test/topic");
assertEquals("", rootTopic);
String rootTopic2 = MqttUtil.getRootTopic("test/topic");
assertEquals("test", rootTopic2);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册