提交 c798ae03 编写于 作者: 浅梦2013's avatar 浅梦2013

加强 topic 校验

上级 3de882b5
......@@ -56,7 +56,13 @@ public final class MqttCodecUtil {
* @return 是否 topic filter
*/
public static boolean isTopicFilter(String topicFilter) {
return topicFilter.indexOf(TOPIC_WILDCARDS_ONE) >= 0 || topicFilter.indexOf(TOPIC_WILDCARDS_MORE) >= 0;
char[] topicFilterChars = topicFilter.toCharArray();
for (char ch : topicFilterChars) {
if (TOPIC_WILDCARDS_ONE == ch || TOPIC_WILDCARDS_MORE == ch) {
return true;
}
}
return false;
}
/**
......
......@@ -19,13 +19,16 @@ package net.dreamlu.iot.mqtt.core.client;
import net.dreamlu.iot.mqtt.codec.MqttQoS;
import net.dreamlu.iot.mqtt.core.common.MqttPendingPublish;
import net.dreamlu.iot.mqtt.core.common.MqttPendingQos2Publish;
import net.dreamlu.iot.mqtt.core.util.collection.MultiValueMap;
import net.dreamlu.iot.mqtt.core.util.collection.IntObjectHashMap;
import net.dreamlu.iot.mqtt.core.util.collection.IntObjectMap;
import net.dreamlu.iot.mqtt.core.util.collection.MultiValueMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
/**
* 客户端 session 管理,包括 sub 和 pub
......
......@@ -18,6 +18,7 @@ package net.dreamlu.iot.mqtt.core.client;
import net.dreamlu.iot.mqtt.codec.*;
import net.dreamlu.iot.mqtt.core.common.MqttPendingPublish;
import net.dreamlu.iot.mqtt.core.util.TopicUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.client.ClientChannelContext;
......@@ -185,6 +186,8 @@ public final class MqttClient {
// 1. 先判断是否已经订阅过,重复订阅,直接跳出
List<MqttClientSubscription> needSubscriptionList = new ArrayList<>();
for (MqttClientSubscription subscription : subscriptionList) {
// 校验 topicFilter
TopicUtil.validateTopicFilter(subscription.getTopicFilter());
boolean subscribed = clientSession.isSubscribed(subscription);
if (!subscribed) {
needSubscriptionList.add(subscription);
......@@ -229,6 +232,8 @@ public final class MqttClient {
* @return MqttClient
*/
public MqttClient unSubscribe(List<String> topicFilters) {
// 校验 topicFilter
TopicUtil.validateTopicFilter(topicFilters);
int messageId = messageIdGenerator.getId();
MqttUnsubscribeMessage message = MqttMessageBuilders.unsubscribe()
.addTopicFilters(topicFilters)
......@@ -363,6 +368,9 @@ public final class MqttClient {
* @return 是否发送成功
*/
public boolean publish(String topic, ByteBuffer payload, MqttQoS qos, Consumer<MqttMessageBuilders.PublishBuilder> builder) {
// 校验 topic
TopicUtil.validateTopicName(topic);
// qos 判断
boolean isHighLevelQoS = MqttQoS.AT_LEAST_ONCE == qos || MqttQoS.EXACTLY_ONCE == qos;
int messageId = isHighLevelQoS ? messageIdGenerator.getId() : -1;
if (payload == null) {
......
......@@ -26,10 +26,10 @@ import net.dreamlu.iot.mqtt.core.server.model.Message;
import net.dreamlu.iot.mqtt.core.server.model.Subscribe;
import net.dreamlu.iot.mqtt.core.server.session.IMqttSessionManager;
import net.dreamlu.iot.mqtt.core.server.store.IMqttMessageStore;
import net.dreamlu.iot.mqtt.core.util.TopicUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.core.ChannelContext;
import org.tio.core.Node;
import org.tio.core.Tio;
import org.tio.server.ServerTioConfig;
import org.tio.server.TioServer;
......@@ -154,6 +154,9 @@ public final class MqttServer {
* @return 是否发送成功
*/
public boolean publish(String clientId, String topic, ByteBuffer payload, MqttQoS qos, boolean retain) {
// 校验 topic
TopicUtil.validateTopicName(topic);
// 获取 context
ChannelContext context = Tio.getByBsId(getServerConfig(), clientId);
if (context == null || context.isClosed) {
logger.warn("Mqtt Topic:{} publish to clientId:{} ChannelContext is null may be disconnected.", topic, clientId);
......@@ -165,8 +168,7 @@ public final class MqttServer {
return false;
}
MqttQoS mqttQoS = qos.value() > subMqttQoS ? MqttQoS.valueOf(subMqttQoS) : qos;
publish(context, clientId, topic, payload, mqttQoS, retain);
return true;
return publish(context, clientId, topic, payload, mqttQoS, retain);
}
/**
......@@ -180,7 +182,7 @@ public final class MqttServer {
* @param retain 是否在服务器上保留消息
* @return 是否发送成功
*/
public boolean publish(ChannelContext context, String clientId, String topic, ByteBuffer payload, MqttQoS qos, boolean retain) {
private boolean publish(ChannelContext context, String clientId, String topic, ByteBuffer payload, MqttQoS qos, boolean retain) {
boolean isHighLevelQoS = MqttQoS.AT_LEAST_ONCE == qos || MqttQoS.EXACTLY_ONCE == qos;
int messageId = isHighLevelQoS ? sessionManager.getMessageId(clientId) : -1;
// 下行 payload 为空时,构造一个空结构体
......@@ -254,6 +256,8 @@ public final class MqttServer {
* @return 是否发送成功
*/
public boolean publishAll(String topic, ByteBuffer payload, MqttQoS qos, boolean retain) {
// 校验 topic
TopicUtil.validateTopicName(topic);
// 查找订阅该 topic 的客户端
List<Subscribe> subscribeList = sessionManager.searchSubscribe(topic);
if (subscribeList.isEmpty()) {
......
......@@ -20,7 +20,10 @@ import net.dreamlu.iot.mqtt.core.server.http.api.code.ResultCode;
import net.dreamlu.iot.mqtt.core.server.http.api.result.Result;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.http.common.*;
import org.tio.http.common.HttpConfig;
import org.tio.http.common.HttpRequest;
import org.tio.http.common.HttpResponse;
import org.tio.http.common.RequestLine;
import org.tio.http.common.handler.HttpRequestHandler;
import java.util.List;
......
......@@ -34,6 +34,7 @@ import net.dreamlu.iot.mqtt.core.server.event.IMqttSessionListener;
import net.dreamlu.iot.mqtt.core.server.model.Message;
import net.dreamlu.iot.mqtt.core.server.session.IMqttSessionManager;
import net.dreamlu.iot.mqtt.core.server.store.IMqttMessageStore;
import net.dreamlu.iot.mqtt.core.util.TopicUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.core.ChannelContext;
......@@ -334,6 +335,8 @@ public class DefaultMqttServerProcessor implements MqttServerProcessor {
boolean enableSubscribeValidator = subscribeValidator != null;
for (MqttTopicSubscription subscription : topicSubscriptionList) {
String topicFilter = subscription.topicName();
// 校验 topicFilter 是否合法
TopicUtil.validateTopicFilter(topicFilter);
MqttQoS mqttQoS = subscription.qualityOfService();
// 校验是否可以订阅
if (enableSubscribeValidator && !subscribeValidator.verifyTopicFilter(context, clientId, topicFilter, mqttQoS)) {
......
......@@ -16,6 +16,10 @@
package net.dreamlu.iot.mqtt.core.util;
import net.dreamlu.iot.mqtt.codec.MqttCodecUtil;
import java.util.List;
/**
* Mqtt Topic 工具
*
......@@ -23,6 +27,62 @@ package net.dreamlu.iot.mqtt.core.util;
*/
public final class TopicUtil {
/**
* 校验 topicFilter
*
* @param topicFilterList topicFilter 集合
*/
public static void validateTopicFilter(List<String> topicFilterList) {
for (String topicFilter : topicFilterList) {
validateTopicFilter(topicFilter);
}
}
/**
* 校验 topicFilter
*
* @param topicFilter topicFilter
*/
public static void validateTopicFilter(String topicFilter) {
if (topicFilter == null || topicFilter.isEmpty()) {
throw new IllegalArgumentException("TopicFilter is blank:" + topicFilter);
}
char[] topicFilterChars = topicFilter.toCharArray();
int topicFilterLength = topicFilterChars.length;
int topicFilterIdxEnd = topicFilterLength - 1;
char ch;
for (int i = 0; i < topicFilterLength; i++) {
ch = topicFilterChars[i];
if (Character.isWhitespace(ch)) {
throw new IllegalArgumentException("Mqtt subscribe topicFilter has white space:" + topicFilter);
} else if (ch == MqttCodecUtil.TOPIC_WILDCARDS_MORE) {
// 校验: # 通配符只能在最后一位
if (i < topicFilterIdxEnd) {
throw new IllegalArgumentException("Mqtt subscribe topicFilter illegal:" + topicFilter);
}
} else if (ch == MqttCodecUtil.TOPIC_WILDCARDS_ONE) {
// 校验: 单独 + 是允许的,判断 + 号前一位是否为 /
if (i > 0 && topicFilterChars[i - 1] != '/') {
throw new IllegalArgumentException("Mqtt subscribe topicFilter illegal:" + topicFilter);
}
}
}
}
/**
* 校验 topicName
*
* @param topicName topicName
*/
public static void validateTopicName(String topicName) throws IllegalArgumentException {
if (topicName == null || topicName.isEmpty()) {
throw new IllegalArgumentException("Topic is blank:" + topicName);
}
if (MqttCodecUtil.isTopicFilter(topicName)) {
throw new IllegalArgumentException("Topic has wildcards char [+] or [#], topicName:" + topicName);
}
}
/**
* 判断 topicFilter topicName 是否匹配
*
......@@ -42,13 +102,13 @@ public final class TopicUtil {
boolean inLayerWildcard = false;
for (int i = 0; i < topicFilterLength; i++) {
ch = topicFilterChars[i];
if (ch == '#') {
if (ch == MqttCodecUtil.TOPIC_WILDCARDS_MORE) {
// 校验: # 通配符只能在最后一位
if (i < topicFilterIdxEnd) {
throw new IllegalArgumentException("Mqtt subscribe topicFilter illegal:" + topicFilter);
}
return true;
} else if (ch == '+') {
} else if (ch == MqttCodecUtil.TOPIC_WILDCARDS_ONE) {
// 校验: 单独 + 是允许的,判断 + 号前一位是否为 /
if (i > 0 && topicFilterChars[i - 1] != '/') {
throw new IllegalArgumentException("Mqtt subscribe topicFilter illegal:" + topicFilter);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册