diff --git a/common/src/main/java/org/apache/rocketmq/common/MqttConfig.java b/common/src/main/java/org/apache/rocketmq/common/MqttConfig.java index 72dcaae4e47270cf0fce501ef4456d355d764c3e..186f22cd6420f9c76cef77f9baa0447b4e95a676 100644 --- a/common/src/main/java/org/apache/rocketmq/common/MqttConfig.java +++ b/common/src/main/java/org/apache/rocketmq/common/MqttConfig.java @@ -52,6 +52,8 @@ public class MqttConfig { private long houseKeepingInterval = 10 * 1000; + private long persistOffsetInterval = 2 * 1000; + public int getListenPort() { return listenPort; } @@ -139,4 +141,12 @@ public class MqttConfig { public void setHouseKeepingInterval(long houseKeepingInterval) { this.houseKeepingInterval = houseKeepingInterval; } + + public long getPersistOffsetInterval() { + return persistOffsetInterval; + } + + public void setPersistOffsetInterval(long persistOffsetInterval) { + this.persistOffsetInterval = persistOffsetInterval; + } } diff --git a/mqtt/pom.xml b/mqtt/pom.xml index 13678904fe3ec4aa476aa2da73459260c92dcc67..eee314e48d5da096d286690c37be92b6e52291f9 100644 --- a/mqtt/pom.xml +++ b/mqtt/pom.xml @@ -103,5 +103,14 @@ org.eclipse.paho org.eclipse.paho.client.mqttv3 + + org.jctools + jctools-core + 2.1.2 + + + org.apache.commons + commons-lang3 + diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/IOTClientManagerImpl.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/IOTClientManagerImpl.java index 3465f3865dc309e12eecde9f2eaca6f969c2afce..8caec92e41ff3a8fe3c9ba6cc458eb5a18beed68 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/IOTClientManagerImpl.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/IOTClientManagerImpl.java @@ -21,11 +21,13 @@ import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; +import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import org.apache.rocketmq.common.client.Client; import org.apache.rocketmq.common.client.ClientManagerImpl; import org.apache.rocketmq.common.client.Subscription; import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.logging.InternalLogger; import org.apache.rocketmq.logging.InternalLoggerFactory; import org.apache.rocketmq.remoting.RemotingChannel; @@ -41,6 +43,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl { 1024); private final ConcurrentHashMap clientId2Subscription = new ConcurrentHashMap<>(1024); private final Map snode2MqttClient = new HashMap<>(); + private final ConcurrentHashMap>> processTable = new ConcurrentHashMap<>(); + public IOTClientManagerImpl() { } @@ -79,6 +83,9 @@ public class IOTClientManagerImpl extends ClientManagerImpl { } public void cleanSessionState(String clientId) { + if (clientId2Subscription.remove(clientId) == null) { + return; + } Map> toBeRemoveFromPersistentStore = new HashMap<>(); for (Iterator>> iterator = topic2Clients.entrySet().iterator(); iterator.hasNext(); ) { Map.Entry> next = iterator.next(); @@ -94,7 +101,7 @@ public class IOTClientManagerImpl extends ClientManagerImpl { } } //TODO update persistent store base on toBeRemoveFromPersistentStore - clientId2Subscription.remove(clientId); + //TODO update persistent store //TODO remove offline messages } @@ -118,4 +125,8 @@ public class IOTClientManagerImpl extends ClientManagerImpl { public Map getSnode2MqttClient() { return snode2MqttClient; } + + public ConcurrentHashMap>> getProcessTable() { + return processTable; + } } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/InFlightMessage.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/InFlightMessage.java index 2a619ab519cd46f0ec02f227b17630ae1741d7d0..57562dd672d50a6216d80b7ae8ce0eed4b4a81d8 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/InFlightMessage.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/InFlightMessage.java @@ -16,16 +16,47 @@ */ package org.apache.rocketmq.mqtt.client; -import io.netty.buffer.ByteBuf; +import org.apache.rocketmq.common.protocol.route.BrokerData; public class InFlightMessage { - final String topic; - final Integer pushQos; - final ByteBuf payload; + private final String topic; + private final Integer pushQos; + private final BrokerData brokerData; + private final byte[] body; + private final String messageId; + private final long queueOffset; - InFlightMessage(String topic, Integer pushQos, ByteBuf payload) { + InFlightMessage(String topic, Integer pushQos, byte[] body, BrokerData brokerData, String messageId, + long queueOffset) { this.topic = topic; this.pushQos = pushQos; - this.payload = payload; + this.body = body; + this.brokerData = brokerData; + this.messageId = messageId; + this.queueOffset = queueOffset; + } + + public String getTopic() { + return topic; + } + + public BrokerData getBrokerData() { + return brokerData; + } + + public String getMessageId() { + return messageId; + } + + public long getQueueOffset() { + return queueOffset; + } + + public Integer getPushQos() { + return pushQos; + } + + public byte[] getBody() { + return body; } } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/MQTTSession.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/MQTTSession.java index 7bf9db426315a817de4134ead67cc14d54e03c2a..7ed89562e6b2bda57bb530186d46535c7e75b5ed 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/MQTTSession.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/client/MQTTSession.java @@ -16,32 +16,47 @@ */ package org.apache.rocketmq.mqtt.client; -import io.netty.buffer.ByteBuf; import java.util.HashMap; import java.util.Hashtable; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.DelayQueue; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.apache.rocketmq.common.client.Client; import org.apache.rocketmq.common.client.ClientRole; +import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.common.message.MessageExt; +import org.apache.rocketmq.common.protocol.RequestCode; +import org.apache.rocketmq.common.protocol.route.BrokerData; +import org.apache.rocketmq.logging.InternalLogger; +import org.apache.rocketmq.logging.InternalLoggerFactory; +import org.apache.rocketmq.mqtt.constant.MqttConstant; import org.apache.rocketmq.mqtt.exception.MqttRuntimeException; import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; +import org.apache.rocketmq.mqtt.util.MqttUtil; import org.apache.rocketmq.remoting.RemotingChannel; +import org.apache.rocketmq.remoting.netty.NettyChannelHandlerContextImpl; +import org.apache.rocketmq.remoting.netty.NettyChannelImpl; +import org.apache.rocketmq.remoting.protocol.RemotingCommand; import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; public class MQTTSession extends Client { + private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); + private boolean cleanSession; private boolean isConnected; private boolean willFlag; + + private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; private final AtomicInteger inflightSlots = new AtomicInteger(10); private final Map inflightWindow = new HashMap<>(); private final DelayQueue inflightTimeouts = new DelayQueue<>(); private static final int FLIGHT_BEFORE_RESEND_MS = 5_000; - private final AtomicInteger lastPacketId = new AtomicInteger(0); private Hashtable inUsePacketIds = new Hashtable(); private int nextPacketId = 0; @@ -74,11 +89,13 @@ public class MQTTSession extends Client { } } - public MQTTSession(String clientId, ClientRole clientRole, Set groups, boolean isConnected, boolean cleanSession, - RemotingChannel remotingChannel, long lastUpdateTimestamp) { + public MQTTSession(String clientId, ClientRole clientRole, Set groups, boolean isConnected, + boolean cleanSession, RemotingChannel remotingChannel, long lastUpdateTimestamp, + DefaultMqttMessageProcessor defaultMqttMessageProcessor) { super(clientId, clientRole, groups, remotingChannel, lastUpdateTimestamp); this.isConnected = isConnected; this.cleanSession = cleanSession; + this.defaultMqttMessageProcessor = defaultMqttMessageProcessor; } @Override @@ -117,21 +134,88 @@ public class MQTTSession extends Client { this.willFlag = willFlag; } - public void pushMessageAtQos(MqttHeader mqttHeader, ByteBuf payload, - DefaultMqttMessageProcessor defaultMqttMessageProcessor) { + public void pushMessageQos0(MqttHeader mqttHeader, byte[] body) { + pushMessage2Client(mqttHeader, body); + } - if (mqttHeader.getQosLevel() > 0) { + public void pushMessageQos1(MqttHeader mqttHeader, MessageExt messageExt, BrokerData brokerData) { + if (inflightSlots.get() > 0) { inflightSlots.decrementAndGet(); mqttHeader.setPacketId(getNextPacketId()); - inflightWindow.put(mqttHeader.getPacketId(), new InFlightMessage(mqttHeader.getTopicName(), mqttHeader.getQosLevel(), payload)); + inflightWindow.put(mqttHeader.getPacketId(), new InFlightMessage(mqttHeader.getTopicName(), mqttHeader.getQosLevel(), messageExt.getBody(), brokerData, messageExt.getMsgId(), messageExt.getQueueOffset())); inflightTimeouts.add(new InFlightPacket(mqttHeader.getPacketId(), FLIGHT_BEFORE_RESEND_MS)); + put2processTable(((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getProcessTable(), brokerData.getBrokerName(), MqttUtil.getRootTopic(mqttHeader.getTopicName()), messageExt); + pushMessage2Client(mqttHeader, messageExt.getBody()); } - defaultMqttMessageProcessor.getMqttPushService().pushMessageQos(mqttHeader, payload, this); } - public void pubAckReceived(int ackPacketId) { - inflightWindow.remove(ackPacketId); + + public InFlightMessage pubAckReceived(int ackPacketId) { + InFlightMessage remove = inflightWindow.remove(ackPacketId); + String rootTopic = MqttUtil.getRootTopic(remove.getTopic()); + ConcurrentHashMap>> processTable = ((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getProcessTable(); + ConcurrentHashMap> map = processTable.get(remove.getBrokerData().getBrokerName()); + if (map != null) { + TreeMap treeMap = map.get(rootTopic + "@" + this.getClientId()); + if (treeMap != null) { + treeMap.remove(remove.getQueueOffset()); + } + } inflightSlots.incrementAndGet(); + releasePacketId(ackPacketId); + return remove; + } + + private void pushMessage2Client(MqttHeader mqttHeader, byte[] body) { + try { + //set remaining length + int remainingLength = mqttHeader.getTopicName().getBytes().length + body.length; + if (mqttHeader.getQosLevel() > 0) { + remainingLength += 2; //add packetId length + } + mqttHeader.setRemainingLength(remainingLength); + RemotingCommand requestCommand = RemotingCommand.createRequestCommand(RequestCode.MQTT_MESSAGE, mqttHeader); + + RemotingChannel remotingChannel = this.getRemotingChannel(); + if (this.getRemotingChannel() instanceof NettyChannelHandlerContextImpl) { + remotingChannel = new NettyChannelImpl(((NettyChannelHandlerContextImpl) this.getRemotingChannel()).getChannelHandlerContext().channel()); + } + requestCommand.setBody(body); + this.defaultMqttMessageProcessor.getMqttRemotingServer().push(remotingChannel, requestCommand, MqttConstant.DEFAULT_TIMEOUT_MILLS); + } catch (Exception ex) { + log.warn("Exception was thrown when pushing MQTT message. Topic: {}, clientId:{}, exception={}", mqttHeader.getTopicName(), this.getClientId(), ex.getMessage()); + } + } + + private void put2processTable( + ConcurrentHashMap>> processTable, + String brokerName, + String rootTopic, + MessageExt messageExt) { + ConcurrentHashMap> map; + TreeMap treeMap; + String offsetKey = rootTopic + "@" + this.getClientId(); + if (processTable.contains(brokerName)) { + map = processTable.get(brokerName); + if (map.contains(offsetKey)) { + treeMap = map.get(offsetKey); + treeMap.putIfAbsent(messageExt.getQueueOffset(), messageExt); + } else { + treeMap = new TreeMap<>(); + treeMap.put(messageExt.getQueueOffset(), messageExt); + map.putIfAbsent(offsetKey, treeMap); + } + } else { + map = new ConcurrentHashMap<>(); + treeMap = new TreeMap<>(); + treeMap.put(messageExt.getQueueOffset(), messageExt); + map.put(offsetKey, treeMap); + ConcurrentHashMap> old = processTable.putIfAbsent(brokerName, map); + if (old != null) { + old.putIfAbsent(offsetKey, treeMap); + } + } } + private synchronized void releasePacketId(int msgId) { this.inUsePacketIds.remove(new Integer(msgId)); } @@ -159,4 +243,13 @@ public class MQTTSession extends Client { this.inUsePacketIds.put(id, id); return this.nextPacketId; } + + public AtomicInteger getInflightSlots() { + return inflightSlots; + } + + public Map getInflightWindow() { + return inflightWindow; + } + } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttMessageForwardHandler.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttMessageForwardHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..6fb715f1e96062b42c6c6855466ea50ad15cfca7 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttMessageForwardHandler.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.mqtt.mqtthandler.impl; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.mqtt.MqttFixedHeader; +import io.netty.handler.codec.mqtt.MqttMessage; +import io.netty.handler.codec.mqtt.MqttMessageType; +import io.netty.handler.codec.mqtt.MqttPublishMessage; +import io.netty.handler.codec.mqtt.MqttPublishVariableHeader; +import io.netty.handler.codec.mqtt.MqttQoS; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import org.apache.rocketmq.common.client.Client; +import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.logging.InternalLogger; +import org.apache.rocketmq.logging.InternalLoggerFactory; +import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; +import org.apache.rocketmq.mqtt.client.MQTTSession; +import org.apache.rocketmq.mqtt.mqtthandler.MessageHandler; +import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; +import org.apache.rocketmq.mqtt.processor.InnerMqttMessageProcessor; +import org.apache.rocketmq.mqtt.task.MqttPushTask; +import org.apache.rocketmq.mqtt.transfer.TransferDataQos1; +import org.apache.rocketmq.mqtt.util.orderedexecutor.SafeRunnable; +import org.apache.rocketmq.remoting.RemotingChannel; +import org.apache.rocketmq.remoting.protocol.RemotingCommand; +import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; + +public class MqttMessageForwardHandler implements MessageHandler { + private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); + private final InnerMqttMessageProcessor innerMqttMessageProcessor; + private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; + + public MqttMessageForwardHandler(InnerMqttMessageProcessor processor) { + this.innerMqttMessageProcessor = processor; + this.defaultMqttMessageProcessor = innerMqttMessageProcessor.getDefaultMqttMessageProcessor(); + } + + /** + * handle messages transferred from other nodes + * + * @param message the message that transferred from other node + */ + @Override public RemotingCommand handleMessage(MqttMessage message, RemotingChannel remotingChannel) { + MqttPublishMessage mqttPublishMessage = (MqttPublishMessage) message; + MqttFixedHeader fixedHeader = mqttPublishMessage.fixedHeader(); + MqttPublishVariableHeader variableHeader = mqttPublishMessage.variableHeader(); + ByteBuf payload = mqttPublishMessage.payload(); + byte[] body = new byte[payload.readableBytes()]; + payload.readBytes(body); + + if (fixedHeader.qosLevel().equals(MqttQoS.AT_MOST_ONCE)) { + MqttHeader mqttHeaderQos0 = new MqttHeader(); + mqttHeaderQos0.setTopicName(variableHeader.topicName()); + mqttHeaderQos0.setMessageType(MqttMessageType.PUBLISH.value()); + mqttHeaderQos0.setQosLevel(MqttQoS.AT_MOST_ONCE.value()); + mqttHeaderQos0.setRetain(false); //TODO set to false temporarily, need to be implemented later. + Set clientsTobePublish = findCurrentNodeClientsTobePublish(variableHeader.topicName(), (IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()); + for (Client client : clientsTobePublish) { + ((MQTTSession) client).pushMessageQos0(mqttHeaderQos0, body, this.defaultMqttMessageProcessor); + } + } else if (fixedHeader.qosLevel().equals(MqttQoS.AT_LEAST_ONCE)) { + TransferDataQos1 transferDataQos1 = TransferDataQos1.decode(body, TransferDataQos1.class); + //TODO : find clients that subscribed this topic from current node + List clientsTobePublished = new ArrayList<>(); + for (Client client : clientsTobePublished) { + //For each client, wrap a task: + //Pull message one by one, and push them if current client match. + MqttHeader mqttHeaderQos1 = new MqttHeader(); + mqttHeaderQos1.setTopicName(variableHeader.topicName()); + mqttHeaderQos1.setMessageType(MqttMessageType.PUBLISH.value()); + mqttHeaderQos1.setRetain(false); //TODO set to false temporarily, need to be implemented later. + MqttPushTask mqttPushTask = new MqttPushTask(this.defaultMqttMessageProcessor, mqttHeaderQos1, client, transferDataQos1.getBrokerData()); + //add task to orderedExecutor + this.defaultMqttMessageProcessor.getOrderedExecutor().executeOrdered(client.getClientId(), SafeRunnable.safeRun(mqttPushTask)); + } + return doResponse(fixedHeader); + } + return null; + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttMessageForwarder.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttMessageForwarder.java deleted file mode 100644 index 5718fadb83f96bff4e69a59f74ed354f5cbca577..0000000000000000000000000000000000000000 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttMessageForwarder.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.rocketmq.mqtt.mqtthandler.impl; - -import io.netty.buffer.ByteBuf; -import io.netty.handler.codec.mqtt.MqttFixedHeader; -import io.netty.handler.codec.mqtt.MqttMessage; -import io.netty.handler.codec.mqtt.MqttPublishMessage; -import io.netty.handler.codec.mqtt.MqttPublishVariableHeader; -import io.netty.handler.codec.mqtt.MqttQoS; -import java.util.Set; -import org.apache.rocketmq.common.client.Client; -import org.apache.rocketmq.common.constant.LoggerName; -import org.apache.rocketmq.logging.InternalLogger; -import org.apache.rocketmq.logging.InternalLoggerFactory; -import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; -import org.apache.rocketmq.mqtt.mqtthandler.MessageHandler; -import org.apache.rocketmq.mqtt.processor.InnerMqttMessageProcessor; -import org.apache.rocketmq.remoting.RemotingChannel; -import org.apache.rocketmq.remoting.protocol.RemotingCommand; - -public class MqttMessageForwarder implements MessageHandler { - private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); - private final InnerMqttMessageProcessor innerMqttMessageProcessor; - - public MqttMessageForwarder(InnerMqttMessageProcessor processor) { - this.innerMqttMessageProcessor = processor; - } - - /** - * handle PUBLISH message from client - * - * @param message - * @return whether the message is handled successfully - */ - @Override public RemotingCommand handleMessage(MqttMessage message, RemotingChannel remotingChannel) { - MqttPublishMessage mqttPublishMessage = (MqttPublishMessage) message; - MqttFixedHeader fixedHeader = mqttPublishMessage.fixedHeader(); - MqttPublishVariableHeader variableHeader = mqttPublishMessage.variableHeader(); - if (fixedHeader.qosLevel().equals(MqttQoS.AT_MOST_ONCE)) { - ByteBuf payload = mqttPublishMessage.payload(); - //Publish message to clients - Set clientsTobePublish = findCurrentNodeClientsTobePublish(variableHeader.topicName(), (IOTClientManagerImpl) this.innerMqttMessageProcessor.getIotClientManager()); - innerMqttMessageProcessor.getDefaultMqttMessageProcessor().getMqttPushService().pushMessageQos0(variableHeader.topicName(), payload, clientsTobePublish); - }else if(fixedHeader.qosLevel().equals(MqttQoS.AT_LEAST_ONCE)){ - //TODO - } - return doResponse(fixedHeader); - } -} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPubackMessageHandler.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPubackMessageHandler.java index 71d94589955b151bc319d55b2af108dec9c851e4..d871be0e7649bc9de81b3be3af3bd491c30dd4a8 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPubackMessageHandler.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPubackMessageHandler.java @@ -18,13 +18,23 @@ package org.apache.rocketmq.mqtt.mqtthandler.impl; import io.netty.handler.codec.mqtt.MqttMessage; +import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader; +import io.netty.handler.codec.mqtt.MqttMessageType; +import io.netty.handler.codec.mqtt.MqttPubAckMessage; import org.apache.rocketmq.common.constant.LoggerName; import org.apache.rocketmq.logging.InternalLogger; import org.apache.rocketmq.logging.InternalLoggerFactory; +import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; +import org.apache.rocketmq.mqtt.client.InFlightMessage; +import org.apache.rocketmq.mqtt.client.MQTTSession; +import org.apache.rocketmq.mqtt.exception.WrongMessageTypeException; import org.apache.rocketmq.mqtt.mqtthandler.MessageHandler; import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; +import org.apache.rocketmq.mqtt.task.MqttPushTask; +import org.apache.rocketmq.mqtt.util.orderedexecutor.SafeRunnable; import org.apache.rocketmq.remoting.RemotingChannel; import org.apache.rocketmq.remoting.protocol.RemotingCommand; +import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; public class MqttPubackMessageHandler implements MessageHandler { @@ -43,6 +53,24 @@ public class MqttPubackMessageHandler implements MessageHandler { * @return */ @Override public RemotingCommand handleMessage(MqttMessage message, RemotingChannel remotingChannel) { + if (!(message instanceof MqttPubAckMessage)) { + log.error("Wrong message type! Expected type: PUBACK but {} was received. MqttMessage={}", message.fixedHeader().messageType(), message.toString()); + throw new WrongMessageTypeException("Wrong message type exception."); + } + MqttPubAckMessage mqttPubAckMessage = (MqttPubAckMessage) message; + MqttMessageIdVariableHeader variableHeader = mqttPubAckMessage.variableHeader(); + IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager(); + MQTTSession client = (MQTTSession) iotClientManager.getClient(IOTClientManagerImpl.IOT_GROUP, remotingChannel); + + InFlightMessage removedMessage = client.pubAckReceived(variableHeader.messageId()); + MqttHeader mqttHeader = new MqttHeader(); + mqttHeader.setTopicName(removedMessage.getTopic()); + mqttHeader.setMessageType(MqttMessageType.PUBLISH.value()); + mqttHeader.setDup(false); + mqttHeader.setRetain(false); //TODO set to false temporarily, need to be implemented. + MqttPushTask task = new MqttPushTask(defaultMqttMessageProcessor, mqttHeader, client, removedMessage.getBrokerData()); + //add task to orderedExecutor + this.defaultMqttMessageProcessor.getOrderedExecutor().executeOrdered(client.getClientId(), SafeRunnable.safeRun(task)); return null; } } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPublishMessageHandler.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPublishMessageHandler.java index a8900ce3d6bb28978f42a7df304bc0adc241606f..be53e58a1235543fe64cd7d4c9d7f820c5c9511c 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPublishMessageHandler.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/mqtthandler/impl/MqttPublishMessageHandler.java @@ -23,36 +23,26 @@ import io.netty.handler.codec.mqtt.MqttMessage; import io.netty.handler.codec.mqtt.MqttMessageType; import io.netty.handler.codec.mqtt.MqttPublishMessage; import io.netty.handler.codec.mqtt.MqttPublishVariableHeader; +import io.netty.handler.codec.mqtt.MqttQoS; import io.netty.util.ReferenceCountUtil; -import java.nio.ByteBuffer; -import java.util.Enumeration; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import org.apache.rocketmq.common.MqttConfig; import org.apache.rocketmq.common.SnodeConfig; import org.apache.rocketmq.common.client.Client; -import org.apache.rocketmq.common.client.Subscription; import org.apache.rocketmq.common.constant.LoggerName; import org.apache.rocketmq.common.exception.MQClientException; import org.apache.rocketmq.common.message.Message; import org.apache.rocketmq.common.message.MessageAccessor; import org.apache.rocketmq.common.message.MessageConst; import org.apache.rocketmq.common.message.MessageDecoder; -import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.protocol.RequestCode; -import org.apache.rocketmq.common.protocol.ResponseCode; -import org.apache.rocketmq.common.protocol.header.GetMaxOffsetRequestHeader; import org.apache.rocketmq.common.protocol.header.SendMessageRequestHeader; -import org.apache.rocketmq.common.protocol.header.SendMessageResponseHeader; -import org.apache.rocketmq.common.protocol.heartbeat.MqttSubscriptionData; -import org.apache.rocketmq.common.protocol.heartbeat.SubscriptionData; import org.apache.rocketmq.common.protocol.route.BrokerData; import org.apache.rocketmq.common.protocol.route.TopicRouteData; import org.apache.rocketmq.common.service.EnodeService; @@ -61,12 +51,14 @@ import org.apache.rocketmq.logging.InternalLoggerFactory; import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; import org.apache.rocketmq.mqtt.client.MQTTSession; import org.apache.rocketmq.mqtt.constant.MqttConstant; +import org.apache.rocketmq.mqtt.exception.MqttRuntimeException; import org.apache.rocketmq.mqtt.exception.WrongMessageTypeException; import org.apache.rocketmq.mqtt.mqtthandler.MessageHandler; import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; +import org.apache.rocketmq.mqtt.task.MqttPushTask; import org.apache.rocketmq.mqtt.util.MqttUtil; +import org.apache.rocketmq.mqtt.util.orderedexecutor.SafeRunnable; import org.apache.rocketmq.remoting.RemotingChannel; -import org.apache.rocketmq.remoting.exception.RemotingCommandException; import org.apache.rocketmq.remoting.exception.RemotingConnectException; import org.apache.rocketmq.remoting.exception.RemotingSendRequestException; import org.apache.rocketmq.remoting.exception.RemotingTimeoutException; @@ -106,19 +98,20 @@ public class MqttPublishMessageHandler implements MessageHandler { } ByteBuf payload = mqttPublishMessage.payload(); - MqttHeader mqttHeader = new MqttHeader(); - mqttHeader.setTopicName(variableHeader.topicName()); - mqttHeader.setMessageType(MqttMessageType.PUBLISH.value()); - mqttHeader.setDup(false); - mqttHeader.setQosLevel(fixedHeader.qosLevel().value()); - mqttHeader.setRetain(false); //set to false tempararily, need to be implemented. + switch (fixedHeader.qosLevel()) { case AT_MOST_ONCE: //For clients connected to the current snode and isConnected is true Set clientsTobePublish = findCurrentNodeClientsTobePublish(variableHeader.topicName(), this.iotClientManager); - + byte[] body = new byte[payload.readableBytes()]; + payload.readBytes(body); + MqttHeader mqttHeaderQos0 = new MqttHeader(); + mqttHeaderQos0.setTopicName(variableHeader.topicName()); + mqttHeaderQos0.setMessageType(MqttMessageType.PUBLISH.value()); + mqttHeaderQos0.setQosLevel(MqttQoS.AT_MOST_ONCE.value()); + mqttHeaderQos0.setRetain(false); //TODO set to false temporarily, need to be implemented later. for (Client client : clientsTobePublish) { - ((MQTTSession) client).pushMessageAtQos(mqttHeader, payload, this.defaultMqttMessageProcessor); + ((MQTTSession) client).pushMessageQos0(mqttHeaderQos0, body, this.defaultMqttMessageProcessor); } //For clients that connected to other snodes, transfer the message to them @@ -134,6 +127,7 @@ public class MqttPublishMessageHandler implements MessageHandler { } finally { ReferenceCountUtil.release(message); } + case AT_LEAST_ONCE: // Store msg and invoke callback to publish msg to subscribers // 1. Check if the root topic has been created @@ -156,95 +150,30 @@ public class MqttPublishMessageHandler implements MessageHandler { responseFuture.whenComplete((data, ex) -> { if (ex == null) { //publish msg to subscribers - try { - SendMessageResponseHeader responseHeader = (SendMessageResponseHeader) data.decodeCommandCustomHeader(SendMessageResponseHeader.class); - //find clients that subscribed this topic from all snodes and put it to map. - Map> snodeAddr2Clients = new HashMap<>(); - - //for clientIds connected to current snode, trigger the logic of push message - List clients = snodeAddr2Clients.get(this.defaultMqttMessageProcessor.getSnodeConfig().getSnodeIP1()); - for (Client client : clients) { - Subscription subscription = this.iotClientManager.getSubscriptionByClientId(client.getClientId()); - ConcurrentHashMap subscriptionTable = subscription.getSubscriptionTable(); - - //for each client, wrap a task: pull messages from commitlog one by one, and push them if current client subscribe it. - Runnable task = new Runnable() { - - @Override - public void run() { - //compare current consumeOffset of rootTopic@clientId with maxOffset, pull message if consumeOffset < maxOffset - long maxOffsetInQueue; - try { - maxOffsetInQueue = getMaxOffset(brokerData.getBrokerName(), rootTopic); - long consumeOffset = enodeService.queryOffset(brokerData.getBrokerName(), client.getClientId(), rootTopic, 0); - long i = consumeOffset; - while (i < maxOffsetInQueue) { - //TODO query messages from enode - RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null); - ByteBuffer byteBuffer = ByteBuffer.wrap(response.getBody()); - MessageExt messageExt = MessageDecoder.clientDecode(byteBuffer, true); - final String realTopic = messageExt.getProperty(MessageConst.PROPERTY_REAL_TOPIC); - - boolean needSkip = needSkip(realTopic); - if (needSkip) { - log.info("Current client doesn't subscribe topic:{}, skip this message", realTopic); - maxOffsetInQueue = getMaxOffset(brokerData.getBrokerName(), rootTopic); - i += 1; - continue; - } - Integer pushQos = lowerQosToTheSubscriptionDesired(realTopic, Integer.valueOf(messageExt.getProperty(MqttConstant.PROPERTY_MQTT_QOS))); - mqttHeader.setQosLevel(pushQos); - //push message - MQTTSession mqttSession = (MQTTSession) client; - mqttSession.pushMessageAtQos(mqttHeader, payload, defaultMqttMessageProcessor); - - maxOffsetInQueue = getMaxOffset(brokerData.getBrokerName(), rootTopic); - i += 1; - } - } catch (Exception ex) { - log.error("Get max offset error, remoting: {} error: {} ", remotingChannel.remoteAddress(), ex); - } - } - - private boolean needSkip(final String realTopic) { - Enumeration topicFilters = subscriptionTable.keys(); - while (topicFilters.hasMoreElements()) { - if (MqttUtil.isMatch(topicFilters.nextElement(), realTopic)) { - return false; - } - } - return true; - } - - private Integer lowerQosToTheSubscriptionDesired(String publishTopic, - Integer publishingQos) { - Integer pushQos = Integer.valueOf(publishingQos); - Iterator> iterator = subscriptionTable.entrySet().iterator(); - Integer maxRequestedQos = 0; - while (iterator.hasNext()) { - final String topicFilter = iterator.next().getKey(); - if (MqttUtil.isMatch(topicFilter, publishTopic)) { - MqttSubscriptionData mqttSubscriptionData = (MqttSubscriptionData) iterator.next().getValue(); - maxRequestedQos = mqttSubscriptionData.getQos() > maxRequestedQos ? mqttSubscriptionData.getQos() : maxRequestedQos; - } - } - if (publishingQos > maxRequestedQos) { - pushQos = maxRequestedQos; - } - return pushQos; - } - }; - - } - //for clientIds connected to other snodes, forward msg - } catch (RemotingCommandException e) { - e.printStackTrace(); + //TODO find clients that subscribed this topic from all snodes and put it to map. + Map> snodeAddr2Clients = new HashMap<>(); + + //for clientIds connected to current snode, trigger the logic of push message + List clients = snodeAddr2Clients.get(this.defaultMqttMessageProcessor.getSnodeConfig().getSnodeIP1()); + for (Client client : clients) { + //For each client, wrap a task: + //Pull message one by one, and push them if current client match. + MqttHeader mqttHeaderQos1 = new MqttHeader(); + mqttHeaderQos1.setTopicName(variableHeader.topicName()); + mqttHeaderQos1.setMessageType(MqttMessageType.PUBLISH.value()); + mqttHeaderQos1.setRetain(false); //TODO set to false temporarily, need to be implemented later. + MqttPushTask mqttPushTask = new MqttPushTask(defaultMqttMessageProcessor, mqttHeaderQos1, client, brokerData); + //add task to orderedExecutor + this.defaultMqttMessageProcessor.getOrderedExecutor().executeOrdered(client.getClientId(), SafeRunnable.safeRun(mqttPushTask)); } + //TODO for clientIds connected to other snodes, forward msg } else { log.error("Store Qos=1 Message error: {}", ex); } }); + case EXACTLY_ONCE: + throw new MqttRuntimeException("Qos = 2 messages are not supported yet."); } return doResponse(fixedHeader); } @@ -297,13 +226,4 @@ public class MqttPublishMessageHandler implements MessageHandler { return request; } - private long getMaxOffset(String enodeName, - String topic) throws InterruptedException, RemotingTimeoutException, RemotingCommandException, RemotingSendRequestException, RemotingConnectException { - GetMaxOffsetRequestHeader requestHeader = new GetMaxOffsetRequestHeader(); - requestHeader.setTopic(topic); - requestHeader.setQueueId(0); - RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.GET_MAX_OFFSET, requestHeader); - - return this.defaultMqttMessageProcessor.getEnodeService().getMaxOffsetInQueue(enodeName, topic, 0, request); - } } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/DefaultMqttMessageProcessor.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/DefaultMqttMessageProcessor.java index fb520e0da2f1c2bb084f346ffbe7685a6e79a28e..803648b2935585f23c0a4f4f7d5dcc5e3b307185 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/DefaultMqttMessageProcessor.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/DefaultMqttMessageProcessor.java @@ -39,6 +39,7 @@ import org.apache.rocketmq.common.constant.LoggerName; import org.apache.rocketmq.common.exception.MQClientException; import org.apache.rocketmq.common.service.EnodeService; import org.apache.rocketmq.common.service.NnodeService; +import org.apache.rocketmq.common.service.ScheduledService; import org.apache.rocketmq.logging.InternalLogger; import org.apache.rocketmq.logging.InternalLoggerFactory; import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; @@ -55,8 +56,9 @@ import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttPubrelMessageHandler; import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttSubscribeMessageHandler; import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttUnsubscribeMessagHandler; import org.apache.rocketmq.mqtt.service.WillMessageService; -import org.apache.rocketmq.mqtt.service.impl.MqttPushServiceImpl; +import org.apache.rocketmq.mqtt.service.impl.MqttScheduledServiceImpl; import org.apache.rocketmq.mqtt.service.impl.WillMessageServiceImpl; +import org.apache.rocketmq.mqtt.util.orderedexecutor.OrderedExecutor; import org.apache.rocketmq.remoting.RemotingChannel; import org.apache.rocketmq.remoting.RemotingServer; import org.apache.rocketmq.remoting.RequestProcessor; @@ -74,7 +76,6 @@ public class DefaultMqttMessageProcessor implements RequestProcessor { private static final int MIN_AVAILABLE_VERSION = 3; private static final int MAX_AVAILABLE_VERSION = 4; private WillMessageService willMessageService; - private MqttPushServiceImpl mqttPushService; private ClientManager iotClientManager; private RemotingServer mqttRemotingServer; private MqttClientHousekeepingService mqttClientHousekeepingService; @@ -82,6 +83,9 @@ public class DefaultMqttMessageProcessor implements RequestProcessor { private SnodeConfig snodeConfig; private EnodeService enodeService; private NnodeService nnodeService; + private ScheduledService mqttScheduledService; + + private final OrderedExecutor orderedExecutor; public DefaultMqttMessageProcessor(MqttConfig mqttConfig, SnodeConfig snodeConfig, RemotingServer mqttRemotingServer, @@ -89,7 +93,6 @@ public class DefaultMqttMessageProcessor implements RequestProcessor { this.mqttConfig = mqttConfig; this.snodeConfig = snodeConfig; this.willMessageService = new WillMessageServiceImpl(); - this.mqttPushService = new MqttPushServiceImpl(this, mqttConfig); this.iotClientManager = new IOTClientManagerImpl(); this.mqttRemotingServer = mqttRemotingServer; this.enodeService = enodeService; @@ -97,6 +100,10 @@ public class DefaultMqttMessageProcessor implements RequestProcessor { this.mqttClientHousekeepingService = new MqttClientHousekeepingService(iotClientManager); this.mqttClientHousekeepingService.start(mqttConfig.getHouseKeepingInterval()); + this.orderedExecutor = OrderedExecutor.newBuilder().name("PushMessageToConsumerThreads").numThreads(mqttConfig.getPushMqttMessageMaxPoolSize()).build(); + this.mqttScheduledService = new MqttScheduledServiceImpl(this); + mqttScheduledService.startScheduleTask(); + registerMessageHandler(MqttMessageType.CONNECT, new MqttConnectMessageHandler(this)); registerMessageHandler(MqttMessageType.DISCONNECT, @@ -164,10 +171,6 @@ public class DefaultMqttMessageProcessor implements RequestProcessor { return willMessageService; } - public MqttPushServiceImpl getMqttPushService() { - return mqttPushService; - } - public ClientManager getIotClientManager() { return iotClientManager; } @@ -207,4 +210,8 @@ public class DefaultMqttMessageProcessor implements RequestProcessor { public void setNnodeService(NnodeService nnodeService) { this.nnodeService = nnodeService; } + + public OrderedExecutor getOrderedExecutor() { + return orderedExecutor; + } } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/InnerMqttMessageProcessor.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/InnerMqttMessageProcessor.java index 0f50e8e252d10fb900c60d8b1e50cc7056d37b7e..90171241e4b4d0639d9413da37f53a1dfd8b7acb 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/InnerMqttMessageProcessor.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/processor/InnerMqttMessageProcessor.java @@ -33,9 +33,8 @@ import org.apache.rocketmq.common.service.EnodeService; import org.apache.rocketmq.common.service.NnodeService; import org.apache.rocketmq.logging.InternalLogger; import org.apache.rocketmq.logging.InternalLoggerFactory; -import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttMessageForwarder; +import org.apache.rocketmq.mqtt.mqtthandler.impl.MqttMessageForwardHandler; import org.apache.rocketmq.mqtt.service.WillMessageService; -import org.apache.rocketmq.mqtt.service.impl.MqttPushServiceImpl; import org.apache.rocketmq.remoting.RemotingChannel; import org.apache.rocketmq.remoting.RemotingServer; import org.apache.rocketmq.remoting.RequestProcessor; @@ -52,24 +51,22 @@ public class InnerMqttMessageProcessor implements RequestProcessor { private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; private WillMessageService willMessageService; - private MqttPushServiceImpl mqttPushService; private ClientManager iotClientManager; private RemotingServer innerMqttRemotingServer; private MqttConfig mqttConfig; private SnodeConfig snodeConfig; private EnodeService enodeService; private NnodeService nnodeService; - private MqttMessageForwarder mqttMessageForwarder; + private MqttMessageForwardHandler mqttMessageForwardHandler; public InnerMqttMessageProcessor(DefaultMqttMessageProcessor defaultMqttMessageProcessor, RemotingServer innerMqttRemotingServer) { this.defaultMqttMessageProcessor = defaultMqttMessageProcessor; this.willMessageService = this.defaultMqttMessageProcessor.getWillMessageService(); - this.mqttPushService = this.defaultMqttMessageProcessor.getMqttPushService(); this.iotClientManager = this.defaultMqttMessageProcessor.getIotClientManager(); this.innerMqttRemotingServer = innerMqttRemotingServer; this.enodeService = this.defaultMqttMessageProcessor.getEnodeService(); this.nnodeService = this.defaultMqttMessageProcessor.getNnodeService(); - this.mqttMessageForwarder = new MqttMessageForwarder(this); + this.mqttMessageForwardHandler = new MqttMessageForwardHandler(this); } @Override @@ -82,7 +79,7 @@ public class InnerMqttMessageProcessor implements RequestProcessor { mqttHeader.getRemainingLength()); MqttPublishVariableHeader mqttPublishVariableHeader = new MqttPublishVariableHeader(mqttHeader.getTopicName(), mqttHeader.getPacketId()); MqttMessage mqttMessage = new MqttPublishMessage(fixedHeader, mqttPublishVariableHeader, Unpooled.copiedBuffer(message.getBody())); - return mqttMessageForwarder.handleMessage(mqttMessage, remotingChannel); + return mqttMessageForwardHandler.handleMessage(mqttMessage, remotingChannel); }else{ return defaultMqttMessageProcessor.processRequest(remotingChannel, message); } @@ -98,10 +95,6 @@ public class InnerMqttMessageProcessor implements RequestProcessor { return willMessageService; } - public MqttPushServiceImpl getMqttPushService() { - return mqttPushService; - } - public ClientManager getIotClientManager() { return iotClientManager; } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/impl/MqttPushServiceImpl.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/impl/MqttPushServiceImpl.java deleted file mode 100644 index a87a6b6c3d85e162d2ef0fcbf48c3b614f4b008c..0000000000000000000000000000000000000000 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/impl/MqttPushServiceImpl.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.rocketmq.mqtt.service.impl; - -import io.netty.buffer.ByteBuf; -import io.netty.util.ReferenceCountUtil; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import org.apache.rocketmq.common.MqttConfig; -import org.apache.rocketmq.common.client.Client; -import org.apache.rocketmq.common.constant.LoggerName; -import org.apache.rocketmq.common.protocol.RequestCode; -import org.apache.rocketmq.common.utils.ThreadUtils; -import org.apache.rocketmq.logging.InternalLogger; -import org.apache.rocketmq.logging.InternalLoggerFactory; -import org.apache.rocketmq.mqtt.constant.MqttConstant; -import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; -import org.apache.rocketmq.mqtt.service.MqttPushService; -import org.apache.rocketmq.remoting.RemotingChannel; -import org.apache.rocketmq.remoting.netty.NettyChannelHandlerContextImpl; -import org.apache.rocketmq.remoting.netty.NettyChannelImpl; -import org.apache.rocketmq.remoting.protocol.RemotingCommand; -import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; - -public class MqttPushServiceImpl implements MqttPushService { - private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); - - private ExecutorService pushMqttMessageExecutorService; - private static DefaultMqttMessageProcessor defaultMqttMessageProcessor; - - public MqttPushServiceImpl(DefaultMqttMessageProcessor defaultMqttMessageProcessor, MqttConfig mqttConfig) { - this.defaultMqttMessageProcessor = defaultMqttMessageProcessor; - pushMqttMessageExecutorService = ThreadUtils.newThreadPoolExecutor( - mqttConfig.getPushMqttMessageMinPoolSize(), - mqttConfig.getPushMqttMessageMaxPoolSize(), - 3000, - TimeUnit.MILLISECONDS, - new ArrayBlockingQueue<>(mqttConfig.getPushMqttMessageThreadPoolQueueCapacity()), - "pushMqttMessageThread", - false); - } - - public static class MqttPushTask implements Runnable { - private AtomicBoolean canceled = new AtomicBoolean(false); - private final ByteBuf message; - private final MqttHeader mqttHeader; - private Client client; -// private final String topic; -// private final Integer qos; -// private boolean retain; -// private Integer packetId; - - public MqttPushTask(final MqttHeader mqttHeader, final ByteBuf message, Client client) { - this.message = message; - this.mqttHeader = mqttHeader; -// this.topic = topic; -// this.qos = qos; -// this.retain = retain; -// this.packetId = packetId; - this.client = client; - } - - @Override - public void run() { - if (!canceled.get()) { - try { - RemotingCommand requestCommand = buildRequestCommand(this.mqttHeader); - - RemotingChannel remotingChannel = client.getRemotingChannel(); - if (client.getRemotingChannel() instanceof NettyChannelHandlerContextImpl) { - remotingChannel = new NettyChannelImpl(((NettyChannelHandlerContextImpl) client.getRemotingChannel()).getChannelHandlerContext().channel()); - } - byte[] body = new byte[message.readableBytes()]; - message.readBytes(body); - requestCommand.setBody(body); - defaultMqttMessageProcessor.getMqttRemotingServer().push(remotingChannel, requestCommand, MqttConstant.DEFAULT_TIMEOUT_MILLS); - } catch (Exception ex) { - log.warn("Exception was thrown when pushing MQTT message to topic: {}, clientId:{}, exception={}", mqttHeader.getTopicName(), client.getClientId(), ex.getMessage()); - } finally { - ReferenceCountUtil.release(message); - } - } else { - log.info("Push message to topic: {}, clientId:{}, canceled!", mqttHeader.getTopicName(), client.getClientId()); - } - } - - private RemotingCommand buildRequestCommand(MqttHeader mqttHeader) { -// if (qos == 0) { -// mqttHeader.setDup(false);//DUP is always 0 for qos=0 messages -// } else { -// mqttHeader.setDup(false);//DUP is depending on whether it is a re-delivery of an earlier attempt. -// } -// mqttHeader.setRemainingLength(4 + topic.getBytes().length + message.readableBytes()); - - RemotingCommand pushMessage = RemotingCommand.createRequestCommand(RequestCode.MQTT_MESSAGE, mqttHeader); - return pushMessage; - } - - public void setCanceled(AtomicBoolean canceled) { - this.canceled = canceled; - } - - } - - public void pushMessageQos(MqttHeader mqttHeader, final ByteBuf message, Client client) { - MqttPushTask pushTask = new MqttPushTask(mqttHeader, message, client); - pushMqttMessageExecutorService.submit(pushTask); - } - - public void shutdown() { - this.pushMqttMessageExecutorService.shutdown(); - } -} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/impl/MqttScheduledServiceImpl.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/impl/MqttScheduledServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..c4e57e56e6d6427ae2b7ff2158937bee2741c8b7 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/impl/MqttScheduledServiceImpl.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.mqtt.service.impl; + +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.common.message.MessageExt; +import org.apache.rocketmq.common.service.ScheduledService; +import org.apache.rocketmq.logging.InternalLogger; +import org.apache.rocketmq.logging.InternalLoggerFactory; +import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; +import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; + +public class MqttScheduledServiceImpl implements ScheduledService { + private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); + + private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; + + public MqttScheduledServiceImpl(DefaultMqttMessageProcessor defaultMqttMessageProcessor) { + this.defaultMqttMessageProcessor = defaultMqttMessageProcessor; + } + + private final ScheduledExecutorService mqttScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + return new Thread(r, "MqttScheduledThread"); + } + }); + + @Override + public void startScheduleTask() { + + this.mqttScheduledExecutorService.scheduleAtFixedRate(new Runnable() { + @Override + public void run() { + IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) defaultMqttMessageProcessor.getIotClientManager(); + ConcurrentHashMap>> processTable = iotClientManager.getProcessTable(); + for (Map.Entry>> entry : processTable.entrySet()) { + String brokerName = entry.getKey(); + ConcurrentHashMap> map = entry.getValue(); + for (Map.Entry> innerEntry : map.entrySet()) { + String topicClient = innerEntry.getKey(); + TreeMap inflightMessages = innerEntry.getValue(); + Long offset = inflightMessages.firstKey(); + defaultMqttMessageProcessor.getEnodeService().persistOffset(null, brokerName, topicClient.split("@")[1], topicClient.split("@")[0], 0, offset); + } + } + } + }, 0, defaultMqttMessageProcessor.getMqttConfig().getPersistOffsetInterval(), TimeUnit.MILLISECONDS); + } + + @Override + public void shutdown() { + if (this.mqttScheduledExecutorService != null) { + this.mqttScheduledExecutorService.shutdown(); + } + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/task/MqttPushTask.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/task/MqttPushTask.java new file mode 100644 index 0000000000000000000000000000000000000000..d0870b61cb67bb685bf01c05c9a75a277277a827 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/task/MqttPushTask.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.mqtt.task; + +import java.nio.ByteBuffer; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.rocketmq.common.client.Client; +import org.apache.rocketmq.common.client.Subscription; +import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.common.message.MessageConst; +import org.apache.rocketmq.common.message.MessageDecoder; +import org.apache.rocketmq.common.message.MessageExt; +import org.apache.rocketmq.common.protocol.RequestCode; +import org.apache.rocketmq.common.protocol.header.GetMaxOffsetRequestHeader; +import org.apache.rocketmq.common.protocol.heartbeat.MqttSubscriptionData; +import org.apache.rocketmq.common.protocol.heartbeat.SubscriptionData; +import org.apache.rocketmq.common.protocol.route.BrokerData; +import org.apache.rocketmq.common.service.EnodeService; +import org.apache.rocketmq.logging.InternalLogger; +import org.apache.rocketmq.logging.InternalLoggerFactory; +import org.apache.rocketmq.mqtt.client.IOTClientManagerImpl; +import org.apache.rocketmq.mqtt.client.MQTTSession; +import org.apache.rocketmq.mqtt.constant.MqttConstant; +import org.apache.rocketmq.mqtt.processor.DefaultMqttMessageProcessor; +import org.apache.rocketmq.mqtt.util.MqttUtil; +import org.apache.rocketmq.remoting.exception.RemotingCommandException; +import org.apache.rocketmq.remoting.exception.RemotingConnectException; +import org.apache.rocketmq.remoting.exception.RemotingSendRequestException; +import org.apache.rocketmq.remoting.exception.RemotingTimeoutException; +import org.apache.rocketmq.remoting.protocol.RemotingCommand; +import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; + +public class MqttPushTask implements Runnable { + + private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); + private final DefaultMqttMessageProcessor defaultMqttMessageProcessor; + private MqttHeader mqttHeader; + private MQTTSession client; + private BrokerData brokerData; + + public MqttPushTask(DefaultMqttMessageProcessor processor, final MqttHeader mqttHeader, Client client, + BrokerData brokerData) { + this.defaultMqttMessageProcessor = processor; + this.mqttHeader = mqttHeader; + this.client = (MQTTSession) client; + this.brokerData = brokerData; + } + + @Override + public void run() { + + String rootTopic = MqttUtil.getRootTopic(mqttHeader.getTopicName()); + EnodeService enodeService = this.defaultMqttMessageProcessor.getEnodeService(); + IOTClientManagerImpl iotClientManager = (IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager(); + Subscription subscription = iotClientManager.getSubscriptionByClientId(client.getClientId()); + ConcurrentHashMap subscriptionTable = subscription.getSubscriptionTable(); + //compare current consumeOffset of rootTopic@clientId with maxOffset, pull message if consumeOffset < maxOffset + long maxOffsetInQueue; + try { + maxOffsetInQueue = getMaxOffset(brokerData.getBrokerName(), rootTopic); + final long consumeOffset = enodeService.queryOffset(brokerData.getBrokerName(), client.getClientId(), rootTopic, 0); + long i = consumeOffset + 1; + while (i <= maxOffsetInQueue) { + //TODO query messages(queueOffset=i) from enode above(brokerData.getBrokerName) + RemotingCommand response = null; + ByteBuffer byteBuffer = ByteBuffer.wrap(response.getBody()); + MessageExt messageExt = MessageDecoder.clientDecode(byteBuffer, true); + + final String realTopic = messageExt.getProperty(MessageConst.PROPERTY_REAL_TOPIC); + + boolean needSkip = needSkip(realTopic, subscriptionTable); + boolean alreadyInFlight = alreadyInFight(brokerData.getBrokerName(), realTopic, client.getClientId(), messageExt.getQueueOffset()); + if (needSkip) { + log.info("Current client doesn't subscribe topic:{}, skip this message", realTopic); + maxOffsetInQueue = getMaxOffset(brokerData.getBrokerName(), rootTopic); + i += 1; + continue; + } + if (alreadyInFlight) { + log.info("The message is already inflight. MessageId={}", messageExt.getMsgId()); + break; + } + Integer pushQos = lowerQosToTheSubscriptionDesired(realTopic, Integer.valueOf(messageExt.getProperty(MqttConstant.PROPERTY_MQTT_QOS)), subscriptionTable); + mqttHeader.setQosLevel(pushQos); + mqttHeader.setTopicName(realTopic); + if (client.getInflightSlots().get() == 0) { + log.info("The in-flight window is full, stop pushing message to consumers and update consumeOffset. ClientId={}, rootTopic={}", client.getClientId(), rootTopic); + break; + } + //push message if in-flight window has slot(not full) + client.pushMessageQos1(mqttHeader, messageExt, brokerData); + + maxOffsetInQueue = getMaxOffset(brokerData.getBrokerName(), rootTopic); + i += 1; + } + + //TODO update consumeOffset of rootTopic@clientId in brokerData.getBrokerName() + enodeService.persistOffset(null, brokerData.getBrokerName(), client.getClientId(), rootTopic, 0, i - 1); + } catch (Exception ex) { + log.error("Exception was thrown when pushing messages to consumer.{}", ex); + } + } + + private boolean needSkip(final String realTopic, ConcurrentHashMap subscriptionTable) { + Enumeration topicFilters = subscriptionTable.keys(); + while (topicFilters.hasMoreElements()) { + if (MqttUtil.isMatch(topicFilters.nextElement(), realTopic)) { + return false; + } + } + return true; + } + + private boolean alreadyInFight(String brokerName, String topic, String clientId, Long queueOffset) { + ConcurrentHashMap>> processTable = ((IOTClientManagerImpl) this.defaultMqttMessageProcessor.getIotClientManager()).getProcessTable(); + ConcurrentHashMap> map = processTable.get(brokerName); + if (map != null) { + TreeMap treeMap = map.get(MqttUtil.getRootTopic(topic) + "@" + clientId); + if (treeMap != null && treeMap.get(queueOffset) != null) { + return true; + } + } + return false; + } + + private Integer lowerQosToTheSubscriptionDesired(String publishTopic, Integer publishingQos, + ConcurrentHashMap subscriptionTable) { + Integer pushQos = Integer.valueOf(publishingQos); + Iterator> iterator = subscriptionTable.entrySet().iterator(); + Integer maxRequestedQos = 0; + while (iterator.hasNext()) { + final String topicFilter = iterator.next().getKey(); + if (MqttUtil.isMatch(topicFilter, publishTopic)) { + MqttSubscriptionData mqttSubscriptionData = (MqttSubscriptionData) iterator.next().getValue(); + maxRequestedQos = mqttSubscriptionData.getQos() > maxRequestedQos ? mqttSubscriptionData.getQos() : maxRequestedQos; + } + } + if (publishingQos > maxRequestedQos) { + pushQos = maxRequestedQos; + } + return pushQos; + } + + private long getMaxOffset(String enodeName, + String topic) throws InterruptedException, RemotingTimeoutException, RemotingCommandException, RemotingSendRequestException, RemotingConnectException { + GetMaxOffsetRequestHeader requestHeader = new GetMaxOffsetRequestHeader(); + requestHeader.setTopic(topic); + requestHeader.setQueueId(0); + RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.GET_MAX_OFFSET, requestHeader); + + return this.defaultMqttMessageProcessor.getEnodeService().getMaxOffsetInQueue(enodeName, topic, 0, request); + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/MqttPushService.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/transfer/TransferDataQos1.java similarity index 57% rename from mqtt/src/main/java/org/apache/rocketmq/mqtt/service/MqttPushService.java rename to mqtt/src/main/java/org/apache/rocketmq/mqtt/transfer/TransferDataQos1.java index bbc706ef5c4758990dc8ebe748f17792a86ffca2..70e32d051e1c5e458ed3cd83fc8e11912186ef90 100644 --- a/mqtt/src/main/java/org/apache/rocketmq/mqtt/service/MqttPushService.java +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/transfer/TransferDataQos1.java @@ -14,12 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.rocketmq.mqtt.service; -import io.netty.buffer.ByteBuf; -import org.apache.rocketmq.common.client.Client; -import org.apache.rocketmq.remoting.transport.mqtt.MqttHeader; -public interface MqttPushService { - void pushMessageQos(MqttHeader mqttHeader, final ByteBuf message, Client client); +package org.apache.rocketmq.mqtt.transfer; + +import org.apache.rocketmq.common.protocol.route.BrokerData; +import org.apache.rocketmq.remoting.serialize.RemotingSerializable; + +public class TransferDataQos1 extends RemotingSerializable { + + private BrokerData brokerData; + private String topic; + + public BrokerData getBrokerData() { + return brokerData; + } + + public void setBrokerData(BrokerData brokerData) { + this.brokerData = brokerData; + } + + public String getTopic() { + return topic; + } + + public void setTopic(String topic) { + this.topic = topic; + } } diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/BoundedExecutorService.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/BoundedExecutorService.java new file mode 100644 index 0000000000000000000000000000000000000000..3e5249c40caebee970bc4be34ab942424cbd0ea1 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/BoundedExecutorService.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import com.google.common.util.concurrent.ForwardingExecutorService; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Implements {@link ExecutorService} and allows limiting the number of tasks to + * be scheduled in the thread's queue. + */ +public class BoundedExecutorService extends ForwardingExecutorService { + private final BlockingQueue queue; + private final ThreadPoolExecutor thread; + private final int maxTasksInQueue; + + public BoundedExecutorService(ThreadPoolExecutor thread, int maxTasksInQueue) { + this.queue = thread.getQueue(); + this.thread = thread; + this.maxTasksInQueue = maxTasksInQueue; + } + + @Override + protected ExecutorService delegate() { + return this.thread; + } + + private void checkQueue(int numberOfTasks) { + if (maxTasksInQueue > 0 && (queue.size() + numberOfTasks) > maxTasksInQueue) { + throw new RejectedExecutionException("Queue at limit of " + maxTasksInQueue + " items"); + } + } + + @Override + public List> invokeAll(Collection> tasks) throws InterruptedException { + checkQueue(tasks.size()); + return super.invokeAll(tasks); + } + + @Override + public List> invokeAll(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException { + checkQueue(tasks.size()); + return super.invokeAll(tasks, timeout, unit); + } + + @Override + public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + checkQueue(tasks.size()); + return super.invokeAny(tasks); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + checkQueue(tasks.size()); + return super.invokeAny(tasks, timeout, unit); + } + + @Override + public void execute(Runnable command) { + checkQueue(1); + super.execute(command); + } + + @Override + public Future submit(Callable task) { + checkQueue(1); + return super.submit(task); + } + + @Override + public Future submit(Runnable task) { + checkQueue(1); + return super.submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + checkQueue(1); + return super.submit(task, result); + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/Counter.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/Counter.java new file mode 100644 index 0000000000000000000000000000000000000000..fbc8ef08fadea4d62b8318947a7a29bb4e450ab4 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/Counter.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +/** + * Simple stats that require only increment and decrement + * functions on a Long. Metrics like the number of topics, persist queue size + * etc. should use this. + */ +public interface Counter { + /** + * Clear this stat. + */ + void clear(); + + /** + * Increment the value associated with this stat. + */ + void inc(); + + /** + * Decrement the value associated with this stat. + */ + void dec(); + + /** + * Add delta to the value associated with this stat. + * @param delta + */ + void add(long delta); + + /** + * Get the value associated with this stat. + */ + Long get(); +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/Gauge.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/Gauge.java new file mode 100644 index 0000000000000000000000000000000000000000..1b532afdd9693d67fffa4a3f4cff0e8f83141097 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/Gauge.java @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +/** + * A guage is a value that has only one value at a specific point in time. + * An example is the number of elements in a queue. The value of T must be + * some numeric type. + */ +public interface Gauge { + T getDefaultValue(); + T getSample(); +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/MathUtils.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/MathUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..85196c18be70be6891df2b51da73c901cb7f2eab --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/MathUtils.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import java.util.concurrent.TimeUnit; + +/** + * Provides misc math functions that don't come standard. + */ +public class MathUtils { + + private static final long NANOSECONDS_PER_MILLISECOND = 1000000; + + public static int signSafeMod(long dividend, int divisor) { + int mod = (int) (dividend % divisor); + + if (mod < 0) { + mod += divisor; + } + + return mod; + } + + public static int findNextPositivePowerOfTwo(final int value) { + return 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + } + + /** + * Current time from some arbitrary time base in the past, counting in + * nanoseconds, and not affected by settimeofday or similar system clock + * changes. This is appropriate to use when computing how much longer to + * wait for an interval to expire. + * + *

NOTE: only use it for measuring. + * http://docs.oracle.com/javase/1.5.0/docs/api/java/lang/System.html#nanoTime%28%29 + * + * @return current time in nanoseconds. + */ + public static long nowInNano() { + return System.nanoTime(); + } + + /** + * Milliseconds elapsed since the time specified, the input is nanoTime + * the only conversion happens when computing the elapsed time. + * + * @param startNanoTime the start of the interval that we are measuring + * @return elapsed time in milliseconds. + */ + public static long elapsedMSec(long startNanoTime) { + return (System.nanoTime() - startNanoTime) / NANOSECONDS_PER_MILLISECOND; + } + + /** + * Microseconds elapsed since the time specified, the input is nanoTime + * the only conversion happens when computing the elapsed time. + * + * @param startNanoTime the start of the interval that we are measuring + * @return elapsed time in milliseconds. + */ + public static long elapsedMicroSec(long startNanoTime) { + return TimeUnit.NANOSECONDS.toMicros(System.nanoTime() - startNanoTime); + } + + /** + * Nanoseconds elapsed since the time specified, the input is nanoTime + * the only conversion happens when computing the elapsed time. + * + * @param startNanoTime the start of the interval that we are measuring + * @return elapsed time in milliseconds. + */ + public static long elapsedNanos(long startNanoTime) { + return System.nanoTime() - startNanoTime; + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/MdcUtils.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/MdcUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..0dab4fad8c7ce38e22fc073fe5d8edf715d0aebe --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/MdcUtils.java @@ -0,0 +1,39 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import java.util.Map; +import org.slf4j.MDC; + +/** + * Utils for work with Slf4j MDC. + */ +public class MdcUtils { + + public static void restoreContext(Map mdcContextMap) { + if (mdcContextMap == null || mdcContextMap.isEmpty()) { + MDC.clear(); + } else { + MDC.setContextMap(mdcContextMap); + } + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/NullStatsLogger.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/NullStatsLogger.java new file mode 100644 index 0000000000000000000000000000000000000000..ed5490c464b0112fc456ca912887389a033c1347 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/NullStatsLogger.java @@ -0,0 +1,129 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import java.util.concurrent.TimeUnit; + +/** + * A no-op {@code StatsLogger}. + * + *

Metrics are not recorded, making this receiver useful in unit tests and as defaults in + * situations where metrics are not strictly required. + */ +public class NullStatsLogger implements StatsLogger { + + public static final NullStatsLogger INSTANCE = new NullStatsLogger(); + + /** + * A no-op {@code OpStatsLogger}. + */ + static class NullOpStatsLogger implements OpStatsLogger { + final OpStatsData nullOpStats = new OpStatsData(0, 0, 0, new long[6]); + + @Override + public void registerFailedEvent(long eventLatency, TimeUnit unit) { + // nop + } + + @Override + public void registerSuccessfulEvent(long eventLatency, TimeUnit unit) { + // nop + } + + @Override + public void registerSuccessfulValue(long value) { + // nop + } + + @Override + public void registerFailedValue(long value) { + // nop + } + + @Override + public OpStatsData toOpStatsData() { + return nullOpStats; + } + + @Override + public void clear() { + // nop + } + } + static NullOpStatsLogger nullOpStatsLogger = new NullOpStatsLogger(); + + /** + * A no-op {@code Counter}. + */ + static class NullCounter implements Counter { + @Override + public void clear() { + // nop + } + + @Override + public void inc() { + // nop + } + + @Override + public void dec() { + // nop + } + + @Override + public void add(long delta) { + // nop + } + + @Override + public Long get() { + return 0L; + } + } + static NullCounter nullCounter = new NullCounter(); + + @Override + public OpStatsLogger getOpStatsLogger(String name) { + return nullOpStatsLogger; + } + + @Override + public Counter getCounter(String name) { + return nullCounter; + } + + @Override + public void registerGauge(String name, Gauge gauge) { + // nop + } + + @Override + public void unregisterGauge(String name, Gauge gauge) { + // nop + } + + @Override + public StatsLogger scope(String name) { + return this; + } + + @Override + public void removeScope(String name, StatsLogger statsLogger) { + // nop + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OpStatsData.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OpStatsData.java new file mode 100644 index 0000000000000000000000000000000000000000..bd96936bd7114affccbb4f553042ae3a7987bb5a --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OpStatsData.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import java.util.Arrays; + +/** + * This class provides a read view of operation specific stats. + * We expose this to JMX. + * We use primitives because the class has to conform to CompositeViewData. + */ +public class OpStatsData { + private final long numSuccessfulEvents, numFailedEvents; + // All latency values are in Milliseconds. + private final double avgLatencyMillis; + // 10.0 50.0, 90.0, 99.0, 99.9, 99.99 in that order. + // TODO: Figure out if we can use a Map + private final long[] percentileLatenciesMillis; + public OpStatsData(long numSuccessfulEvents, long numFailedEvents, + double avgLatencyMillis, long[] percentileLatenciesMillis) { + this.numSuccessfulEvents = numSuccessfulEvents; + this.numFailedEvents = numFailedEvents; + this.avgLatencyMillis = avgLatencyMillis; + this.percentileLatenciesMillis = + Arrays.copyOf(percentileLatenciesMillis, percentileLatenciesMillis.length); + } + + public long getP10Latency() { + return this.percentileLatenciesMillis[0]; + } + public long getP50Latency() { + return this.percentileLatenciesMillis[1]; + } + + public long getP90Latency() { + return this.percentileLatenciesMillis[2]; + } + + public long getP99Latency() { + return this.percentileLatenciesMillis[3]; + } + + public long getP999Latency() { + return this.percentileLatenciesMillis[4]; + } + + public long getP9999Latency() { + return this.percentileLatenciesMillis[5]; + } + + public long getNumSuccessfulEvents() { + return this.numSuccessfulEvents; + } + + public long getNumFailedEvents() { + return this.numFailedEvents; + } + + public double getAvgLatencyMillis() { + return this.avgLatencyMillis; + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OpStatsLogger.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OpStatsLogger.java new file mode 100644 index 0000000000000000000000000000000000000000..d6eab79856bb949610197d0e2d656cb003815ca9 --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OpStatsLogger.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import java.util.concurrent.TimeUnit; + +/** + * This interface handles logging of statistics related to each operation. (PUBLISH, CONSUME etc.) + */ +public interface OpStatsLogger { + + /** + * Increment the failed op counter with the given eventLatency. + * @param eventLatencyMillis The event latency + * @param unit + */ + void registerFailedEvent(long eventLatencyMillis, TimeUnit unit); + + /** + * An operation succeeded with the given eventLatency. Update + * stats to reflect the same + * @param eventLatencyMillis The event latency + * @param unit + */ + void registerSuccessfulEvent(long eventLatencyMillis, TimeUnit unit); + + /** + * An operation with the given value succeeded. + * @param value + */ + void registerSuccessfulValue(long value); + + /** + * An operation with the given value failed. + */ + void registerFailedValue(long value); + + /** + * @return Returns an OpStatsData object with necessary values. We need this function + * to support JMX exports. This should be deprecated sometime in the near future. + * populated. + */ + OpStatsData toOpStatsData(); + + /** + * Clear stats for this operation. + */ + void clear(); +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OrderedExecutor.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OrderedExecutor.java new file mode 100644 index 0000000000000000000000000000000000000000..ad9daa9e0cdad8e15b5d86badf28ae8afdf331fc --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/OrderedExecutor.java @@ -0,0 +1,684 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import com.google.common.util.concurrent.ForwardingExecutorService; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.util.concurrent.DefaultThreadFactory; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; +import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.logging.InternalLogger; +import org.apache.rocketmq.logging.InternalLoggerFactory; +import org.slf4j.MDC; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * This class provides 2 things over the java {@link ExecutorService}. + * + *

1. It takes {@link SafeRunnable objects} instead of plain Runnable objects. + * This means that exceptions in scheduled tasks wont go unnoticed and will be logged. + * + *

2. It supports submitting tasks with an ordering key, so that tasks submitted + * with the same key will always be executed in order, but tasks across different keys can be unordered. This retains + * parallelism while retaining the basic amount of ordering we want (e.g. , per ledger handle). Ordering is achieved by + * hashing the key objects to threads by their {@link #hashCode()} method. + */ +public class OrderedExecutor implements ExecutorService { + private static final InternalLogger log = InternalLoggerFactory.getLogger(LoggerName.MQTT_LOGGER_NAME); + public static final int NO_TASK_LIMIT = -1; + protected static final long WARN_TIME_MICRO_SEC_DEFAULT = TimeUnit.SECONDS.toMicros(1); + + final String name; + final ExecutorService threads[]; + final long threadIds[]; + final Random rand = new Random(); + final OpStatsLogger taskExecutionStats; + final OpStatsLogger taskPendingStats; + final boolean traceTaskExecution; + final boolean preserveMdcForTaskExecution; + final long warnTimeMicroSec; + final int maxTasksInQueue; + + public static Builder newBuilder() { + return new Builder(); + } + + /** + * A builder class for an OrderedExecutor. + */ + public static class Builder extends AbstractBuilder { + + @Override + public OrderedExecutor build() { + if (null == threadFactory) { + threadFactory = new DefaultThreadFactory("bookkeeper-ordered-safe-executor"); + } + return new OrderedExecutor(name, numThreads, threadFactory, statsLogger, + traceTaskExecution, preserveMdcForTaskExecution, warnTimeMicroSec, maxTasksInQueue); + } + } + + /** + * Abstract builder class to build {@link OrderedExecutor}. + */ + public abstract static class AbstractBuilder { + protected String name = getClass().getSimpleName(); + protected int numThreads = Runtime.getRuntime().availableProcessors(); + protected ThreadFactory threadFactory = null; + protected StatsLogger statsLogger = NullStatsLogger.INSTANCE; + protected boolean traceTaskExecution = false; + protected boolean preserveMdcForTaskExecution = false; + protected long warnTimeMicroSec = WARN_TIME_MICRO_SEC_DEFAULT; + protected int maxTasksInQueue = NO_TASK_LIMIT; + protected boolean enableBusyWait = false; + + public AbstractBuilder name(String name) { + this.name = name; + return this; + } + + public AbstractBuilder numThreads(int num) { + this.numThreads = num; + return this; + } + + public AbstractBuilder maxTasksInQueue(int num) { + this.maxTasksInQueue = num; + return this; + } + + public AbstractBuilder threadFactory(ThreadFactory threadFactory) { + this.threadFactory = threadFactory; + return this; + } + + public AbstractBuilder statsLogger(StatsLogger statsLogger) { + this.statsLogger = statsLogger; + return this; + } + + public AbstractBuilder traceTaskExecution(boolean enabled) { + this.traceTaskExecution = enabled; + return this; + } + + public AbstractBuilder preserveMdcForTaskExecution(boolean enabled) { + this.preserveMdcForTaskExecution = enabled; + return this; + } + + public AbstractBuilder traceTaskWarnTimeMicroSec(long warnTimeMicroSec) { + this.warnTimeMicroSec = warnTimeMicroSec; + return this; + } + + public AbstractBuilder enableBusyWait(boolean enableBusyWait) { + this.enableBusyWait = enableBusyWait; + return this; + } + + @SuppressWarnings("unchecked") + public T build() { + if (null == threadFactory) { + threadFactory = new DefaultThreadFactory(name); + } + return (T) new OrderedExecutor( + name, + numThreads, + threadFactory, + statsLogger, + traceTaskExecution, + preserveMdcForTaskExecution, + warnTimeMicroSec, + maxTasksInQueue); + } + } + + /** + * Decorator class for a runnable that measure the execution time. + */ + protected class TimedRunnable implements Runnable { + final Runnable runnable; + final long initNanos; + + TimedRunnable(Runnable runnable) { + this.runnable = runnable; + this.initNanos = MathUtils.nowInNano(); + } + + @Override + public void run() { + taskPendingStats.registerSuccessfulEvent(MathUtils.elapsedNanos(initNanos), TimeUnit.NANOSECONDS); + long startNanos = MathUtils.nowInNano(); + try { + this.runnable.run(); + } finally { + long elapsedMicroSec = MathUtils.elapsedMicroSec(startNanos); + taskExecutionStats.registerSuccessfulEvent(elapsedMicroSec, TimeUnit.MICROSECONDS); + if (elapsedMicroSec >= warnTimeMicroSec) { + log.warn("Runnable {}:{} took too long {} micros to execute.", runnable, runnable.getClass(), + elapsedMicroSec); + } + } + } + } + + /** + * Decorator class for a callable that measure the execution time. + */ + protected class TimedCallable implements Callable { + final Callable callable; + final long initNanos; + + TimedCallable(Callable callable) { + this.callable = callable; + this.initNanos = MathUtils.nowInNano(); + } + + @Override + public T call() throws Exception { + taskPendingStats.registerSuccessfulEvent(MathUtils.elapsedNanos(initNanos), TimeUnit.NANOSECONDS); + long startNanos = MathUtils.nowInNano(); + try { + return this.callable.call(); + } finally { + long elapsedMicroSec = MathUtils.elapsedMicroSec(startNanos); + taskExecutionStats.registerSuccessfulEvent(elapsedMicroSec, TimeUnit.MICROSECONDS); + if (elapsedMicroSec >= warnTimeMicroSec) { + log.warn("Callable {}:{} took too long {} micros to execute.", callable, callable.getClass(), + elapsedMicroSec); + } + } + } + } + + /** + * Decorator class for a runnable that preserves MDC context. + */ + static class ContextPreservingRunnable implements Runnable { + private final Runnable runnable; + private final Map mdcContextMap; + + ContextPreservingRunnable(Runnable runnable) { + this.runnable = runnable; + this.mdcContextMap = MDC.getCopyOfContextMap(); + } + + @Override + public void run() { + MdcUtils.restoreContext(mdcContextMap); + try { + runnable.run(); + } finally { + MDC.clear(); + } + } + } + + /** + * Decorator class for a callable that preserves MDC context. + */ + static class ContextPreservingCallable implements Callable { + private final Callable callable; + private final Map mdcContextMap; + + ContextPreservingCallable(Callable callable) { + this.callable = callable; + this.mdcContextMap = MDC.getCopyOfContextMap(); + } + + @Override + public T call() throws Exception { + MdcUtils.restoreContext(mdcContextMap); + try { + return callable.call(); + } finally { + MDC.clear(); + } + } + } + + protected ThreadPoolExecutor createSingleThreadExecutor(ThreadFactory factory) { + BlockingQueue queue = new LinkedBlockingQueue<>(); + return new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, queue, factory); + } + + protected ExecutorService getBoundedExecutor(ThreadPoolExecutor executor) { + return new BoundedExecutorService(executor, this.maxTasksInQueue); + } + + protected ExecutorService addExecutorDecorators(ExecutorService executor) { + return new ForwardingExecutorService() { + @Override + protected ExecutorService delegate() { + return executor; + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + return super.invokeAll(timedCallables(tasks)); + } + + @Override + public List> invokeAll(Collection> tasks, + long timeout, TimeUnit unit) + throws InterruptedException { + return super.invokeAll(timedCallables(tasks), timeout, unit); + } + + @Override + public T invokeAny(Collection> tasks) + throws InterruptedException, ExecutionException { + return super.invokeAny(timedCallables(tasks)); + } + + @Override + public T invokeAny(Collection> tasks, + long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return super.invokeAny(timedCallables(tasks), timeout, unit); + } + + @Override + public void execute(Runnable command) { + super.execute(timedRunnable(command)); + } + + @Override + public Future submit(Callable task) { + return super.submit(timedCallable(task)); + } + + @Override + public Future submit(Runnable task) { + return super.submit(timedRunnable(task)); + } + + @Override + public Future submit(Runnable task, T result) { + return super.submit(timedRunnable(task), result); + } + }; + } + + /** + * Constructs Safe executor. + * + * @param numThreads - number of threads + * @param baseName - base name of executor threads + * @param threadFactory - for constructing threads + * @param statsLogger - for reporting executor stats + * @param traceTaskExecution - should we stat task execution + * @param preserveMdcForTaskExecution - should we preserve MDC for task execution + * @param warnTimeMicroSec - log long task exec warning after this interval + * @param maxTasksInQueue - maximum items allowed in a thread queue. -1 for no limit + */ + protected OrderedExecutor(String baseName, int numThreads, ThreadFactory threadFactory, + StatsLogger statsLogger, boolean traceTaskExecution, + boolean preserveMdcForTaskExecution, long warnTimeMicroSec, int maxTasksInQueue) { + checkArgument(numThreads > 0); + checkArgument(!StringUtils.isBlank(baseName)); + + this.maxTasksInQueue = maxTasksInQueue; + this.warnTimeMicroSec = warnTimeMicroSec; + name = baseName; + threads = new ExecutorService[numThreads]; + threadIds = new long[numThreads]; + for (int i = 0; i < numThreads; i++) { + ThreadPoolExecutor thread = createSingleThreadExecutor( + new ThreadFactoryBuilder().setNameFormat(name + "-" + getClass().getSimpleName() + "-" + i + "-%d") + .setThreadFactory(threadFactory).build()); + + threads[i] = addExecutorDecorators(getBoundedExecutor(thread)); + + final int idx = i; + try { + threads[idx].submit(() -> { + threadIds[idx] = Thread.currentThread().getId(); + }).get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Couldn't start thread " + i, e); + } catch (ExecutionException e) { + throw new RuntimeException("Couldn't start thread " + i, e); + } + + // Register gauges + statsLogger.registerGauge(String.format("%s-queue-%d", name, idx), new Gauge() { + @Override + public Number getDefaultValue() { + return 0; + } + + @Override + public Number getSample() { + return thread.getQueue().size(); + } + }); + statsLogger.registerGauge(String.format("%s-completed-tasks-%d", name, idx), new Gauge() { + @Override + public Number getDefaultValue() { + return 0; + } + + @Override + public Number getSample() { + return thread.getCompletedTaskCount(); + } + }); + statsLogger.registerGauge(String.format("%s-total-tasks-%d", name, idx), new Gauge() { + @Override + public Number getDefaultValue() { + return 0; + } + + @Override + public Number getSample() { + return thread.getTaskCount(); + } + }); + } + + // Stats + this.taskExecutionStats = statsLogger.scope(name).getOpStatsLogger("task_execution"); + this.taskPendingStats = statsLogger.scope(name).getOpStatsLogger("task_queued"); + this.traceTaskExecution = traceTaskExecution; + this.preserveMdcForTaskExecution = preserveMdcForTaskExecution; + } + + /** + * Flag describing executor's expectation in regards of MDC. All tasks submitted through executor's submit/execute + * methods will automatically respect this. + * + * @return true if runnable/callable is expected to preserve MDC, false otherwise. + */ + public boolean preserveMdc() { + return preserveMdcForTaskExecution; + } + + /** + * Schedules a one time action to execute with an ordering guarantee on the key. + * + * @param orderingKey + * @param r + */ + public void executeOrdered(Object orderingKey, SafeRunnable r) { + chooseThread(orderingKey).execute(r); + } + + /** + * Schedules a one time action to execute with an ordering guarantee on the key. + * + * @param orderingKey + * @param r + */ + public void executeOrdered(long orderingKey, SafeRunnable r) { + chooseThread(orderingKey).execute(r); + } + + /** + * Schedules a one time action to execute with an ordering guarantee on the key. + * + * @param orderingKey + * @param r + */ + public void executeOrdered(int orderingKey, SafeRunnable r) { + chooseThread(orderingKey).execute(r); + } + + public ListenableFuture submitOrdered(long orderingKey, Callable task) { + SettableFuture future = SettableFuture.create(); + executeOrdered(orderingKey, () -> { + try { + T result = task.call(); + future.set(result); + } catch (Throwable t) { + future.setException(t); + } + }); + + return future; + } + + public long getThreadID(long orderingKey) { + // skip hashcode generation in this special case + if (threadIds.length == 1) { + return threadIds[0]; + } + + return threadIds[MathUtils.signSafeMod(orderingKey, threadIds.length)]; + } + + public ExecutorService chooseThread() { + // skip random # generation in this special case + if (threads.length == 1) { + return threads[0]; + } + + return threads[rand.nextInt(threads.length)]; + } + + public ExecutorService chooseThread(Object orderingKey) { + // skip hashcode generation in this special case + if (threads.length == 1) { + return threads[0]; + } + + if (null == orderingKey) { + return threads[rand.nextInt(threads.length)]; + } else { + return threads[MathUtils.signSafeMod(orderingKey.hashCode(), threads.length)]; + } + } + + /** + * skip hashcode generation in this special case. + * + * @param orderingKey long ordering key + * @return the thread for executing this order key + */ + public ExecutorService chooseThread(long orderingKey) { + if (threads.length == 1) { + return threads[0]; + } + + return threads[MathUtils.signSafeMod(orderingKey, threads.length)]; + } + + protected Runnable timedRunnable(Runnable r) { + final Runnable runMe = traceTaskExecution ? new TimedRunnable(r) : r; + return preserveMdcForTaskExecution ? new ContextPreservingRunnable(runMe) : runMe; + } + + protected Callable timedCallable(Callable c) { + final Callable callMe = traceTaskExecution ? new TimedCallable<>(c) : c; + return preserveMdcForTaskExecution ? new ContextPreservingCallable<>(callMe) : callMe; + } + + protected Collection> timedCallables(Collection> tasks) { + if (traceTaskExecution || preserveMdcForTaskExecution) { + return tasks.stream() + .map(this::timedCallable) + .collect(Collectors.toList()); + } + return tasks; + } + + /** + * {@inheritDoc} + */ + @Override + public Future submit(Callable task) { + return chooseThread().submit(timedCallable(task)); + } + + /** + * {@inheritDoc} + */ + @Override + public Future submit(Runnable task, T result) { + return chooseThread().submit(task, result); + } + + /** + * {@inheritDoc} + */ + @Override + public Future submit(Runnable task) { + return chooseThread().submit(task); + } + + /** + * {@inheritDoc} + */ + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + return chooseThread().invokeAll(timedCallables(tasks)); + } + + /** + * {@inheritDoc} + */ + @Override + public List> invokeAll(Collection> tasks, + long timeout, + TimeUnit unit) + throws InterruptedException { + return chooseThread().invokeAll(timedCallables(tasks), timeout, unit); + } + + /** + * {@inheritDoc} + */ + @Override + public T invokeAny(Collection> tasks) + throws InterruptedException, ExecutionException { + return chooseThread().invokeAny(timedCallables(tasks)); + } + + /** + * {@inheritDoc} + */ + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return chooseThread().invokeAny(timedCallables(tasks), timeout, unit); + } + + /** + * {@inheritDoc} + */ + @Override + public void execute(Runnable command) { + chooseThread().execute(timedRunnable(command)); + } + + /** + * {@inheritDoc} + */ + @Override + public void shutdown() { + for (int i = 0; i < threads.length; i++) { + threads[i].shutdown(); + } + } + + /** + * {@inheritDoc} + */ + @Override + public List shutdownNow() { + List runnables = new ArrayList(); + for (ExecutorService executor : threads) { + runnables.addAll(executor.shutdownNow()); + } + return runnables; + } + + /** + * {@inheritDoc} + */ + @Override + public boolean isShutdown() { + for (ExecutorService executor : threads) { + if (!executor.isShutdown()) { + return false; + } + } + return true; + } + + /** + * {@inheritDoc} + */ + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + boolean ret = true; + for (int i = 0; i < threads.length; i++) { + ret = ret && threads[i].awaitTermination(timeout, unit); + } + return ret; + } + + /** + * {@inheritDoc} + */ + @Override + public boolean isTerminated() { + for (ExecutorService executor : threads) { + if (!executor.isTerminated()) { + return false; + } + } + return true; + } + + /** + * Force threads shutdown (cancel active requests) after specified delay, to be used after shutdown() rejects new + * requests. + */ + public void forceShutdown(long timeout, TimeUnit unit) { + for (int i = 0; i < threads.length; i++) { + try { + if (!threads[i].awaitTermination(timeout, unit)) { + threads[i].shutdownNow(); + } + } catch (InterruptedException exception) { + threads[i].shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/SafeRunnable.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/SafeRunnable.java new file mode 100644 index 0000000000000000000000000000000000000000..df36cf4964004890599cb745c23500a77dd8e5bc --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/SafeRunnable.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A runnable that catches runtime exceptions. + */ +@FunctionalInterface +public interface SafeRunnable extends Runnable { + + Logger LOGGER = LoggerFactory.getLogger(SafeRunnable.class); + + @Override + default void run() { + try { + safeRun(); + } catch (Throwable t) { + LOGGER.error("Unexpected throwable caught ", t); + } + } + + void safeRun(); + + /** + * Utility method to use SafeRunnable from lambdas. + * + *

Eg: + *

+     * 
+     * executor.submit(SafeRunnable.safeRun(() -> {
+     *    // My not-safe code
+     * });
+     * 
+     * 
+ */ + static SafeRunnable safeRun(Runnable runnable) { + return new SafeRunnable() { + @Override + public void safeRun() { + runnable.run(); + } + }; + } + + /** + * Utility method to use SafeRunnable from lambdas with + * a custom exception handler. + * + *

Eg: + *

+     * 
+     * executor.submit(SafeRunnable.safeRun(() -> {
+     *    // My not-safe code
+     * }, exception -> {
+     *    // Handle exception
+     * );
+     * 
+     * 
+ * + * @param runnable + * @param exceptionHandler + * handler that will be called when there are any exception + * @return + */ + static SafeRunnable safeRun(Runnable runnable, Consumer exceptionHandler) { + return () -> { + try { + runnable.run(); + } catch (Throwable t) { + exceptionHandler.accept(t); + throw t; + } + }; + } +} diff --git a/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/StatsLogger.java b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/StatsLogger.java new file mode 100644 index 0000000000000000000000000000000000000000..790b266f0e46f7a054faf4dc989e8097cd3ba9cf --- /dev/null +++ b/mqtt/src/main/java/org/apache/rocketmq/mqtt/util/orderedexecutor/StatsLogger.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.rocketmq.mqtt.util.orderedexecutor; + +/** + * A simple interface that exposes just 2 useful methods. One to get the logger for an Op stat + * and another to get the logger for a simple stat + */ +public interface StatsLogger { + /** + * @param name + * Stats Name + * @return Get the logger for an OpStat described by the name. + */ + OpStatsLogger getOpStatsLogger(String name); + + /** + * @param name + * Stats Name + * @return Get the logger for a simple stat described by the name + */ + Counter getCounter(String name); + + /** + * Register given gauge as name name. + * + * @param name + * gauge name + * @param gauge + * gauge function + */ + void registerGauge(String name, Gauge gauge); + + /** + * Unregister given gauge from name name. + * + * @param name + * name of the gauge + * @param gauge + * gauge function + */ + void unregisterGauge(String name, Gauge gauge); + + /** + * Provide the stats logger under scope name. + * + * @param name + * scope name. + * @return stats logger under scope name. + */ + StatsLogger scope(String name); + + /** + * Remove the given statsLogger for scope name. + * It can be no-op if the underlying stats provider doesn't have the ability to remove scope. + * + * @param name name of the scope + * @param statsLogger the stats logger of this scope. + */ + void removeScope(String name, StatsLogger statsLogger); + +} diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/transport/mqtt/MqttHeader.java b/remoting/src/main/java/org/apache/rocketmq/remoting/transport/mqtt/MqttHeader.java index d9c31b70864b22d16ada6363f4f4ce69e3e86580..84811e26db18793afff19eba2d7db7c54a21ecd8 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/transport/mqtt/MqttHeader.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/transport/mqtt/MqttHeader.java @@ -30,11 +30,11 @@ public class MqttHeader implements CommandCustomHeader { @CFNotNull private Integer messageType; @CFNotNull - private boolean isDup; + private boolean isDup = false; @CFNotNull private Integer qosLevel; @CFNotNull - private boolean isRetain; + private boolean isRetain = false; @CFNotNull private int remainingLength;