diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java index c2c60824ec80e9f4f3e2eac99c95704739349ecd..20efd6647819f893328636cb8fc71a4583fe80e5 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java @@ -16,17 +16,6 @@ */ package org.apache.rocketmq.client.consumer; -import java.io.ByteArrayOutputStream; -import java.net.InetSocketAddress; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; @@ -48,6 +37,7 @@ import org.apache.rocketmq.client.impl.consumer.PullMessageService; import org.apache.rocketmq.client.impl.consumer.PullRequest; import org.apache.rocketmq.client.impl.consumer.PullResultExt; import org.apache.rocketmq.client.impl.consumer.RebalanceImpl; +import org.apache.rocketmq.client.impl.consumer.RebalancePushImpl; import org.apache.rocketmq.client.impl.factory.MQClientInstance; import org.apache.rocketmq.common.message.MessageClientExt; import org.apache.rocketmq.common.message.MessageDecoder; @@ -66,6 +56,18 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.Answer; +import java.io.ByteArrayOutputStream; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Fail.failBecauseExceptionWasNotThrown; import static org.mockito.ArgumentMatchers.any; @@ -74,6 +76,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -87,6 +90,7 @@ public class DefaultMQPushConsumerTest { @Mock private MQClientAPIImpl mQClientAPIImpl; private RebalanceImpl rebalanceImpl; + private RebalancePushImpl rebalancePushImpl; private DefaultMQPushConsumer pushConsumer; @Before @@ -96,24 +100,24 @@ public class DefaultMQPushConsumerTest { factoryTable.clear(); when(mQClientAPIImpl.pullMessage(anyString(), any(PullMessageRequestHeader.class), - anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) - .thenAnswer(new Answer() { - @Override - public PullResult answer(InvocationOnMock mock) throws Throwable { - PullMessageRequestHeader requestHeader = mock.getArgument(1); - MessageClientExt messageClientExt = new MessageClientExt(); - messageClientExt.setTopic(topic); - messageClientExt.setQueueId(0); - messageClientExt.setMsgId("123"); - messageClientExt.setBody(new byte[] {'a'}); - messageClientExt.setOffsetMsgId("234"); - messageClientExt.setBornHost(new InetSocketAddress(8080)); - messageClientExt.setStoreHost(new InetSocketAddress(8080)); - PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); - ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); - return pullResult; - } - }); + anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + .thenAnswer(new Answer() { + @Override + public PullResult answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + MessageClientExt messageClientExt = new MessageClientExt(); + messageClientExt.setTopic(topic); + messageClientExt.setQueueId(0); + messageClientExt.setMsgId("123"); + messageClientExt.setBody(new byte[]{'a'}); + messageClientExt.setOffsetMsgId("234"); + messageClientExt.setBornHost(new InetSocketAddress(8080)); + messageClientExt.setStoreHost(new InetSocketAddress(8080)); + PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); + return pullResult; + } + }); consumerGroup = "FooBarGroup" + System.currentTimeMillis(); @@ -124,12 +128,13 @@ public class DefaultMQPushConsumerTest { pushConsumer.registerMessageListener(new MessageListenerConcurrently() { @Override public ConsumeConcurrentlyStatus consumeMessage(List msgs, - ConsumeConcurrentlyContext context) { + ConsumeConcurrentlyContext context) { return null; } }); DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); + rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl())); // suppress updateTopicRouteInfoFromNameServer pushConsumer.changeInstanceNameToPID(); @@ -142,7 +147,8 @@ public class DefaultMQPushConsumerTest { doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); rebalanceImpl = spy(pushConsumerImpl.getRebalanceImpl()); - doReturn(123L).when(rebalanceImpl).computePullFromWhere(any(MessageQueue.class)); + // doReturn(123L).when(rebalancePushImpl).computePullFromWhere(any(MessageQueue.class)); + doReturn(123L).when(rebalanceImpl).computePullFromWhereWithException(any(MessageQueue.class)); FieldUtils.writeDeclaredField(pushConsumerImpl, "rebalanceImpl", rebalanceImpl, true); Set messageQueueSet = new HashSet(); @@ -170,7 +176,7 @@ public class DefaultMQPushConsumerTest { pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { @Override public ConsumeConcurrentlyStatus consumeMessage(List msgs, - ConsumeConcurrentlyContext context) { + ConsumeConcurrentlyContext context) { messageAtomic.set(msgs.get(0)); countDownLatch.countDown(); return null; @@ -183,7 +189,7 @@ public class DefaultMQPushConsumerTest { MessageExt msg = messageAtomic.get(); assertThat(msg).isNotNull(); assertThat(msg.getTopic()).isEqualTo(topic); - assertThat(msg.getBody()).isEqualTo(new byte[] {'a'}); + assertThat(msg.getBody()).isEqualTo(new byte[]{'a'}); } @Test @@ -210,7 +216,7 @@ public class DefaultMQPushConsumerTest { MessageExt msg = messageAtomic.get(); assertThat(msg).isNotNull(); assertThat(msg.getTopic()).isEqualTo(topic); - assertThat(msg.getBody()).isEqualTo(new byte[] {'a'}); + assertThat(msg.getBody()).isEqualTo(new byte[]{'a'}); } @Test @@ -287,7 +293,7 @@ public class DefaultMQPushConsumerTest { pushConsumer.registerMessageListener(new MessageListenerConcurrently() { @Override public ConsumeConcurrentlyStatus consumeMessage(List msgs, - ConsumeConcurrentlyContext context) { + ConsumeConcurrentlyContext context) { return null; } }); @@ -313,11 +319,29 @@ public class DefaultMQPushConsumerTest { } private PullResultExt createPullResult(PullMessageRequestHeader requestHeader, PullStatus pullStatus, - List messageExtList) throws Exception { + List messageExtList) throws Exception { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); for (MessageExt messageExt : messageExtList) { outputStream.write(MessageDecoder.encode(messageExt, false)); } return new PullResultExt(pullStatus, requestHeader.getQueueOffset() + messageExtList.size(), 123, 2048, messageExtList, 0, outputStream.toByteArray()); } + + @Test + public void testPullMessage_ExceptionOccursWhenComputePullFromWhere() throws MQClientException { + doThrow(MQClientException.class).when(rebalancePushImpl).computePullFromWhereWithException(any(MessageQueue.class)); + final CountDownLatch countDownLatch = new CountDownLatch(1); + final MessageExt[] messageExts = new MessageExt[1]; + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService( + new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), + (msgs, context) -> { + messageExts[0] = msgs.get(0); + return null; + })); + + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeOrderly(true); + PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); + pullMessageService.executePullRequestImmediately(createPullRequest()); + assertThat(messageExts[0]).isNull(); + } }