diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java index ce8b4d3d1ff5e743571db9866969b6bb0e73a48c..048e456326246429018e418c5afd41dc6e02501c 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java @@ -20,29 +20,37 @@ import java.io.ByteArrayOutputStream; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.CountDownLatch; import org.apache.rocketmq.client.ClientConfig; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; +import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyContext; +import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyStatus; import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently; +import org.apache.rocketmq.client.consumer.listener.MessageListenerOrderly; import org.apache.rocketmq.client.exception.MQBrokerException; import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.FindBrokerResult; import org.apache.rocketmq.client.impl.MQClientAPIImpl; import org.apache.rocketmq.client.impl.MQClientManager; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; +import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; import org.apache.rocketmq.client.impl.consumer.ProcessQueue; import org.apache.rocketmq.client.impl.consumer.PullAPIWrapper; import org.apache.rocketmq.client.impl.consumer.PullMessageService; import org.apache.rocketmq.client.impl.consumer.PullRequest; import org.apache.rocketmq.client.impl.consumer.PullResultExt; +import org.apache.rocketmq.client.impl.consumer.RebalancePushImpl; import org.apache.rocketmq.client.impl.factory.MQClientInstance; import org.apache.rocketmq.common.message.MessageClientExt; import org.apache.rocketmq.common.message.MessageDecoder; import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.message.MessageQueue; +import org.apache.rocketmq.common.protocol.body.LockBatchRequestBody; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.remoting.exception.RemotingException; import org.junit.After; @@ -62,8 +70,8 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class DefaultMQPushConsumerTest { @@ -75,6 +83,7 @@ public class DefaultMQPushConsumerTest { @Mock private MQClientAPIImpl mQClientAPIImpl; private PullAPIWrapper pullAPIWrapper; + private RebalancePushImpl rebalancePushImpl; private DefaultMQPushConsumer pushConsumer; @Before @@ -91,14 +100,21 @@ public class DefaultMQPushConsumerTest { return null; } }); + + DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); + rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl())); + Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); + field.setAccessible(true); + field.set(pushConsumerImpl, rebalancePushImpl); + pushConsumer.subscribe(topic, "*"); pushConsumer.start(); - DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); - Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); + field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); field.setAccessible(true); field.set(pushConsumerImpl, mQClientFactory); + field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl"); field.setAccessible(true); field.set(mQClientFactory, mQClientAPIImpl); @@ -107,9 +123,35 @@ public class DefaultMQPushConsumerTest { field.setAccessible(true); field.set(pushConsumerImpl, pullAPIWrapper); - when(mQClientFactory.findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean())).thenReturn(new FindBrokerResult("127.0.0.1:10911", false)); + pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory); mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); mQClientFactory.start(); + + doAnswer(new Answer() { + @Override public Object answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + MessageClientExt messageClientExt = new MessageClientExt(); + messageClientExt.setTopic(topic); + messageClientExt.setQueueId(0); + messageClientExt.setMsgId("123"); + messageClientExt.setBody(new byte[] {'a'}); + messageClientExt.setOffsetMsgId("234"); + messageClientExt.setBornHost(new InetSocketAddress(8080)); + messageClientExt.setStoreHost(new InetSocketAddress(8080)); + PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback)mock.getArgument(4)).onSuccess(pullResult); + return null; + } + }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + + doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); + doReturn(Collections.singletonList(mQClientFactory.getClientId())).when(mQClientFactory).findConsumerIdList(anyString(), anyString()); + Set messageQueueSet = new HashSet<>(); + messageQueueSet.add(createPullRequest().getMessageQueue()); + pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet); + doReturn(messageQueueSet).when(mQClientAPIImpl).lockBatchMQ(anyString(), any(LockBatchRequestBody.class), anyLong()); + + doReturn(123L).when(rebalancePushImpl).computePullFromWhere(any(MessageQueue.class)); } @After @@ -129,25 +171,31 @@ public class DefaultMQPushConsumerTest { return null; } })); - doAnswer(new Answer() { - @Override public Object answer(InvocationOnMock mock) throws Throwable { - PullMessageRequestHeader requestHeader = mock.getArgument(1); - MessageClientExt messageClientExt = new MessageClientExt(); - messageClientExt.setTopic(topic); - messageClientExt.setQueueId(0); - messageClientExt.setMsgId("123"); - messageClientExt.setBody(new byte[] {'a'}); - messageClientExt.setOffsetMsgId("234"); - messageClientExt.setBornHost(new InetSocketAddress(8080)); - messageClientExt.setStoreHost(new InetSocketAddress(8080)); - PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); - ((PullCallback)mock.getArgument(4)).onSuccess(pullResult); + + PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); + pullMessageService.executePullRequestImmediately(createPullRequest()); + countDownLatch.await(); + assertThat(messageExts[0].getTopic()).isEqualTo(topic); + assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); + } + + @Test + public void testPullMessage_SuccessWithOrderlyService() throws InterruptedException, RemotingException, MQBrokerException { + final CountDownLatch countDownLatch = new CountDownLatch(1); + final MessageExt[] messageExts = new MessageExt[1]; + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageOrderlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerOrderly() { + @Override + public ConsumeOrderlyStatus consumeMessage(List msgs, ConsumeOrderlyContext context) { + messageExts[0] = msgs.get(0); + countDownLatch.countDown(); return null; } - }).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); + })); + pushConsumer.getDefaultMQPushConsumerImpl().doRebalance(); PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); pullMessageService.executePullRequestImmediately(createPullRequest()); + countDownLatch.await(); assertThat(messageExts[0].getTopic()).isEqualTo(topic); assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); diff --git a/common/src/test/java/org/apache/rocketmq/common/MixAllTest.java b/common/src/test/java/org/apache/rocketmq/common/MixAllTest.java index 3a40dd56a7e3a584da13ae66f984f6a764726e6e..06024c3557787482ec0aafc04db8c20e5c03cda6 100644 --- a/common/src/test/java/org/apache/rocketmq/common/MixAllTest.java +++ b/common/src/test/java/org/apache/rocketmq/common/MixAllTest.java @@ -59,7 +59,7 @@ public class MixAllTest { file.delete(); } file.createNewFile(); - try( PrintWriter out = new PrintWriter( fileName ) ){ + try (PrintWriter out = new PrintWriter(fileName)) { out.write("TestForMixAll"); } String string = MixAll.file2String(fileName);