diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java index 6dc4ed8d33fedd2820b603bfcb87d046b68d5350..6672b1e88d6d9aa5a7fc61a7d76d07cda866bc67 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java @@ -146,7 +146,7 @@ public class DefaultMQPullConsumerTest { }); } - private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, List messageExtList) { + private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, List messageExtList) throws Exception { return new PullResultExt(pullStatus, requestHeader.getQueueOffset() + messageExtList.size(), 123, 2048, messageExtList, 0, new byte[] {}); } } \ No newline at end of file diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java new file mode 100644 index 0000000000000000000000000000000000000000..7f5c87417b13d9033e4961261a29a22e2d8f44c4 --- /dev/null +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.client.consumer; + +import java.io.ByteArrayOutputStream; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import org.apache.rocketmq.client.ClientConfig; +import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; +import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; +import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently; +import org.apache.rocketmq.client.exception.MQBrokerException; +import org.apache.rocketmq.client.impl.CommunicationMode; +import org.apache.rocketmq.client.impl.FindBrokerResult; +import org.apache.rocketmq.client.impl.MQClientAPIImpl; +import org.apache.rocketmq.client.impl.MQClientManager; +import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; +import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; +import org.apache.rocketmq.client.impl.consumer.ProcessQueue; +import org.apache.rocketmq.client.impl.consumer.PullAPIWrapper; +import org.apache.rocketmq.client.impl.consumer.PullMessageService; +import org.apache.rocketmq.client.impl.consumer.PullRequest; +import org.apache.rocketmq.client.impl.consumer.PullResultExt; +import org.apache.rocketmq.client.impl.factory.MQClientInstance; +import org.apache.rocketmq.common.message.MessageClientExt; +import org.apache.rocketmq.common.message.MessageDecoder; +import org.apache.rocketmq.common.message.MessageExt; +import org.apache.rocketmq.common.message.MessageQueue; +import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; +import org.apache.rocketmq.remoting.exception.RemotingException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +import static org.apache.rocketmq.client.producer.DefaultMQProducerTest.createTopicRoute; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class DefaultMQPushConsumerTest { + private String consumerGroup; + private String topic = "FooBar"; + private String brokerName = "BrokerA"; + @Spy + private MQClientInstance mQClientFactory = MQClientManager.getInstance().getAndCreateMQClientInstance(new ClientConfig()); + @Mock + private MQClientAPIImpl mQClientAPIImpl; + private PullAPIWrapper pullAPIWrapper; + private DefaultMQPushConsumer pushConsumer; + + @Before + public void init() throws Exception { + consumerGroup = "FooBarGroup" + System.currentTimeMillis(); + pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false)); + pushConsumer = new DefaultMQPushConsumer(consumerGroup); + pushConsumer.setNamesrvAddr("127.0.0.1:9876"); + pushConsumer.setPullInterval(60 * 1000); + + pushConsumer.registerMessageListener(new MessageListenerConcurrently() { + @Override public ConsumeConcurrentlyStatus consumeMessage(List msgs, + ConsumeConcurrentlyContext context) { + return null; + } + }); + pushConsumer.subscribe(topic, "*"); + pushConsumer.start(); + + DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); + Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); + field.setAccessible(true); + field.set(pushConsumerImpl, mQClientFactory); + + field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl"); + field.setAccessible(true); + field.set(mQClientFactory, mQClientAPIImpl); + + field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper"); + field.setAccessible(true); + field.set(pushConsumerImpl, pullAPIWrapper); + + when(mQClientFactory.findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean())).thenReturn(new FindBrokerResult("127.0.0.1:10911", false)); + mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); + mQClientFactory.start(); + when(mQClientAPIImpl.getTopicRouteInfoFromNameServer(anyString(), anyLong())).thenReturn(createTopicRoute()); + } + + @After + public void terminate() { + pushConsumer.shutdown(); + } + + @Test + public void testPullMessage_Success() throws InterruptedException, RemotingException, MQBrokerException { + final CountDownLatch countDownLatch = new CountDownLatch(1); + final MessageExt[] messageExts = new MessageExt[1]; + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { + @Override public ConsumeConcurrentlyStatus consumeMessage(List msgs, + ConsumeConcurrentlyContext context) { + messageExts[0] = msgs.get(0); + countDownLatch.countDown(); + return null; + } + })); + doAnswer(new Answer() { + @Override public Object answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + MessageClientExt messageClientExt = new MessageClientExt(); + messageClientExt.setTopic(topic); + messageClientExt.setQueueId(0); + messageClientExt.setMsgId("123"); + messageClientExt.setBody(new byte[] {'a'}); + messageClientExt.setOffsetMsgId("234"); + messageClientExt.setBornHost(new InetSocketAddress(8080)); + messageClientExt.setStoreHost(new InetSocketAddress(8080)); + PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback)mock.getArgument(4)).onSuccess(pullResult); + return null; + } + }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + + PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); + pullMessageService.executePullRequestImmediately(createPullRequest()); + countDownLatch.await(); + assertThat(messageExts[0].getTopic()).isEqualTo(topic); + assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); + } + + private PullRequest createPullRequest() { + PullRequest pullRequest = new PullRequest(); + pullRequest.setConsumerGroup(consumerGroup); + pullRequest.setNextOffset(1024); + + MessageQueue messageQueue = new MessageQueue(); + messageQueue.setBrokerName(brokerName); + messageQueue.setQueueId(0); + messageQueue.setTopic(topic); + pullRequest.setMessageQueue(messageQueue); + pullRequest.setProcessQueue(new ProcessQueue()); + + return pullRequest; + } + + private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, List messageExtList) throws Exception { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + for (MessageExt messageExt : messageExtList) { + outputStream.write(MessageDecoder.encode(messageExt, false)); + } + return new PullResultExt(pullStatus, requestHeader.getQueueOffset() + messageExtList.size(), 123, 2048, messageExtList, 0, outputStream.toByteArray()); + } +} \ No newline at end of file diff --git a/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java b/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java index bbaff003eef8143a5b4357071979a146559fb942..eddd3806c8513151d4033e369030344ec9dd9af2 100644 --- a/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java @@ -190,7 +190,7 @@ public class DefaultMQProducerTest { } } - private TopicRouteData createTopicRoute() { + public static TopicRouteData createTopicRoute() { TopicRouteData topicRouteData = new TopicRouteData(); topicRouteData.setFilterServerTable(new HashMap>());