提交 a3aa00a2 编写于 作者: Y yukon

[ROCKETMQ-52] Resolve infinite loop issue in rocketmq-client UT

上级 a3aff81e
...@@ -24,7 +24,7 @@ import java.util.HashSet; ...@@ -24,7 +24,7 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import org.apache.rocketmq.client.ClientConfig; import java.util.concurrent.TimeUnit;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyContext;
...@@ -35,7 +35,6 @@ import org.apache.rocketmq.client.exception.MQBrokerException; ...@@ -35,7 +35,6 @@ import org.apache.rocketmq.client.exception.MQBrokerException;
import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.CommunicationMode;
import org.apache.rocketmq.client.impl.FindBrokerResult; import org.apache.rocketmq.client.impl.FindBrokerResult;
import org.apache.rocketmq.client.impl.MQClientAPIImpl; import org.apache.rocketmq.client.impl.MQClientAPIImpl;
import org.apache.rocketmq.client.impl.MQClientManager;
import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService;
import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService;
import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl;
...@@ -50,7 +49,6 @@ import org.apache.rocketmq.common.message.MessageClientExt; ...@@ -50,7 +49,6 @@ import org.apache.rocketmq.common.message.MessageClientExt;
import org.apache.rocketmq.common.message.MessageDecoder; import org.apache.rocketmq.common.message.MessageDecoder;
import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.common.message.MessageQueue; import org.apache.rocketmq.common.message.MessageQueue;
import org.apache.rocketmq.common.protocol.body.LockBatchRequestBody;
import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.remoting.exception.RemotingException; import org.apache.rocketmq.remoting.exception.RemotingException;
import org.junit.After; import org.junit.After;
...@@ -58,7 +56,6 @@ import org.junit.Before; ...@@ -58,7 +56,6 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
...@@ -69,17 +66,16 @@ import static org.mockito.ArgumentMatchers.anyBoolean; ...@@ -69,17 +66,16 @@ import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class DefaultMQPushConsumerTest { public class DefaultMQPushConsumerTest {
private String consumerGroup; private String consumerGroup;
private String topic = "FooBar"; private String topic = "FooBar";
private String brokerName = "BrokerA"; private String brokerName = "BrokerA";
@Spy private MQClientInstance mQClientFactory;
private MQClientInstance mQClientFactory = MQClientManager.getInstance().getAndCreateMQClientInstance(new ClientConfig());
@Mock @Mock
private MQClientAPIImpl mQClientAPIImpl; private MQClientAPIImpl mQClientAPIImpl;
private PullAPIWrapper pullAPIWrapper; private PullAPIWrapper pullAPIWrapper;
...@@ -89,7 +85,6 @@ public class DefaultMQPushConsumerTest { ...@@ -89,7 +85,6 @@ public class DefaultMQPushConsumerTest {
@Before @Before
public void init() throws Exception { public void init() throws Exception {
consumerGroup = "FooBarGroup" + System.currentTimeMillis(); consumerGroup = "FooBarGroup" + System.currentTimeMillis();
pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false));
pushConsumer = new DefaultMQPushConsumer(consumerGroup); pushConsumer = new DefaultMQPushConsumer(consumerGroup);
pushConsumer.setNamesrvAddr("127.0.0.1:9876"); pushConsumer.setNamesrvAddr("127.0.0.1:9876");
pushConsumer.setPullInterval(60 * 1000); pushConsumer.setPullInterval(60 * 1000);
...@@ -106,10 +101,10 @@ public class DefaultMQPushConsumerTest { ...@@ -106,10 +101,10 @@ public class DefaultMQPushConsumerTest {
Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl");
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, rebalancePushImpl); field.set(pushConsumerImpl, rebalancePushImpl);
pushConsumer.subscribe(topic, "*"); pushConsumer.subscribe(topic, "*");
pushConsumer.start(); pushConsumer.start();
mQClientFactory = spy(pushConsumerImpl.getmQClientFactory());
field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory");
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, mQClientFactory); field.set(pushConsumerImpl, mQClientFactory);
...@@ -119,15 +114,17 @@ public class DefaultMQPushConsumerTest { ...@@ -119,15 +114,17 @@ public class DefaultMQPushConsumerTest {
field.setAccessible(true); field.setAccessible(true);
field.set(mQClientFactory, mQClientAPIImpl); field.set(mQClientFactory, mQClientAPIImpl);
pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false));
field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper"); field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper");
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, pullAPIWrapper); field.set(pushConsumerImpl, pullAPIWrapper);
pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory); pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory);
mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl);
mQClientFactory.start();
doAnswer(new Answer() { when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<Object>() {
@Override public Object answer(InvocationOnMock mock) throws Throwable { @Override public Object answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1); PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt(); MessageClientExt messageClientExt = new MessageClientExt();
...@@ -140,17 +137,15 @@ public class DefaultMQPushConsumerTest { ...@@ -140,17 +137,15 @@ public class DefaultMQPushConsumerTest {
messageClientExt.setStoreHost(new InetSocketAddress(8080)); messageClientExt.setStoreHost(new InetSocketAddress(8080));
PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.<MessageExt>singletonList(messageClientExt)); PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.<MessageExt>singletonList(messageClientExt));
((PullCallback)mock.getArgument(4)).onSuccess(pullResult); ((PullCallback)mock.getArgument(4)).onSuccess(pullResult);
return null; return pullResult;
} }
}).when(mQClientAPIImpl).pullMessage(anyString(), any(PullMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)); });
doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean());
doReturn(Collections.singletonList(mQClientFactory.getClientId())).when(mQClientFactory).findConsumerIdList(anyString(), anyString()); doReturn(Collections.singletonList(mQClientFactory.getClientId())).when(mQClientFactory).findConsumerIdList(anyString(), anyString());
Set<MessageQueue> messageQueueSet = new HashSet<>(); Set<MessageQueue> messageQueueSet = new HashSet<>();
messageQueueSet.add(createPullRequest().getMessageQueue()); messageQueueSet.add(createPullRequest().getMessageQueue());
pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet); pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet);
doReturn(messageQueueSet).when(mQClientAPIImpl).lockBatchMQ(anyString(), any(LockBatchRequestBody.class), anyLong());
doReturn(123L).when(rebalancePushImpl).computePullFromWhere(any(MessageQueue.class)); doReturn(123L).when(rebalancePushImpl).computePullFromWhere(any(MessageQueue.class));
} }
...@@ -180,23 +175,26 @@ public class DefaultMQPushConsumerTest { ...@@ -180,23 +175,26 @@ public class DefaultMQPushConsumerTest {
} }
@Test @Test
public void testPullMessage_SuccessWithOrderlyService() throws InterruptedException, RemotingException, MQBrokerException { public void testPullMessage_SuccessWithOrderlyService() throws Exception {
final CountDownLatch countDownLatch = new CountDownLatch(1); final CountDownLatch countDownLatch = new CountDownLatch(1);
final MessageExt[] messageExts = new MessageExt[1]; final MessageExt[] messageExts = new MessageExt[1];
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageOrderlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerOrderly() {
MessageListenerOrderly listenerOrderly = new MessageListenerOrderly() {
@Override @Override
public ConsumeOrderlyStatus consumeMessage(List<MessageExt> msgs, ConsumeOrderlyContext context) { public ConsumeOrderlyStatus consumeMessage(List<MessageExt> msgs, ConsumeOrderlyContext context) {
messageExts[0] = msgs.get(0); messageExts[0] = msgs.get(0);
countDownLatch.countDown(); countDownLatch.countDown();
return null; return null;
} }
})); };
pushConsumer.registerMessageListener(listenerOrderly);
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageOrderlyService(pushConsumer.getDefaultMQPushConsumerImpl(), listenerOrderly));
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeOrderly(true);
pushConsumer.getDefaultMQPushConsumerImpl().doRebalance(); pushConsumer.getDefaultMQPushConsumerImpl().doRebalance();
PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest()); pullMessageService.executePullRequestLater(createPullRequest(), 100);
countDownLatch.await(); countDownLatch.await(10, TimeUnit.SECONDS);
assertThat(messageExts[0].getTopic()).isEqualTo(topic); assertThat(messageExts[0].getTopic()).isEqualTo(topic);
assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'});
} }
...@@ -211,7 +209,10 @@ public class DefaultMQPushConsumerTest { ...@@ -211,7 +209,10 @@ public class DefaultMQPushConsumerTest {
messageQueue.setQueueId(0); messageQueue.setQueueId(0);
messageQueue.setTopic(topic); messageQueue.setTopic(topic);
pullRequest.setMessageQueue(messageQueue); pullRequest.setMessageQueue(messageQueue);
pullRequest.setProcessQueue(new ProcessQueue()); ProcessQueue processQueue = new ProcessQueue();
processQueue.setLocked(true);
processQueue.setLastLockTimestamp(System.currentTimeMillis());
pullRequest.setProcessQueue(processQueue);
return pullRequest; return pullRequest;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册