diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java new file mode 100644 index 0000000000000000000000000000000000000000..6dc4ed8d33fedd2820b603bfcb87d046b68d5350 --- /dev/null +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPullConsumerTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.client.consumer; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.apache.rocketmq.client.ClientConfig; +import org.apache.rocketmq.client.impl.CommunicationMode; +import org.apache.rocketmq.client.impl.FindBrokerResult; +import org.apache.rocketmq.client.impl.MQClientAPIImpl; +import org.apache.rocketmq.client.impl.MQClientManager; +import org.apache.rocketmq.client.impl.consumer.PullAPIWrapper; +import org.apache.rocketmq.client.impl.consumer.PullResultExt; +import org.apache.rocketmq.client.impl.factory.MQClientInstance; +import org.apache.rocketmq.common.message.MessageExt; +import org.apache.rocketmq.common.message.MessageQueue; +import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class DefaultMQPullConsumerTest { + @Spy + private MQClientInstance mQClientFactory = MQClientManager.getInstance().getAndCreateMQClientInstance(new ClientConfig()); + @Mock + private MQClientAPIImpl mQClientAPIImpl; + private DefaultMQPullConsumer pullConsumer; + private String consumerGroup = "FooBarGroup"; + private String topic = "FooBar"; + private String brokerName = "BrokerA"; + + @Before + public void init() throws Exception { + pullConsumer = new DefaultMQPullConsumer(consumerGroup); + pullConsumer.setNamesrvAddr("127.0.0.1:9876"); + pullConsumer.start(); + PullAPIWrapper pullAPIWrapper = pullConsumer.getDefaultMQPullConsumerImpl().getPullAPIWrapper(); + Field field = PullAPIWrapper.class.getDeclaredField("mQClientFactory"); + field.setAccessible(true); + field.set(pullAPIWrapper, mQClientFactory); + + field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl"); + field.setAccessible(true); + field.set(mQClientFactory, mQClientAPIImpl); + + when(mQClientFactory.findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean())).thenReturn(new FindBrokerResult("127.0.0.1:10911", false)); + } + + @After + public void terminate() { + pullConsumer.shutdown(); + } + + @Test + public void testPullMessage_Success() throws Exception { + doAnswer(new Answer() { + @Override public Object answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + return createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(new MessageExt())); + } + }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + + MessageQueue messageQueue = new MessageQueue(topic, brokerName, 0); + PullResult pullResult = pullConsumer.pull(messageQueue, "*", 1024, 3); + assertThat(pullResult).isNotNull(); + assertThat(pullResult.getPullStatus()).isEqualTo(PullStatus.FOUND); + assertThat(pullResult.getNextBeginOffset()).isEqualTo(1024 + 1); + assertThat(pullResult.getMinOffset()).isEqualTo(123); + assertThat(pullResult.getMaxOffset()).isEqualTo(2048); + assertThat(pullResult.getMsgFoundList()).isEqualTo(new ArrayList<>()); + } + + @Test + public void testPullMessage_NotFound() throws Exception{ + doAnswer(new Answer() { + @Override public Object answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + return createPullResult(requestHeader, PullStatus.NO_NEW_MSG, new ArrayList()); + } + }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + + MessageQueue messageQueue = new MessageQueue(topic, brokerName, 0); + PullResult pullResult = pullConsumer.pull(messageQueue, "*", 1024, 3); + assertThat(pullResult.getPullStatus()).isEqualTo(PullStatus.NO_NEW_MSG); + } + + @Test + public void testPullMessageAsync_Success() throws Exception { + doAnswer(new Answer() { + @Override public Object answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(new MessageExt())); + + PullCallback pullCallback = mock.getArgument(4); + pullCallback.onSuccess(pullResult); + return null; + } + }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + + MessageQueue messageQueue = new MessageQueue(topic, brokerName, 0); + pullConsumer.pull(messageQueue, "*", 1024, 3, new PullCallback() { + @Override public void onSuccess(PullResult pullResult) { + assertThat(pullResult).isNotNull(); + assertThat(pullResult.getPullStatus()).isEqualTo(PullStatus.FOUND); + assertThat(pullResult.getNextBeginOffset()).isEqualTo(1024 + 1); + assertThat(pullResult.getMinOffset()).isEqualTo(123); + assertThat(pullResult.getMaxOffset()).isEqualTo(2048); + assertThat(pullResult.getMsgFoundList()).isEqualTo(new ArrayList<>()); + } + + @Override public void onException(Throwable e) { + + } + }); + } + + private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, List messageExtList) { + return new PullResultExt(pullStatus, requestHeader.getQueueOffset() + messageExtList.size(), 123, 2048, messageExtList, 0, new byte[] {}); + } +} \ No newline at end of file diff --git a/client/src/test/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerTest.java b/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java similarity index 97% rename from client/src/test/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerTest.java rename to client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java index 5946ea838554e43f57540149680da0cc20018caf..bbaff003eef8143a5b4357071979a146559fb942 100644 --- a/client/src/test/java/org/apache/rocketmq/client/impl/producer/DefaultMQProducerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.rocketmq.client.impl.producer; +package org.apache.rocketmq.client.producer; import java.lang.reflect.Field; import java.util.ArrayList; @@ -31,10 +31,8 @@ import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.MQClientAPIImpl; import org.apache.rocketmq.client.impl.MQClientManager; import org.apache.rocketmq.client.impl.factory.MQClientInstance; -import org.apache.rocketmq.client.producer.DefaultMQProducer; -import org.apache.rocketmq.client.producer.SendCallback; -import org.apache.rocketmq.client.producer.SendResult; -import org.apache.rocketmq.client.producer.SendStatus; +import org.apache.rocketmq.client.impl.producer.DefaultMQProducerImpl; +import org.apache.rocketmq.client.impl.producer.TopicPublishInfo; import org.apache.rocketmq.common.message.Message; import org.apache.rocketmq.common.protocol.header.SendMessageRequestHeader; import org.apache.rocketmq.common.protocol.route.BrokerData;