提交 e010ca74 编写于 作者: C chengxiangwang

polish data structure of IOTClientManager;add message forward logic

上级 4f8cd91c
......@@ -51,7 +51,7 @@ import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@RunWith(MockitoJUnitRunner.Silent.class)
public class DefaultMQPullConsumerTest {
@Spy
private MQClientInstance mQClientFactory = MQClientManager.getInstance().getAndCreateMQClientInstance(new ClientConfig());
......
......@@ -73,7 +73,7 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
@RunWith(MockitoJUnitRunner.Silent.class)
public class DefaultMQPushConsumerTest {
private String consumerGroup;
private String topic = "FooBar";
......
......@@ -46,7 +46,7 @@
<dependencies>
<dependency>
<groupId>org.apache.rocketmq</groupId>
<groupId>${project.groupId}</groupId>
<artifactId>rocketmq-remoting</artifactId>
</dependency>
<dependency>
......@@ -54,5 +54,5 @@
<artifactId>snakeyaml</artifactId>
</dependency>
</dependencies>
</project>
......@@ -60,7 +60,6 @@ public class MqttConfig {
this.listenPort = listenPort;
}
public boolean isAclEnable() {
return aclEnable;
}
......
......@@ -99,5 +99,9 @@
<groupId>${project.groupId}</groupId>
<artifactId>rocketmq-broker</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.paho</groupId>
<artifactId>org.eclipse.paho.client.mqttv3</artifactId>
</dependency>
</dependencies>
</project>
......@@ -16,6 +16,7 @@
*/
package org.apache.rocketmq.mqtt.client;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
......@@ -25,10 +26,10 @@ import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.client.ClientManagerImpl;
import org.apache.rocketmq.common.client.Subscription;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.protocol.heartbeat.SubscriptionData;
import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.eclipse.paho.client.mqttv3.MqttClient;
public class IOTClientManagerImpl extends ClientManagerImpl {
......@@ -36,9 +37,12 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
public static final String IOT_GROUP = "IOT_GROUP";
private final ConcurrentHashMap<String/*root topic*/, ConcurrentHashMap<Client, Set<SubscriptionData>>> topic2SubscriptionTable = new ConcurrentHashMap<>(
// private final ConcurrentHashMap<String/*root topic*/, ConcurrentHashMap<Client, Set<SubscriptionData>>> topic2SubscriptionTable = new ConcurrentHashMap<>(
// 1024);
private final ConcurrentHashMap<String/*root topic*/, Set<Client>> topic2Clients = new ConcurrentHashMap<>(
1024);
private final ConcurrentHashMap<String/*clientId*/, Subscription> clientId2Subscription = new ConcurrentHashMap<>(1024);
private final Map<String/*snode ip*/, MqttClient> snode2MqttClient = new HashMap<>();
public IOTClientManagerImpl() {
}
......@@ -78,7 +82,7 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
}
public void cleanSessionState(String clientId) {
clientId2Subscription.remove(clientId);
/* clientId2Subscription.remove(clientId);
for (Iterator<Map.Entry<String, ConcurrentHashMap<Client, Set<SubscriptionData>>>> iterator = topic2SubscriptionTable.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<String, ConcurrentHashMap<Client, Set<SubscriptionData>>> next = iterator.next();
for (Iterator<Map.Entry<Client, Set<SubscriptionData>>> iterator1 = next.getValue().entrySet().iterator(); iterator1.hasNext(); ) {
......@@ -91,7 +95,18 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
if (next.getValue() == null || next.getValue().size() == 0) {
iterator.remove();
}
}*/
for (Iterator<Map.Entry<String, Set<Client>>> iterator = topic2Clients.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<String, Set<Client>> next = iterator.next();
Iterator<Client> iterator1 = next.getValue().iterator();
while (iterator1.hasNext()) {
if (iterator1.next().getClientId().equals(clientId)) {
iterator1.remove();
}
}
}
clientId2Subscription.remove(clientId);
//remove offline messages
}
......@@ -99,8 +114,12 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
return clientId2Subscription.get(clientId);
}
public ConcurrentHashMap<String, ConcurrentHashMap<Client, Set<SubscriptionData>>> getTopic2SubscriptionTable() {
return topic2SubscriptionTable;
/* public ConcurrentHashMap<String, ConcurrentHashMap<Client, Set<SubscriptionData>>> getTopic2SubscriptionTable() {
return topic2SubscriptionTable;
}*/
public ConcurrentHashMap<String/*root topic*/, Set<Client>> getTopic2Clients() {
return topic2Clients;
}
public ConcurrentHashMap<String, Subscription> getClientId2Subscription() {
......@@ -110,4 +129,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl {
public void initSubscription(String clientId, Subscription subscription) {
clientId2Subscription.put(clientId, subscription);
}
public Map<String, MqttClient> getSnode2MqttClient() {
return snode2MqttClient;
}
}
......@@ -17,13 +17,25 @@
package org.apache.rocketmq.mqtt.mqtthandler;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttQoS;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.client.Subscription;
import org.apache.rocketmq.common.exception.MQClientException;
import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl;
import org.apache.rocketmq.mqtt.util.MqttUtil;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.apache.rocketmq.remoting.exception.RemotingConnectException;
import org.apache.rocketmq.remoting.exception.RemotingSendRequestException;
import org.apache.rocketmq.remoting.exception.RemotingTimeoutException;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader;
public interface MessageHandler {
......@@ -32,5 +44,46 @@ public interface MessageHandler {
*
* @param message
*/
RemotingCommand handleMessage(MqttMessage message, RemotingChannel remotingChannel) throws InterruptedException, RemotingTimeoutException, RemotingSendRequestException, RemotingConnectException, MQClientException;
RemotingCommand handleMessage(MqttMessage message,
RemotingChannel remotingChannel) throws InterruptedException, RemotingTimeoutException, RemotingSendRequestException, RemotingConnectException, MQClientException;
default Set<Client> findCurrentNodeClientsTobePublish(String topic, IOTClientManagerImpl iotClientManager) {
//find those clients publishing the message to
ConcurrentHashMap<String, Set<Client>> topic2Clients = iotClientManager.getTopic2Clients();
ConcurrentHashMap<String, Subscription> clientId2Subscription = iotClientManager.getClientId2Subscription();
Set<Client> clientsTobePush = new HashSet<>();
if (topic2Clients.containsKey(MqttUtil.getRootTopic(topic))) {
Set<Client> clients = topic2Clients.get(MqttUtil.getRootTopic(topic));
for (Client client : clients) {
Subscription subscription = clientId2Subscription.get(client.getClientId());
Enumeration<String> keys = subscription.getSubscriptionTable().keys();
while (keys.hasMoreElements()) {
String topicFilter = keys.nextElement();
if (MqttUtil.isMatch(topicFilter, topic)) {
clientsTobePush.add(client);
}
}
}
}
return clientsTobePush;
}
default RemotingCommand doResponse(MqttFixedHeader fixedHeader) {
if (fixedHeader.qosLevel().value() > 0) {
RemotingCommand command = RemotingCommand.createResponseCommand(MqttHeader.class);
MqttHeader mqttHeader = (MqttHeader) command.readCustomHeader();
if (fixedHeader.qosLevel().equals(MqttQoS.AT_MOST_ONCE)) {
mqttHeader.setMessageType(MqttMessageType.PUBACK.value());
mqttHeader.setDup(false);
mqttHeader.setQosLevel(MqttQoS.AT_MOST_ONCE.value());
mqttHeader.setRetain(false);
mqttHeader.setRemainingLength(2);
mqttHeader.setPacketId(0);
} else if (fixedHeader.qosLevel().equals(MqttQoS.AT_LEAST_ONCE)) {
//PUBREC/PUBREL/PUBCOMP
}
return command;
}
return null;
}
}
......@@ -17,21 +17,29 @@
package org.apache.rocketmq.mqtt.mqtthandler.impl;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import io.netty.handler.codec.mqtt.MqttPublishVariableHeader;
import io.netty.handler.codec.mqtt.MqttQoS;
import java.util.Set;
import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl;
import org.apache.rocketmq.mqtt.mqtthandler.MessageHandler;
import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor;
import org.apache.rocketmq.mqtt.processor.InnerMqttMessageProcessor;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
public class MqttMessageForwarder implements MessageHandler {
private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME);
private final DefaultMqttMessageProcessor defaultMqttMessageProcessor;
private final InnerMqttMessageProcessor innerMqttMessageProcessor;
public MqttMessageForwarder(DefaultMqttMessageProcessor processor) {
this.defaultMqttMessageProcessor = processor;
public MqttMessageForwarder(InnerMqttMessageProcessor processor) {
this.innerMqttMessageProcessor = processor;
}
/**
......@@ -41,6 +49,17 @@ public class MqttMessageForwarder implements MessageHandler {
* @return whether the message is handled successfully
*/
@Override public RemotingCommand handleMessage(MqttMessage message, RemotingChannel remotingChannel) {
return null;
MqttPublishMessage mqttPublishMessage = (MqttPublishMessage) message;
MqttFixedHeader fixedHeader = mqttPublishMessage.fixedHeader();
MqttPublishVariableHeader variableHeader = mqttPublishMessage.variableHeader();
if (fixedHeader.qosLevel().equals(MqttQoS.AT_MOST_ONCE)) {
ByteBuf payload = mqttPublishMessage.payload();
//Publish message to clients
Set<Client> clientsTobePublish = findCurrentNodeClientsTobePublish(variableHeader.topicName(), (IOTClientManagerImpl) this.innerMqttMessageProcessor.getIotClientManager());
innerMqttMessageProcessor.getDefaultMqttMessageProcessor().getMqttPushService().pushMessageQos0(variableHeader.topicName(), payload, clientsTobePublish);
}else if(fixedHeader.qosLevel().equals(MqttQoS.AT_LEAST_ONCE)){
//TODO
}
return doResponse(fixedHeader);
}
}
......@@ -118,8 +118,8 @@ public class MqttSubscribeMessageHandler implements MessageHandler {
//do the logic when client sends subscribe packet.
//1.update clientId2Subscription
ConcurrentHashMap<String, Subscription> clientId2Subscription = iotClientManager.getClientId2Subscription();
ConcurrentHashMap<String, ConcurrentHashMap<Client, Set<SubscriptionData>>> topic2SubscriptionTable = iotClientManager.getTopic2SubscriptionTable();
Subscription subscription = null;
ConcurrentHashMap<String, Set<Client>> topic2Clients = iotClientManager.getTopic2Clients();
Subscription subscription;
if (clientId2Subscription.containsKey(client.getClientId())) {
subscription = clientId2Subscription.get(client.getClientId());
} else {
......@@ -133,23 +133,17 @@ public class MqttSubscribeMessageHandler implements MessageHandler {
grantQoss.add(actualQos);
SubscriptionData subscriptionData = new MqttSubscriptionData(mqttTopicSubscription.qualityOfService().value(), client.getClientId(), mqttTopicSubscription.topicName());
subscriptionDatas.put(mqttTopicSubscription.topicName(), subscriptionData);
//2.update topic2SubscriptionTable
//2.update topic2ClientIds
String rootTopic = MqttUtil.getRootTopic(mqttTopicSubscription.topicName());
ConcurrentHashMap<Client, Set<SubscriptionData>> client2SubscriptionData = topic2SubscriptionTable.get(rootTopic);
if (client2SubscriptionData == null || client2SubscriptionData.size() == 0) {
client2SubscriptionData = new ConcurrentHashMap<>();
ConcurrentHashMap<Client, Set<SubscriptionData>> prev = topic2SubscriptionTable.putIfAbsent(rootTopic, client2SubscriptionData);
if (topic2Clients.contains(rootTopic)) {
final Set<Client> clientIds = topic2Clients.get(rootTopic);
clientIds.add(client);
} else {
Set<Client> clients = new HashSet<>();
clients.add(client);
Set<Client> prev = topic2Clients.putIfAbsent(rootTopic, clients);
if (prev != null) {
client2SubscriptionData = prev;
}
Set<SubscriptionData> subscriptionDataSet = client2SubscriptionData.get(client);
if (subscriptionDataSet == null) {
subscriptionDataSet = new HashSet<>();
Set<SubscriptionData> prevSubscriptionDataSet = client2SubscriptionData.putIfAbsent(client, subscriptionDataSet);
if (prevSubscriptionDataSet != null) {
subscriptionDataSet = prevSubscriptionDataSet;
}
subscriptionDataSet.add(subscriptionData);
prev.add(client);
}
}
}
......
......@@ -25,6 +25,7 @@ import io.netty.handler.codec.mqtt.MqttUnsubscribeMessage;
import io.netty.handler.codec.mqtt.MqttUnsubscribePayload;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.rocketmq.common.client.Client;
......@@ -99,27 +100,32 @@ public class MqttUnsubscribeMessagHandler implements MessageHandler {
private void doUnsubscribe(Client client, List<String> topics, IOTClientManagerImpl iotClientManager) {
ConcurrentHashMap<String, Subscription> clientId2Subscription = iotClientManager.getClientId2Subscription();
ConcurrentHashMap<String, ConcurrentHashMap<Client, Set<SubscriptionData>>> topic2SubscriptionTable = iotClientManager.getTopic2SubscriptionTable();
ConcurrentHashMap<String, Set<Client>> topic2Clients = iotClientManager.getTopic2Clients();
Subscription subscription = clientId2Subscription.get(client.getClientId());
for (String topicFilter : topics) {
//1.update clientId2Subscription
if (clientId2Subscription.containsKey(client.getClientId())) {
Subscription subscription = clientId2Subscription.get(client.getClientId());
//1.update clientId2Subscription
if (clientId2Subscription.containsKey(client.getClientId())) {
for (String topicFilter : topics) {
subscription.getSubscriptionTable().remove(topicFilter);
}
//2.update topic2SubscriptionTable
String rootTopic = MqttUtil.getRootTopic(topicFilter);
ConcurrentHashMap<Client, Set<SubscriptionData>> client2SubscriptionData = topic2SubscriptionTable.get(rootTopic);
if (client2SubscriptionData != null) {
Set<SubscriptionData> subscriptionDataSet = client2SubscriptionData.get(client);
if (subscriptionDataSet != null) {
Iterator<SubscriptionData> iterator = subscriptionDataSet.iterator();
while (iterator.hasNext()) {
if (iterator.next().getTopic().equals(topicFilter))
iterator.remove();
}
}
for (Iterator<Map.Entry<String, Set<Client>>> iterator = topic2Clients.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<String, Set<Client>> next = iterator.next();
String rootTopic = next.getKey();
boolean needRemove = true;
for (Map.Entry<String, SubscriptionData> entry : subscription.getSubscriptionTable().entrySet()) {
if (MqttUtil.getRootTopic(entry.getKey()).equals(rootTopic)) {
needRemove = false;
break;
}
}
if (needRemove) {
next.getValue().remove(client);
}
if (next.getValue().size() == 0) {
iterator.remove();
}
}
}
}
......@@ -30,7 +30,6 @@ import io.netty.handler.codec.mqtt.MqttPublishVariableHeader;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttSubscribeMessage;
import io.netty.handler.codec.mqtt.MqttSubscribePayload;
import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.Map;
import org.apache.rocketmq.common.MqttConfig;
......@@ -61,7 +60,6 @@ import org.apache.rocketmq.mqtt.service.impl.WillMessageServiceImpl;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.apache.rocketmq.remoting.RemotingServer;
import org.apache.rocketmq.remoting.RequestProcessor;
import org.apache.rocketmq.remoting.exception.RemotingCommandException;
import org.apache.rocketmq.remoting.exception.RemotingConnectException;
import org.apache.rocketmq.remoting.exception.RemotingSendRequestException;
import org.apache.rocketmq.remoting.exception.RemotingTimeoutException;
......@@ -85,7 +83,8 @@ public class DefaultMqttMessageProcessor implements RequestProcessor {
private EnodeService enodeService;
private NnodeService nnodeService;
public DefaultMqttMessageProcessor(MqttConfig mqttConfig, SnodeConfig snodeConfig, RemotingServer mqttRemotingServer,
public DefaultMqttMessageProcessor(MqttConfig mqttConfig, SnodeConfig snodeConfig,
RemotingServer mqttRemotingServer,
EnodeService enodeService, NnodeService nnodeService) {
this.mqttConfig = mqttConfig;
this.snodeConfig = snodeConfig;
......@@ -119,7 +118,7 @@ public class DefaultMqttMessageProcessor implements RequestProcessor {
@Override
public RemotingCommand processRequest(RemotingChannel remotingChannel, RemotingCommand message)
throws RemotingCommandException, UnsupportedEncodingException, InterruptedException, RemotingTimeoutException, MQClientException, RemotingSendRequestException, RemotingConnectException {
throws InterruptedException, RemotingTimeoutException, MQClientException, RemotingSendRequestException, RemotingConnectException {
MqttHeader mqttHeader = (MqttHeader) message.readCustomHeader();
MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.valueOf(mqttHeader.getMessageType()),
mqttHeader.isDup(), MqttQoS.valueOf(mqttHeader.getQosLevel()), mqttHeader.isRetain(),
......@@ -132,7 +131,6 @@ public class DefaultMqttMessageProcessor implements RequestProcessor {
mqttHeader.isHasPassword(), mqttHeader.isWillRetain(),
mqttHeader.getWillQos(), mqttHeader.isWillFlag(),
mqttHeader.isCleanSession(), mqttHeader.getKeepAliveTimeSeconds());
// MqttConnectPayload mqttConnectPayload = (MqttConnectPayload) message.getPayload();
MqttConnectPayload mqttConnectPayload = (MqttConnectPayload) MqttEncodeDecodeUtil.decode(message.getBody(), MqttConnectPayload.class);
mqttMessage = new MqttConnectMessage(fixedHeader, mqttConnectVariableHeader, mqttConnectPayload);
break;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.mqtt.processor;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import io.netty.handler.codec.mqtt.MqttPublishVariableHeader;
import io.netty.handler.codec.mqtt.MqttQoS;
import org.apache.rocketmq.common.MqttConfig;
import org.apache.rocketmq.common.SnodeConfig;
import org.apache.rocketmq.common.client.ClientManager;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.exception.MQClientException;
import org.apache.rocketmq.common.service.EnodeService;
import org.apache.rocketmq.common.service.NnodeService;
import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttMessageForwarder;
import org.apache.rocketmq.mqtt.service.WillMessageService;
import org.apache.rocketmq.mqtt.service.impl.MqttPushServiceImpl;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.apache.rocketmq.remoting.RemotingServer;
import org.apache.rocketmq.remoting.RequestProcessor;
import org.apache.rocketmq.remoting.exception.RemotingConnectException;
import org.apache.rocketmq.remoting.exception.RemotingSendRequestException;
import org.apache.rocketmq.remoting.exception.RemotingTimeoutException;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader;
import static io.netty.handler.codec.mqtt.MqttMessageType.PUBLISH;
public class InnerMqttMessageProcessor implements RequestProcessor {
private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME);
private final DefaultMqttMessageProcessor defaultMqttMessageProcessor;
private WillMessageService willMessageService;
private MqttPushServiceImpl mqttPushService;
private ClientManager iotClientManager;
private RemotingServer innerMqttRemotingServer;
private MqttConfig mqttConfig;
private SnodeConfig snodeConfig;
private EnodeService enodeService;
private NnodeService nnodeService;
private MqttMessageForwarder mqttMessageForwarder;
public InnerMqttMessageProcessor(DefaultMqttMessageProcessor defaultMqttMessageProcessor, RemotingServer innerMqttRemotingServer) {
this.defaultMqttMessageProcessor = defaultMqttMessageProcessor;
this.willMessageService = this.defaultMqttMessageProcessor.getWillMessageService();
this.mqttPushService = this.defaultMqttMessageProcessor.getMqttPushService();
this.iotClientManager = this.defaultMqttMessageProcessor.getIotClientManager();
this.innerMqttRemotingServer = innerMqttRemotingServer;
this.enodeService = this.defaultMqttMessageProcessor.getEnodeService();
this.nnodeService = this.defaultMqttMessageProcessor.getNnodeService();
this.mqttMessageForwarder = new MqttMessageForwarder(this);
}
@Override
public RemotingCommand processRequest(RemotingChannel remotingChannel, RemotingCommand message)
throws InterruptedException, RemotingTimeoutException, MQClientException, RemotingSendRequestException, RemotingConnectException {
MqttHeader mqttHeader = (MqttHeader) message.readCustomHeader();
if(mqttHeader.getMessageType().equals(PUBLISH)){
MqttFixedHeader fixedHeader = new MqttFixedHeader(MqttMessageType.valueOf(mqttHeader.getMessageType()),
mqttHeader.isDup(), MqttQoS.valueOf(mqttHeader.getQosLevel()), mqttHeader.isRetain(),
mqttHeader.getRemainingLength());
MqttPublishVariableHeader mqttPublishVariableHeader = new MqttPublishVariableHeader(mqttHeader.getTopicName(), mqttHeader.getPacketId());
MqttMessage mqttMessage = new MqttPublishMessage(fixedHeader, mqttPublishVariableHeader, Unpooled.copiedBuffer(message.getBody()));
return mqttMessageForwarder.handleMessage(mqttMessage, remotingChannel);
}else{
return defaultMqttMessageProcessor.processRequest(remotingChannel, message);
}
}
@Override
public boolean rejectRequest() {
return false;
}
public WillMessageService getWillMessageService() {
return willMessageService;
}
public MqttPushServiceImpl getMqttPushService() {
return mqttPushService;
}
public ClientManager getIotClientManager() {
return iotClientManager;
}
public MqttConfig getMqttConfig() {
return mqttConfig;
}
public void setMqttConfig(MqttConfig mqttConfig) {
this.mqttConfig = mqttConfig;
}
public SnodeConfig getSnodeConfig() {
return snodeConfig;
}
public void setSnodeConfig(SnodeConfig snodeConfig) {
this.snodeConfig = snodeConfig;
}
public EnodeService getEnodeService() {
return enodeService;
}
public void setEnodeService(EnodeService enodeService) {
this.enodeService = enodeService;
}
public NnodeService getNnodeService() {
return nnodeService;
}
public void setNnodeService(NnodeService nnodeService) {
this.nnodeService = nnodeService;
}
public DefaultMqttMessageProcessor getDefaultMqttMessageProcessor() {
return defaultMqttMessageProcessor;
}
}
......@@ -19,11 +19,8 @@ package org.apache.rocketmq.mqtt.service.impl;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.util.ReferenceCountUtil;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
......@@ -31,14 +28,11 @@ import org.apache.rocketmq.common.MqttConfig;
import org.apache.rocketmq.common.client.Client;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.protocol.RequestCode;
import org.apache.rocketmq.common.protocol.heartbeat.SubscriptionData;
import org.apache.rocketmq.common.utils.ThreadUtils;
import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl;
import org.apache.rocketmq.mqtt.constant.MqttConstant;
import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor;
import org.apache.rocketmq.mqtt.util.MqttUtil;
import org.apache.rocketmq.remoting.RemotingChannel;
import org.apache.rocketmq.remoting.netty.NettyChannelHandlerContextImpl;
import org.apache.rocketmq.remoting.netty.NettyChannelImpl;
......@@ -49,7 +43,7 @@ public class MqttPushServiceImpl {
private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME);
private ExecutorService pushMqttMessageExecutorService;
private DefaultMqttMessageProcessor defaultMqttMessageProcessor;
private static DefaultMqttMessageProcessor defaultMqttMessageProcessor;
public MqttPushServiceImpl(DefaultMqttMessageProcessor defaultMqttMessageProcessor, MqttConfig mqttConfig) {
this.defaultMqttMessageProcessor = defaultMqttMessageProcessor;
......@@ -63,21 +57,23 @@ public class MqttPushServiceImpl {
false);
}
public class MqttPushTask implements Runnable {
static class MqttPushTask implements Runnable {
private AtomicBoolean canceled = new AtomicBoolean(false);
private final ByteBuf message;
private final String topic;
private final Integer qos;
private boolean retain;
private Integer packetId;
private Client client;
public MqttPushTask(final String topic, final ByteBuf message, final Integer qos, boolean retain,
Integer packetId) {
Integer packetId, Client client) {
this.message = message;
this.topic = topic;
this.qos = qos;
this.retain = retain;
this.packetId = packetId;
this.client = client;
}
@Override
......@@ -86,32 +82,14 @@ public class MqttPushServiceImpl {
try {
RemotingCommand requestCommand = buildRequestCommand(topic, qos, retain, packetId);
//find those clients publishing the message to
IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager();
ConcurrentHashMap<String, ConcurrentHashMap<Client, Set<SubscriptionData>>> topic2SubscriptionTable = iotClientManager.getTopic2SubscriptionTable();
Set<Client> clients = new HashSet<>();
if (topic2SubscriptionTable.containsKey(MqttUtil.getRootTopic(topic))) {
ConcurrentHashMap<Client, Set<SubscriptionData>> client2SubscriptionDatas = topic2SubscriptionTable.get(MqttUtil.getRootTopic(topic));
for (Map.Entry<Client, Set<SubscriptionData>> entry : client2SubscriptionDatas.entrySet()) {
Set<SubscriptionData> subscriptionDatas = entry.getValue();
for (SubscriptionData subscriptionData : subscriptionDatas) {
if (MqttUtil.isMatch(subscriptionData.getTopic(), topic)) {
clients.add(entry.getKey());
break;
}
}
}
}
for (Client client : clients) {
RemotingChannel remotingChannel = client.getRemotingChannel();
if (client.getRemotingChannel() instanceof NettyChannelHandlerContextImpl) {
remotingChannel = new NettyChannelImpl(((NettyChannelHandlerContextImpl) client.getRemotingChannel()).getChannelHandlerContext().channel());
}
byte[] body = new byte[message.readableBytes()];
message.readBytes(body);
requestCommand.setBody(body);
defaultMqttMessageProcessor.getMqttRemotingServer().push(remotingChannel, requestCommand, MqttConstant.DEFAULT_TIMEOUT_MILLS);
RemotingChannel remotingChannel = client.getRemotingChannel();
if (client.getRemotingChannel() instanceof NettyChannelHandlerContextImpl) {
remotingChannel = new NettyChannelImpl(((NettyChannelHandlerContextImpl) client.getRemotingChannel()).getChannelHandlerContext().channel());
}
byte[] body = new byte[message.readableBytes()];
message.readBytes(body);
requestCommand.setBody(body);
defaultMqttMessageProcessor.getMqttRemotingServer().push(remotingChannel, requestCommand, MqttConstant.DEFAULT_TIMEOUT_MILLS);
} catch (Exception ex) {
log.warn("Exception was thrown when pushing MQTT message to topic: {}, exception={}", topic, ex.getMessage());
} finally {
......@@ -147,15 +125,21 @@ public class MqttPushServiceImpl {
}
public void pushMessageQos0(final String topic, final ByteBuf message) {
MqttPushTask pushTask = new MqttPushTask(topic, message, 0, false, 0);
pushMqttMessageExecutorService.submit(pushTask);
public void pushMessageQos0(final String topic, final ByteBuf message, Set<Client> clientsTobePublish) {
//For clientIds connected to the current snode
for (Client client : clientsTobePublish) {
MqttPushTask pushTask = new MqttPushTask(topic, message, 0, false, 0, client);
pushMqttMessageExecutorService.submit(pushTask);
}
}
public void pushMessageQos1(final String topic, final ByteBuf message, final Integer qos, boolean retain,
Integer packetId) {
MqttPushTask pushTask = new MqttPushTask(topic, message, qos, retain, packetId);
pushMqttMessageExecutorService.submit(pushTask);
public void pushMessageQos1(final String topic, final ByteBuf message, boolean retain, Integer packetId,
Set<Client> clientsTobePublish) {
for (Client client : clientsTobePublish) {
MqttPushTask pushTask = new MqttPushTask(topic, message, 1, retain, packetId, client);
pushMqttMessageExecutorService.submit(pushTask);
}
}
public void shutdown() {
......
......@@ -30,10 +30,17 @@ import org.apache.rocketmq.common.ThreadFactoryImpl;
import org.apache.rocketmq.common.client.ClientManager;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.protocol.RequestCode;
import org.apache.rocketmq.common.service.ClientService;
import org.apache.rocketmq.common.service.EnodeService;
import org.apache.rocketmq.common.service.MetricsService;
import org.apache.rocketmq.common.service.NnodeService;
import org.apache.rocketmq.common.service.PushService;
import org.apache.rocketmq.common.service.ScheduledService;
import org.apache.rocketmq.common.utils.ThreadUtils;
import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.logging.InternalLoggerFactory;
import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor;
import org.apache.rocketmq.mqtt.processor.InnerMqttMessageProcessor;
import org.apache.rocketmq.remoting.ClientConfig;
import org.apache.rocketmq.remoting.RemotingClient;
import org.apache.rocketmq.remoting.RemotingClientFactory;
......@@ -62,12 +69,6 @@ import org.apache.rocketmq.snode.processor.ConsumerManageProcessor;
import org.apache.rocketmq.snode.processor.HeartbeatProcessor;
import org.apache.rocketmq.snode.processor.PullMessageProcessor;
import org.apache.rocketmq.snode.processor.SendMessageProcessor;
import org.apache.rocketmq.common.service.ClientService;
import org.apache.rocketmq.common.service.EnodeService;
import org.apache.rocketmq.common.service.MetricsService;
import org.apache.rocketmq.common.service.NnodeService;
import org.apache.rocketmq.common.service.PushService;
import org.apache.rocketmq.common.service.ScheduledService;
import org.apache.rocketmq.snode.service.impl.ClientServiceImpl;
import org.apache.rocketmq.snode.service.impl.LocalEnodeServiceImpl;
import org.apache.rocketmq.snode.service.impl.MetricsServiceImpl;
......@@ -91,6 +92,7 @@ public class SnodeController {
private RemotingServer snodeServer;
private RemotingClient mqttRemotingClient;
private RemotingServer mqttRemotingServer;
private RemotingServer innerMqttRemotingServer;
private ExecutorService sendMessageExecutor;
private ExecutorService handleMqttMessageExecutor;
private ExecutorService heartbeatExecutor;
......@@ -101,7 +103,6 @@ public class SnodeController {
private ScheduledService scheduledService;
private ClientManager producerManager;
private ClientManager consumerManager;
// private ClientManager iotClientManager;
private SubscriptionManager subscriptionManager;
private ClientHousekeepingService clientHousekeepingService;
private SubscriptionGroupManager subscriptionGroupManager;
......@@ -111,6 +112,7 @@ public class SnodeController {
private PullMessageProcessor pullMessageProcessor;
private HeartbeatProcessor heartbeatProcessor;
private DefaultMqttMessageProcessor defaultMqttMessageProcessor;
private InnerMqttMessageProcessor innerMqttMessageProcessor;
private InterceptorGroup remotingServerInterceptorGroup;
private InterceptorGroup consumeMessageInterceptorGroup;
private InterceptorGroup sendMessageInterceptorGroup;
......@@ -118,14 +120,12 @@ public class SnodeController {
private ClientService clientService;
private SlowConsumerService slowConsumerService;
private MetricsService metricsService;
// private WillMessageService willMessageService;
// private MqttPushServiceImpl mqttPushService;
private final ScheduledExecutorService scheduledExecutorService = Executors
.newSingleThreadScheduledExecutor(new ThreadFactoryImpl(
"SnodeControllerScheduledThread"));
public SnodeController(SnodeConfig snodeConfig, MqttConfig mqttConfig) {
public SnodeController(SnodeConfig snodeConfig, MqttConfig mqttConfig) throws CloneNotSupportedException {
this.nettyClientConfig = snodeConfig.getNettyClientConfig();
this.nettyServerConfig = snodeConfig.getNettyServerConfig();
this.mqttServerConfig = mqttConfig.getMqttServerConfig();
......@@ -155,6 +155,14 @@ public class SnodeController {
this.mqttRemotingServer.init(this.mqttServerConfig, this.clientHousekeepingService);
this.mqttRemotingServer.registerInterceptorGroup(this.remotingServerInterceptorGroup);
}
this.innerMqttRemotingServer = RemotingServerFactory.getInstance().createRemotingServer(
RemotingUtil.MQTT_PROTOCOL);
ServerConfig innerMqttServerConfig = (ServerConfig)mqttServerConfig.clone();
innerMqttServerConfig.setListenPort(mqttServerConfig.getListenPort() - 1);
if (this.innerMqttRemotingServer != null) {
this.innerMqttRemotingServer.init(innerMqttServerConfig, this.clientHousekeepingService);
this.innerMqttRemotingServer.registerInterceptorGroup(this.remotingServerInterceptorGroup);
}
this.sendMessageExecutor = ThreadUtils.newThreadPoolExecutor(
snodeConfig.getSnodeSendMessageMinPoolSize(),
snodeConfig.getSnodeSendMessageMaxPoolSize(),
......@@ -212,7 +220,8 @@ public class SnodeController {
this.sendMessageProcessor = new SendMessageProcessor(this);
this.heartbeatProcessor = new HeartbeatProcessor(this);
this.pullMessageProcessor = new PullMessageProcessor(this);
this.defaultMqttMessageProcessor = new DefaultMqttMessageProcessor(this.mqttConfig, mqttRemotingServer, enodeService, nnodeService);
this.defaultMqttMessageProcessor = new DefaultMqttMessageProcessor(this.mqttConfig, this.snodeConfig, mqttRemotingServer, enodeService, nnodeService);
this.innerMqttMessageProcessor = new InnerMqttMessageProcessor(this.defaultMqttMessageProcessor, innerMqttRemotingServer);
this.pushService = new PushServiceImpl(this);
this.clientService = new ClientServiceImpl(this);
this.subscriptionManager = new SubscriptionManagerImpl();
......@@ -352,6 +361,9 @@ public class SnodeController {
if (mqttRemotingServer != null) {
this.mqttRemotingServer.registerProcessor(RequestCode.MQTT_MESSAGE, defaultMqttMessageProcessor, handleMqttMessageExecutor);
}
if (innerMqttRemotingServer != null) {
this.innerMqttRemotingServer.registerProcessor(RequestCode.MQTT_MESSAGE, innerMqttMessageProcessor, handleMqttMessageExecutor);
}
}
public void start() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册