diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java index 048e456326246429018e418c5afd41dc6e02501c..2e0af5affdd074ae9a94a1c52eb1a07019fcb7d5 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java @@ -24,7 +24,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; -import org.apache.rocketmq.client.ClientConfig; +import java.util.concurrent.TimeUnit; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyContext; @@ -35,7 +35,6 @@ import org.apache.rocketmq.client.exception.MQBrokerException; import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.FindBrokerResult; import org.apache.rocketmq.client.impl.MQClientAPIImpl; -import org.apache.rocketmq.client.impl.MQClientManager; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; @@ -50,7 +49,6 @@ import org.apache.rocketmq.common.message.MessageClientExt; import org.apache.rocketmq.common.message.MessageDecoder; import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.message.MessageQueue; -import org.apache.rocketmq.common.protocol.body.LockBatchRequestBody; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.remoting.exception.RemotingException; import org.junit.After; @@ -58,7 +56,6 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.Spy; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer; @@ -69,17 +66,16 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class DefaultMQPushConsumerTest { private String consumerGroup; private String topic = "FooBar"; private String brokerName = "BrokerA"; - @Spy - private MQClientInstance mQClientFactory = MQClientManager.getInstance().getAndCreateMQClientInstance(new ClientConfig()); + private MQClientInstance mQClientFactory; @Mock private MQClientAPIImpl mQClientAPIImpl; private PullAPIWrapper pullAPIWrapper; @@ -89,7 +85,6 @@ public class DefaultMQPushConsumerTest { @Before public void init() throws Exception { consumerGroup = "FooBarGroup" + System.currentTimeMillis(); - pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false)); pushConsumer = new DefaultMQPushConsumer(consumerGroup); pushConsumer.setNamesrvAddr("127.0.0.1:9876"); pushConsumer.setPullInterval(60 * 1000); @@ -106,10 +101,10 @@ public class DefaultMQPushConsumerTest { Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); field.setAccessible(true); field.set(pushConsumerImpl, rebalancePushImpl); - pushConsumer.subscribe(topic, "*"); pushConsumer.start(); + mQClientFactory = spy(pushConsumerImpl.getmQClientFactory()); field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); field.setAccessible(true); field.set(pushConsumerImpl, mQClientFactory); @@ -119,15 +114,17 @@ public class DefaultMQPushConsumerTest { field.setAccessible(true); field.set(mQClientFactory, mQClientAPIImpl); + pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false)); field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper"); field.setAccessible(true); field.set(pushConsumerImpl, pullAPIWrapper); pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory); mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); - mQClientFactory.start(); - doAnswer(new Answer() { + when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), + anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock mock) throws Throwable { PullMessageRequestHeader requestHeader = mock.getArgument(1); MessageClientExt messageClientExt = new MessageClientExt(); @@ -140,17 +137,15 @@ public class DefaultMQPushConsumerTest { messageClientExt.setStoreHost(new InetSocketAddress(8080)); PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); ((PullCallback)mock.getArgument(4)).onSuccess(pullResult); - return null; + return pullResult; } - }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + }); doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); doReturn(Collections.singletonList(mQClientFactory.getClientId())).when(mQClientFactory).findConsumerIdList(anyString(), anyString()); Set messageQueueSet = new HashSet<>(); messageQueueSet.add(createPullRequest().getMessageQueue()); pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet); - doReturn(messageQueueSet).when(mQClientAPIImpl).lockBatchMQ(anyString(), any(LockBatchRequestBody.class), anyLong()); - doReturn(123L).when(rebalancePushImpl).computePullFromWhere(any(MessageQueue.class)); } @@ -180,23 +175,26 @@ public class DefaultMQPushConsumerTest { } @Test - public void testPullMessage_SuccessWithOrderlyService() throws InterruptedException, RemotingException, MQBrokerException { + public void testPullMessage_SuccessWithOrderlyService() throws Exception { final CountDownLatch countDownLatch = new CountDownLatch(1); final MessageExt[] messageExts = new MessageExt[1]; - pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageOrderlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerOrderly() { + + MessageListenerOrderly listenerOrderly = new MessageListenerOrderly() { @Override public ConsumeOrderlyStatus consumeMessage(List msgs, ConsumeOrderlyContext context) { messageExts[0] = msgs.get(0); countDownLatch.countDown(); return null; } - })); + }; + pushConsumer.registerMessageListener(listenerOrderly); + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageOrderlyService(pushConsumer.getDefaultMQPushConsumerImpl(), listenerOrderly)); + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeOrderly(true); pushConsumer.getDefaultMQPushConsumerImpl().doRebalance(); - PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); - pullMessageService.executePullRequestImmediately(createPullRequest()); + pullMessageService.executePullRequestLater(createPullRequest(), 100); - countDownLatch.await(); + countDownLatch.await(10, TimeUnit.SECONDS); assertThat(messageExts[0].getTopic()).isEqualTo(topic); assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); } @@ -211,7 +209,10 @@ public class DefaultMQPushConsumerTest { messageQueue.setQueueId(0); messageQueue.setTopic(topic); pullRequest.setMessageQueue(messageQueue); - pullRequest.setProcessQueue(new ProcessQueue()); + ProcessQueue processQueue = new ProcessQueue(); + processQueue.setLocked(true); + processQueue.setLastLockTimestamp(System.currentTimeMillis()); + pullRequest.setProcessQueue(processQueue); return pullRequest; }