未验证 提交 5d212571 编写于 作者: H Heng Du 提交者: GitHub

Merge pull request #2903 from ayanamist/fix-test

Fix unit test stability
...@@ -20,8 +20,12 @@ package org.apache.rocketmq.client.consumer; ...@@ -20,8 +20,12 @@ package org.apache.rocketmq.client.consumer;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.*; import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.rocketmq.client.ClientConfig; import org.apache.rocketmq.client.ClientConfig;
import org.apache.rocketmq.client.consumer.store.OffsetStore; import org.apache.rocketmq.client.consumer.store.OffsetStore;
import org.apache.rocketmq.client.consumer.store.ReadOffsetType; import org.apache.rocketmq.client.consumer.store.ReadOffsetType;
...@@ -46,17 +50,15 @@ import org.apache.rocketmq.common.message.MessageExt; ...@@ -46,17 +50,15 @@ import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.common.message.MessageQueue; import org.apache.rocketmq.common.message.MessageQueue;
import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.common.protocol.heartbeat.MessageModel; import org.apache.rocketmq.common.protocol.heartbeat.MessageModel;
import org.apache.rocketmq.remoting.RPCHook;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Spy; import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.failBecauseExceptionWasNotThrown; import static org.assertj.core.api.Fail.failBecauseExceptionWasNotThrown;
...@@ -70,9 +72,7 @@ import static org.mockito.Mockito.doReturn; ...@@ -70,9 +72,7 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest(DefaultLitePullConsumerImpl.class)
@PowerMockIgnore("javax.management.*")
public class DefaultLitePullConsumerTest { public class DefaultLitePullConsumerTest {
@Spy @Spy
private MQClientInstance mQClientFactory = MQClientManager.getInstance().getOrCreateMQClientInstance(new ClientConfig()); private MQClientInstance mQClientFactory = MQClientManager.getInstance().getOrCreateMQClientInstance(new ClientConfig());
...@@ -94,7 +94,10 @@ public class DefaultLitePullConsumerTest { ...@@ -94,7 +94,10 @@ public class DefaultLitePullConsumerTest {
@Before @Before
public void init() throws Exception { public void init() throws Exception {
PowerMockito.suppress(PowerMockito.method(DefaultLitePullConsumerImpl.class, "updateTopicSubscribeInfoWhenSubscriptionChanged")); ConcurrentMap<String, MQClientInstance> factoryTable = (ConcurrentMap<String, MQClientInstance>) FieldUtils.readDeclaredField(MQClientManager.getInstance(), "factoryTable", true);
factoryTable.forEach((s, instance) -> instance.shutdown());
factoryTable.clear();
Field field = MQClientInstance.class.getDeclaredField("rebalanceService"); Field field = MQClientInstance.class.getDeclaredField("rebalanceService");
field.setAccessible(true); field.setAccessible(true);
RebalanceService rebalanceService = (RebalanceService) field.get(mQClientFactory); RebalanceService rebalanceService = (RebalanceService) field.get(mQClientFactory);
...@@ -182,7 +185,9 @@ public class DefaultLitePullConsumerTest { ...@@ -182,7 +185,9 @@ public class DefaultLitePullConsumerTest {
when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L); when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L);
when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(500L); when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(500L);
MessageQueue messageQueue = createMessageQueue(); MessageQueue messageQueue = createMessageQueue();
litePullConsumer.assign(Collections.singletonList(messageQueue)); List<MessageQueue> messageQueues = Collections.singletonList(messageQueue);
litePullConsumer.assign(messageQueues);
litePullConsumer.pause(messageQueues);
long offset = litePullConsumer.committed(messageQueue); long offset = litePullConsumer.committed(messageQueue);
litePullConsumer.seek(messageQueue, offset); litePullConsumer.seek(messageQueue, offset);
Field field = DefaultLitePullConsumerImpl.class.getDeclaredField("assignedMessageQueue"); Field field = DefaultLitePullConsumerImpl.class.getDeclaredField("assignedMessageQueue");
...@@ -198,7 +203,9 @@ public class DefaultLitePullConsumerTest { ...@@ -198,7 +203,9 @@ public class DefaultLitePullConsumerTest {
when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L); when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L);
when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(500L); when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(500L);
MessageQueue messageQueue = createMessageQueue(); MessageQueue messageQueue = createMessageQueue();
litePullConsumer.assign(Collections.singletonList(messageQueue)); List<MessageQueue> messageQueues = Collections.singletonList(messageQueue);
litePullConsumer.assign(messageQueues);
litePullConsumer.pause(messageQueues);
litePullConsumer.seekToBegin(messageQueue); litePullConsumer.seekToBegin(messageQueue);
Field field = DefaultLitePullConsumerImpl.class.getDeclaredField("assignedMessageQueue"); Field field = DefaultLitePullConsumerImpl.class.getDeclaredField("assignedMessageQueue");
field.setAccessible(true); field.setAccessible(true);
...@@ -213,7 +220,9 @@ public class DefaultLitePullConsumerTest { ...@@ -213,7 +220,9 @@ public class DefaultLitePullConsumerTest {
when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L); when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L);
when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(500L); when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(500L);
MessageQueue messageQueue = createMessageQueue(); MessageQueue messageQueue = createMessageQueue();
litePullConsumer.assign(Collections.singletonList(messageQueue)); List<MessageQueue> messageQueues = Collections.singletonList(messageQueue);
litePullConsumer.assign(messageQueues);
litePullConsumer.pause(messageQueues);
litePullConsumer.seekToEnd(messageQueue); litePullConsumer.seekToEnd(messageQueue);
Field field = DefaultLitePullConsumerImpl.class.getDeclaredField("assignedMessageQueue"); Field field = DefaultLitePullConsumerImpl.class.getDeclaredField("assignedMessageQueue");
field.setAccessible(true); field.setAccessible(true);
...@@ -228,7 +237,9 @@ public class DefaultLitePullConsumerTest { ...@@ -228,7 +237,9 @@ public class DefaultLitePullConsumerTest {
when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L); when(mQAdminImpl.minOffset(any(MessageQueue.class))).thenReturn(0L);
when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(100L); when(mQAdminImpl.maxOffset(any(MessageQueue.class))).thenReturn(100L);
MessageQueue messageQueue = createMessageQueue(); MessageQueue messageQueue = createMessageQueue();
litePullConsumer.assign(Collections.singletonList(messageQueue)); List<MessageQueue> messageQueues = Collections.singletonList(messageQueue);
litePullConsumer.assign(messageQueues);
litePullConsumer.pause(messageQueues);
try { try {
litePullConsumer.seek(messageQueue, -1); litePullConsumer.seek(messageQueue, -1);
failBecauseExceptionWasNotThrown(MQClientException.class); failBecauseExceptionWasNotThrown(MQClientException.class);
...@@ -517,9 +528,6 @@ public class DefaultLitePullConsumerTest { ...@@ -517,9 +528,6 @@ public class DefaultLitePullConsumerTest {
public void testConsumerAfterShutdown() throws Exception { public void testConsumerAfterShutdown() throws Exception {
DefaultLitePullConsumer defaultLitePullConsumer = createSubscribeLitePullConsumer(); DefaultLitePullConsumer defaultLitePullConsumer = createSubscribeLitePullConsumer();
DefaultLitePullConsumer mockConsumer = spy(defaultLitePullConsumer);
when(mockConsumer.poll(anyLong())).thenReturn(new ArrayList<>());
new AsyncConsumer().executeAsync(defaultLitePullConsumer); new AsyncConsumer().executeAsync(defaultLitePullConsumer);
Thread.sleep(100); Thread.sleep(100);
...@@ -576,9 +584,9 @@ public class DefaultLitePullConsumerTest { ...@@ -576,9 +584,9 @@ public class DefaultLitePullConsumerTest {
when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<Object>() { .thenAnswer(new Answer<PullResult>() {
@Override @Override
public Object answer(InvocationOnMock mock) throws Throwable { public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1); PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt(); MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic); messageClientExt.setTopic(topic);
...@@ -604,6 +612,7 @@ public class DefaultLitePullConsumerTest { ...@@ -604,6 +612,7 @@ public class DefaultLitePullConsumerTest {
DefaultLitePullConsumer litePullConsumer = new DefaultLitePullConsumer(consumerGroup + System.currentTimeMillis()); DefaultLitePullConsumer litePullConsumer = new DefaultLitePullConsumer(consumerGroup + System.currentTimeMillis());
litePullConsumer.setNamesrvAddr("127.0.0.1:9876"); litePullConsumer.setNamesrvAddr("127.0.0.1:9876");
litePullConsumer.subscribe(topic, "*"); litePullConsumer.subscribe(topic, "*");
suppressUpdateTopicRouteInfoFromNameServer(litePullConsumer);
litePullConsumer.start(); litePullConsumer.start();
initDefaultLitePullConsumer(litePullConsumer); initDefaultLitePullConsumer(litePullConsumer);
return litePullConsumer; return litePullConsumer;
...@@ -612,6 +621,7 @@ public class DefaultLitePullConsumerTest { ...@@ -612,6 +621,7 @@ public class DefaultLitePullConsumerTest {
private DefaultLitePullConsumer createStartLitePullConsumer() throws Exception { private DefaultLitePullConsumer createStartLitePullConsumer() throws Exception {
DefaultLitePullConsumer litePullConsumer = new DefaultLitePullConsumer(consumerGroup + System.currentTimeMillis()); DefaultLitePullConsumer litePullConsumer = new DefaultLitePullConsumer(consumerGroup + System.currentTimeMillis());
litePullConsumer.setNamesrvAddr("127.0.0.1:9876"); litePullConsumer.setNamesrvAddr("127.0.0.1:9876");
suppressUpdateTopicRouteInfoFromNameServer(litePullConsumer);
litePullConsumer.start(); litePullConsumer.start();
initDefaultLitePullConsumer(litePullConsumer); initDefaultLitePullConsumer(litePullConsumer);
return litePullConsumer; return litePullConsumer;
...@@ -627,6 +637,7 @@ public class DefaultLitePullConsumerTest { ...@@ -627,6 +637,7 @@ public class DefaultLitePullConsumerTest {
litePullConsumer.setNamesrvAddr("127.0.0.1:9876"); litePullConsumer.setNamesrvAddr("127.0.0.1:9876");
litePullConsumer.setMessageModel(MessageModel.BROADCASTING); litePullConsumer.setMessageModel(MessageModel.BROADCASTING);
litePullConsumer.subscribe(topic, "*"); litePullConsumer.subscribe(topic, "*");
suppressUpdateTopicRouteInfoFromNameServer(litePullConsumer);
litePullConsumer.start(); litePullConsumer.start();
initDefaultLitePullConsumer(litePullConsumer); initDefaultLitePullConsumer(litePullConsumer);
return litePullConsumer; return litePullConsumer;
...@@ -648,4 +659,15 @@ public class DefaultLitePullConsumerTest { ...@@ -648,4 +659,15 @@ public class DefaultLitePullConsumerTest {
} }
return new PullResultExt(pullStatus, requestHeader.getQueueOffset() + messageExtList.size(), 123, 2048, messageExtList, 0, outputStream.toByteArray()); return new PullResultExt(pullStatus, requestHeader.getQueueOffset() + messageExtList.size(), 123, 2048, messageExtList, 0, outputStream.toByteArray());
} }
private static void suppressUpdateTopicRouteInfoFromNameServer(DefaultLitePullConsumer litePullConsumer) throws IllegalAccessException {
DefaultLitePullConsumerImpl defaultLitePullConsumerImpl = (DefaultLitePullConsumerImpl) FieldUtils.readDeclaredField(litePullConsumer, "defaultLitePullConsumerImpl", true);
if (litePullConsumer.getMessageModel() == MessageModel.CLUSTERING) {
litePullConsumer.changeInstanceNameToPID();
}
MQClientInstance mQClientFactory = spy(MQClientManager.getInstance().getOrCreateMQClientInstance(litePullConsumer, (RPCHook) FieldUtils.readDeclaredField(defaultLitePullConsumerImpl, "rpcHook", true)));
ConcurrentMap<String, MQClientInstance> factoryTable = (ConcurrentMap<String, MQClientInstance>) FieldUtils.readDeclaredField(MQClientManager.getInstance(), "factoryTable", true);
factoryTable.put(litePullConsumer.buildMQClientId(), mQClientFactory);
doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString());
}
} }
...@@ -23,9 +23,12 @@ import java.util.Collections; ...@@ -23,9 +23,12 @@ import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyContext;
...@@ -37,6 +40,7 @@ import org.apache.rocketmq.client.exception.MQClientException; ...@@ -37,6 +40,7 @@ import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.CommunicationMode;
import org.apache.rocketmq.client.impl.FindBrokerResult; import org.apache.rocketmq.client.impl.FindBrokerResult;
import org.apache.rocketmq.client.impl.MQClientAPIImpl; import org.apache.rocketmq.client.impl.MQClientAPIImpl;
import org.apache.rocketmq.client.impl.MQClientManager;
import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService;
import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService;
import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl;
...@@ -45,13 +49,14 @@ import org.apache.rocketmq.client.impl.consumer.PullAPIWrapper; ...@@ -45,13 +49,14 @@ import org.apache.rocketmq.client.impl.consumer.PullAPIWrapper;
import org.apache.rocketmq.client.impl.consumer.PullMessageService; import org.apache.rocketmq.client.impl.consumer.PullMessageService;
import org.apache.rocketmq.client.impl.consumer.PullRequest; import org.apache.rocketmq.client.impl.consumer.PullRequest;
import org.apache.rocketmq.client.impl.consumer.PullResultExt; import org.apache.rocketmq.client.impl.consumer.PullResultExt;
import org.apache.rocketmq.client.impl.consumer.RebalancePushImpl; import org.apache.rocketmq.client.impl.consumer.RebalanceImpl;
import org.apache.rocketmq.client.impl.factory.MQClientInstance; import org.apache.rocketmq.client.impl.factory.MQClientInstance;
import org.apache.rocketmq.common.message.MessageClientExt; import org.apache.rocketmq.common.message.MessageClientExt;
import org.apache.rocketmq.common.message.MessageDecoder; import org.apache.rocketmq.common.message.MessageDecoder;
import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.common.message.MessageQueue; import org.apache.rocketmq.common.message.MessageQueue;
import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.remoting.RPCHook;
import org.apache.rocketmq.remoting.exception.RemotingException; import org.apache.rocketmq.remoting.exception.RemotingException;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
...@@ -60,11 +65,8 @@ import org.junit.Test; ...@@ -60,11 +65,8 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.failBecauseExceptionWasNotThrown; import static org.assertj.core.api.Fail.failBecauseExceptionWasNotThrown;
...@@ -77,9 +79,7 @@ import static org.mockito.Mockito.doReturn; ...@@ -77,9 +79,7 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest(DefaultMQPushConsumerImpl.class)
@PowerMockIgnore("javax.management.*")
public class DefaultMQPushConsumerTest { public class DefaultMQPushConsumerTest {
private String consumerGroup; private String consumerGroup;
private String topic = "FooBar"; private String topic = "FooBar";
...@@ -89,11 +89,15 @@ public class DefaultMQPushConsumerTest { ...@@ -89,11 +89,15 @@ public class DefaultMQPushConsumerTest {
@Mock @Mock
private MQClientAPIImpl mQClientAPIImpl; private MQClientAPIImpl mQClientAPIImpl;
private PullAPIWrapper pullAPIWrapper; private PullAPIWrapper pullAPIWrapper;
private RebalancePushImpl rebalancePushImpl; private RebalanceImpl rebalanceImpl;
private DefaultMQPushConsumer pushConsumer; private DefaultMQPushConsumer pushConsumer;
@Before @Before
public void init() throws Exception { public void init() throws Exception {
ConcurrentMap<String, MQClientInstance> factoryTable = (ConcurrentMap<String, MQClientInstance>) FieldUtils.readDeclaredField(MQClientManager.getInstance(), "factoryTable", true);
factoryTable.forEach((s, instance) -> instance.shutdown());
factoryTable.clear();
consumerGroup = "FooBarGroup" + System.currentTimeMillis(); consumerGroup = "FooBarGroup" + System.currentTimeMillis();
pushConsumer = new DefaultMQPushConsumer(consumerGroup); pushConsumer = new DefaultMQPushConsumer(consumerGroup);
pushConsumer.setNamesrvAddr("127.0.0.1:9876"); pushConsumer.setNamesrvAddr("127.0.0.1:9876");
...@@ -108,16 +112,21 @@ public class DefaultMQPushConsumerTest { ...@@ -108,16 +112,21 @@ public class DefaultMQPushConsumerTest {
}); });
DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl();
PowerMockito.suppress(PowerMockito.method(DefaultMQPushConsumerImpl.class, "updateTopicSubscribeInfoWhenSubscriptionChanged"));
rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl())); // suppress updateTopicRouteInfoFromNameServer
pushConsumer.changeInstanceNameToPID();
mQClientFactory = spy(MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true)));
factoryTable.put(pushConsumer.buildMQClientId(), mQClientFactory);
doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString());
rebalanceImpl = spy(pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl());
Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl");
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, rebalancePushImpl); field.set(pushConsumerImpl, rebalanceImpl);
pushConsumer.subscribe(topic, "*"); pushConsumer.subscribe(topic, "*");
pushConsumer.start(); pushConsumer.start();
mQClientFactory = spy(pushConsumerImpl.getmQClientFactory());
field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory");
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, mQClientFactory); field.set(pushConsumerImpl, mQClientFactory);
...@@ -131,14 +140,13 @@ public class DefaultMQPushConsumerTest { ...@@ -131,14 +140,13 @@ public class DefaultMQPushConsumerTest {
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, pullAPIWrapper); field.set(pushConsumerImpl, pullAPIWrapper);
pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory);
mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl);
when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<Object>() { .thenAnswer(new Answer<PullResult>() {
@Override @Override
public Object answer(InvocationOnMock mock) throws Throwable { public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1); PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt(); MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic); messageClientExt.setTopic(topic);
...@@ -155,11 +163,10 @@ public class DefaultMQPushConsumerTest { ...@@ -155,11 +163,10 @@ public class DefaultMQPushConsumerTest {
}); });
doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean());
doReturn(Collections.singletonList(mQClientFactory.getClientId())).when(mQClientFactory).findConsumerIdList(anyString(), anyString());
Set<MessageQueue> messageQueueSet = new HashSet<MessageQueue>(); Set<MessageQueue> messageQueueSet = new HashSet<MessageQueue>();
messageQueueSet.add(createPullRequest().getMessageQueue()); messageQueueSet.add(createPullRequest().getMessageQueue());
pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet); pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet);
doReturn(123L).when(rebalancePushImpl).computePullFromWhere(any(MessageQueue.class)); doReturn(123L).when(rebalanceImpl).computePullFromWhere(any(MessageQueue.class));
} }
@After @After
...@@ -175,12 +182,12 @@ public class DefaultMQPushConsumerTest { ...@@ -175,12 +182,12 @@ public class DefaultMQPushConsumerTest {
@Test @Test
public void testPullMessage_Success() throws InterruptedException, RemotingException, MQBrokerException { public void testPullMessage_Success() throws InterruptedException, RemotingException, MQBrokerException {
final CountDownLatch countDownLatch = new CountDownLatch(1); final CountDownLatch countDownLatch = new CountDownLatch(1);
final MessageExt[] messageExts = new MessageExt[1]; final AtomicReference<MessageExt> messageAtomic = new AtomicReference<>();
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {
@Override @Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) { ConsumeConcurrentlyContext context) {
messageExts[0] = msgs.get(0); messageAtomic.set(msgs.get(0));
countDownLatch.countDown(); countDownLatch.countDown();
return null; return null;
} }
...@@ -188,20 +195,22 @@ public class DefaultMQPushConsumerTest { ...@@ -188,20 +195,22 @@ public class DefaultMQPushConsumerTest {
PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest()); pullMessageService.executePullRequestImmediately(createPullRequest());
countDownLatch.await(); countDownLatch.await(10, TimeUnit.SECONDS);
assertThat(messageExts[0].getTopic()).isEqualTo(topic); MessageExt msg = messageAtomic.get();
assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); assertThat(msg).isNotNull();
assertThat(msg.getTopic()).isEqualTo(topic);
assertThat(msg.getBody()).isEqualTo(new byte[] {'a'});
} }
@Test @Test
public void testPullMessage_SuccessWithOrderlyService() throws Exception { public void testPullMessage_SuccessWithOrderlyService() throws Exception {
final CountDownLatch countDownLatch = new CountDownLatch(1); final CountDownLatch countDownLatch = new CountDownLatch(1);
final MessageExt[] messageExts = new MessageExt[1]; final AtomicReference<MessageExt> messageAtomic = new AtomicReference<>();
MessageListenerOrderly listenerOrderly = new MessageListenerOrderly() { MessageListenerOrderly listenerOrderly = new MessageListenerOrderly() {
@Override @Override
public ConsumeOrderlyStatus consumeMessage(List<MessageExt> msgs, ConsumeOrderlyContext context) { public ConsumeOrderlyStatus consumeMessage(List<MessageExt> msgs, ConsumeOrderlyContext context) {
messageExts[0] = msgs.get(0); messageAtomic.set(msgs.get(0));
countDownLatch.countDown(); countDownLatch.countDown();
return null; return null;
} }
...@@ -214,8 +223,10 @@ public class DefaultMQPushConsumerTest { ...@@ -214,8 +223,10 @@ public class DefaultMQPushConsumerTest {
pullMessageService.executePullRequestLater(createPullRequest(), 100); pullMessageService.executePullRequestLater(createPullRequest(), 100);
countDownLatch.await(10, TimeUnit.SECONDS); countDownLatch.await(10, TimeUnit.SECONDS);
assertThat(messageExts[0].getTopic()).isEqualTo(topic); MessageExt msg = messageAtomic.get();
assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); assertThat(msg).isNotNull();
assertThat(msg.getTopic()).isEqualTo(topic);
assertThat(msg.getBody()).isEqualTo(new byte[] {'a'});
} }
@Test @Test
...@@ -281,7 +292,7 @@ public class DefaultMQPushConsumerTest { ...@@ -281,7 +292,7 @@ public class DefaultMQPushConsumerTest {
PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest()); pullMessageService.executePullRequestImmediately(createPullRequest());
countDownLatch.await(); assertThat(countDownLatch.await(30, TimeUnit.SECONDS)).isTrue();
pushConsumer.shutdown(); pushConsumer.shutdown();
assertThat(messageConsumedFlag.get()).isTrue(); assertThat(messageConsumedFlag.get()).isTrue();
......
...@@ -16,7 +16,19 @@ ...@@ -16,7 +16,19 @@
*/ */
package org.apache.rocketmq.client.impl.consumer; package org.apache.rocketmq.client.impl.consumer;
import org.apache.rocketmq.client.consumer.*; import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer;
import org.apache.rocketmq.client.consumer.PullCallback;
import org.apache.rocketmq.client.consumer.PullResult;
import org.apache.rocketmq.client.consumer.PullStatus;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently; import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
...@@ -40,23 +52,15 @@ import org.junit.Ignore; ...@@ -40,23 +52,15 @@ import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
...@@ -117,9 +121,9 @@ public class ConsumeMessageConcurrentlyServiceTest { ...@@ -117,9 +121,9 @@ public class ConsumeMessageConcurrentlyServiceTest {
when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<Object>() { .thenAnswer(new Answer<PullResult>() {
@Override @Override
public Object answer(InvocationOnMock mock) throws Throwable { public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1); PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt(); MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic); messageClientExt.setTopic(topic);
...@@ -145,13 +149,13 @@ public class ConsumeMessageConcurrentlyServiceTest { ...@@ -145,13 +149,13 @@ public class ConsumeMessageConcurrentlyServiceTest {
@Test @Test
public void testPullMessage_ConsumeSuccess() throws InterruptedException, RemotingException, MQBrokerException, NoSuchFieldException,Exception { public void testPullMessage_ConsumeSuccess() throws InterruptedException, RemotingException, MQBrokerException, NoSuchFieldException,Exception {
final CountDownLatch countDownLatch = new CountDownLatch(1); final CountDownLatch countDownLatch = new CountDownLatch(1);
final MessageExt[] messageExts = new MessageExt[1]; final AtomicReference<MessageExt> messageAtomic = new AtomicReference<>();
ConsumeMessageConcurrentlyService normalServie = new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { ConsumeMessageConcurrentlyService normalServie = new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {
@Override @Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) { ConsumeConcurrentlyContext context) {
messageExts[0] = msgs.get(0); messageAtomic.set(msgs.get(0));
countDownLatch.countDown(); countDownLatch.countDown();
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS; return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
} }
...@@ -175,8 +179,10 @@ public class ConsumeMessageConcurrentlyServiceTest { ...@@ -175,8 +179,10 @@ public class ConsumeMessageConcurrentlyServiceTest {
StatsItem item = itemSet.getAndCreateStatsItem(topic + "@" + pushConsumer.getDefaultMQPushConsumerImpl().groupName()); StatsItem item = itemSet.getAndCreateStatsItem(topic + "@" + pushConsumer.getDefaultMQPushConsumerImpl().groupName());
assertThat(item.getValue().get()).isGreaterThan(0L); assertThat(item.getValue().get()).isGreaterThan(0L);
assertThat(messageExts[0].getTopic()).isEqualTo(topic); MessageExt msg = messageAtomic.get();
assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); assertThat(msg).isNotNull();
assertThat(msg.getTopic()).isEqualTo(topic);
assertThat(msg.getBody()).isEqualTo(new byte[] {'a'});
} }
@After @After
......
...@@ -22,6 +22,7 @@ import java.util.List; ...@@ -22,6 +22,7 @@ import java.util.List;
import java.util.Properties; import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.rocketmq.client.ClientConfig; import org.apache.rocketmq.client.ClientConfig;
import org.apache.rocketmq.client.admin.MQAdminExtInner; import org.apache.rocketmq.client.admin.MQAdminExtInner;
import org.apache.rocketmq.client.exception.MQBrokerException; import org.apache.rocketmq.client.exception.MQBrokerException;
...@@ -39,7 +40,6 @@ import org.apache.rocketmq.remoting.exception.RemotingException; ...@@ -39,7 +40,6 @@ import org.apache.rocketmq.remoting.exception.RemotingException;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.internal.util.reflection.FieldSetter;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
...@@ -55,7 +55,7 @@ public class MQClientInstanceTest { ...@@ -55,7 +55,7 @@ public class MQClientInstanceTest {
@Before @Before
public void init() throws Exception { public void init() throws Exception {
FieldSetter.setField(mqClientInstance, MQClientInstance.class.getDeclaredField("brokerAddrTable"), brokerAddrTable); FieldUtils.writeDeclaredField(mqClientInstance, "brokerAddrTable", brokerAddrTable, true);
} }
@Test @Test
......
...@@ -20,6 +20,18 @@ package org.apache.rocketmq.client.trace; ...@@ -20,6 +20,18 @@ package org.apache.rocketmq.client.trace;
import io.opentracing.mock.MockSpan; import io.opentracing.mock.MockSpan;
import io.opentracing.mock.MockTracer; import io.opentracing.mock.MockTracer;
import io.opentracing.tag.Tags; import io.opentracing.tag.Tags;
import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer; import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer;
import org.apache.rocketmq.client.consumer.PullCallback; import org.apache.rocketmq.client.consumer.PullCallback;
import org.apache.rocketmq.client.consumer.PullResult; import org.apache.rocketmq.client.consumer.PullResult;
...@@ -27,11 +39,14 @@ import org.apache.rocketmq.client.consumer.PullStatus; ...@@ -27,11 +39,14 @@ import org.apache.rocketmq.client.consumer.PullStatus;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently; import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
import org.apache.rocketmq.client.consumer.store.OffsetStore;
import org.apache.rocketmq.client.consumer.store.ReadOffsetType;
import org.apache.rocketmq.client.exception.MQBrokerException; import org.apache.rocketmq.client.exception.MQBrokerException;
import org.apache.rocketmq.client.exception.MQClientException; import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.CommunicationMode;
import org.apache.rocketmq.client.impl.FindBrokerResult; import org.apache.rocketmq.client.impl.FindBrokerResult;
import org.apache.rocketmq.client.impl.MQClientAPIImpl; import org.apache.rocketmq.client.impl.MQClientAPIImpl;
import org.apache.rocketmq.client.impl.MQClientManager;
import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService;
import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl;
import org.apache.rocketmq.client.impl.consumer.ProcessQueue; import org.apache.rocketmq.client.impl.consumer.ProcessQueue;
...@@ -47,28 +62,17 @@ import org.apache.rocketmq.common.message.MessageDecoder; ...@@ -47,28 +62,17 @@ import org.apache.rocketmq.common.message.MessageDecoder;
import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.common.message.MessageQueue; import org.apache.rocketmq.common.message.MessageQueue;
import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.remoting.RPCHook;
import org.apache.rocketmq.remoting.exception.RemotingException; import org.apache.rocketmq.remoting.exception.RemotingException;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
...@@ -80,9 +84,7 @@ import static org.mockito.Mockito.doReturn; ...@@ -80,9 +84,7 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest(DefaultMQPushConsumerImpl.class)
@PowerMockIgnore("javax.management.*")
public class DefaultMQConsumerWithOpenTracingTest { public class DefaultMQConsumerWithOpenTracingTest {
private String consumerGroup; private String consumerGroup;
...@@ -99,6 +101,10 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -99,6 +101,10 @@ public class DefaultMQConsumerWithOpenTracingTest {
@Before @Before
public void init() throws Exception { public void init() throws Exception {
ConcurrentMap<String, MQClientInstance> factoryTable = (ConcurrentMap<String, MQClientInstance>) FieldUtils.readDeclaredField(MQClientManager.getInstance(), "factoryTable", true);
factoryTable.forEach((s, instance) -> instance.shutdown());
factoryTable.clear();
consumerGroup = "FooBarGroup" + System.currentTimeMillis(); consumerGroup = "FooBarGroup" + System.currentTimeMillis();
pushConsumer = new DefaultMQPushConsumer(consumerGroup); pushConsumer = new DefaultMQPushConsumer(consumerGroup);
pushConsumer.getDefaultMQPushConsumerImpl().registerConsumeMessageHook( pushConsumer.getDefaultMQPushConsumerImpl().registerConsumeMessageHook(
...@@ -106,6 +112,10 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -106,6 +112,10 @@ public class DefaultMQConsumerWithOpenTracingTest {
pushConsumer.setNamesrvAddr("127.0.0.1:9876"); pushConsumer.setNamesrvAddr("127.0.0.1:9876");
pushConsumer.setPullInterval(60 * 1000); pushConsumer.setPullInterval(60 * 1000);
OffsetStore offsetStore = Mockito.mock(OffsetStore.class);
Mockito.when(offsetStore.readOffset(any(MessageQueue.class), any(ReadOffsetType.class))).thenReturn(0L);
pushConsumer.setOffsetStore(offsetStore);
pushConsumer.registerMessageListener(new MessageListenerConcurrently() { pushConsumer.registerMessageListener(new MessageListenerConcurrently() {
@Override @Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
...@@ -114,8 +124,14 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -114,8 +124,14 @@ public class DefaultMQConsumerWithOpenTracingTest {
} }
}); });
PowerMockito.suppress(PowerMockito.method(DefaultMQPushConsumerImpl.class, "updateTopicSubscribeInfoWhenSubscriptionChanged"));
DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl();
// suppress updateTopicRouteInfoFromNameServer
pushConsumer.changeInstanceNameToPID();
mQClientFactory = spy(MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true)));
factoryTable.put(pushConsumer.buildMQClientId(), mQClientFactory);
doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString());
rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl())); rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl()));
Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl");
field.setAccessible(true); field.setAccessible(true);
...@@ -124,8 +140,6 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -124,8 +140,6 @@ public class DefaultMQConsumerWithOpenTracingTest {
pushConsumer.start(); pushConsumer.start();
mQClientFactory = spy(pushConsumerImpl.getmQClientFactory());
field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory");
field.setAccessible(true); field.setAccessible(true);
field.set(pushConsumerImpl, mQClientFactory); field.set(pushConsumerImpl, mQClientFactory);
...@@ -142,11 +156,11 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -142,11 +156,11 @@ public class DefaultMQConsumerWithOpenTracingTest {
pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory); pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory);
mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl);
when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), when(mQClientAPIImpl.pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<Object>() { .thenAnswer(new Answer<PullResult>() {
@Override @Override
public Object answer(InvocationOnMock mock) throws Throwable { public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1); PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt(); MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic); messageClientExt.setTopic(topic);
...@@ -176,12 +190,12 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -176,12 +190,12 @@ public class DefaultMQConsumerWithOpenTracingTest {
@Test @Test
public void testPullMessage_WithTrace_Success() throws InterruptedException, RemotingException, MQBrokerException, MQClientException { public void testPullMessage_WithTrace_Success() throws InterruptedException, RemotingException, MQBrokerException, MQClientException {
final CountDownLatch countDownLatch = new CountDownLatch(1); final CountDownLatch countDownLatch = new CountDownLatch(1);
final MessageExt[] messageExts = new MessageExt[1]; final AtomicReference<MessageExt> messageAtomic = new AtomicReference<>();
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {
@Override @Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) { ConsumeConcurrentlyContext context) {
messageExts[0] = msgs.get(0); messageAtomic.set(msgs.get(0));
countDownLatch.countDown(); countDownLatch.countDown();
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS; return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
} }
...@@ -189,9 +203,11 @@ public class DefaultMQConsumerWithOpenTracingTest { ...@@ -189,9 +203,11 @@ public class DefaultMQConsumerWithOpenTracingTest {
PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest()); pullMessageService.executePullRequestImmediately(createPullRequest());
countDownLatch.await(3000L, TimeUnit.MILLISECONDS); countDownLatch.await(30, TimeUnit.SECONDS);
assertThat(messageExts[0].getTopic()).isEqualTo(topic); MessageExt msg = messageAtomic.get();
assertThat(messageExts[0].getBody()).isEqualTo(new byte[]{'a'}); assertThat(msg).isNotNull();
assertThat(msg.getTopic()).isEqualTo(topic);
assertThat(msg.getBody()).isEqualTo(new byte[]{'a'});
assertThat(tracer.finishedSpans().size()).isEqualTo(1); assertThat(tracer.finishedSpans().size()).isEqualTo(1);
MockSpan span = tracer.finishedSpans().get(0); MockSpan span = tracer.finishedSpans().get(0);
......
...@@ -26,8 +26,11 @@ import java.util.HashMap; ...@@ -26,8 +26,11 @@ import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer; import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer;
import org.apache.rocketmq.client.consumer.PullCallback; import org.apache.rocketmq.client.consumer.PullCallback;
import org.apache.rocketmq.client.consumer.PullResult; import org.apache.rocketmq.client.consumer.PullResult;
...@@ -40,6 +43,7 @@ import org.apache.rocketmq.client.exception.MQClientException; ...@@ -40,6 +43,7 @@ import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.CommunicationMode;
import org.apache.rocketmq.client.impl.FindBrokerResult; import org.apache.rocketmq.client.impl.FindBrokerResult;
import org.apache.rocketmq.client.impl.MQClientAPIImpl; import org.apache.rocketmq.client.impl.MQClientAPIImpl;
import org.apache.rocketmq.client.impl.MQClientManager;
import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService; import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyService;
import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl;
import org.apache.rocketmq.client.impl.consumer.ProcessQueue; import org.apache.rocketmq.client.impl.consumer.ProcessQueue;
...@@ -53,17 +57,16 @@ import org.apache.rocketmq.client.impl.producer.DefaultMQProducerImpl; ...@@ -53,17 +57,16 @@ import org.apache.rocketmq.client.impl.producer.DefaultMQProducerImpl;
import org.apache.rocketmq.client.producer.DefaultMQProducer; import org.apache.rocketmq.client.producer.DefaultMQProducer;
import org.apache.rocketmq.client.producer.SendResult; import org.apache.rocketmq.client.producer.SendResult;
import org.apache.rocketmq.client.producer.SendStatus; import org.apache.rocketmq.client.producer.SendStatus;
import org.apache.rocketmq.common.MixAll;
import org.apache.rocketmq.common.message.MessageClientExt; import org.apache.rocketmq.common.message.MessageClientExt;
import org.apache.rocketmq.common.message.MessageDecoder; import org.apache.rocketmq.common.message.MessageDecoder;
import org.apache.rocketmq.common.message.MessageExt; import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.common.message.MessageQueue; import org.apache.rocketmq.common.message.MessageQueue;
import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.common.protocol.route.BrokerData; import org.apache.rocketmq.common.protocol.route.BrokerData;
import org.apache.rocketmq.common.protocol.route.QueueData; import org.apache.rocketmq.common.protocol.route.QueueData;
import org.apache.rocketmq.common.protocol.route.TopicRouteData; import org.apache.rocketmq.common.protocol.route.TopicRouteData;
import org.apache.rocketmq.common.topic.TopicValidator; import org.apache.rocketmq.common.topic.TopicValidator;
import org.apache.rocketmq.remoting.RPCHook;
import org.apache.rocketmq.remoting.exception.RemotingException; import org.apache.rocketmq.remoting.exception.RemotingException;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
...@@ -71,11 +74,8 @@ import org.junit.Test; ...@@ -71,11 +74,8 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
...@@ -87,9 +87,7 @@ import static org.mockito.Mockito.doReturn; ...@@ -87,9 +87,7 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class) @RunWith(MockitoJUnitRunner.class)
@PrepareForTest(DefaultMQPushConsumerImpl.class)
@PowerMockIgnore("javax.management.*")
public class DefaultMQConsumerWithTraceTest { public class DefaultMQConsumerWithTraceTest {
private String consumerGroup; private String consumerGroup;
private String consumerGroupNormal; private String consumerGroupNormal;
...@@ -116,6 +114,10 @@ public class DefaultMQConsumerWithTraceTest { ...@@ -116,6 +114,10 @@ public class DefaultMQConsumerWithTraceTest {
@Before @Before
public void init() throws Exception { public void init() throws Exception {
ConcurrentMap<String, MQClientInstance> factoryTable = (ConcurrentMap<String, MQClientInstance>) FieldUtils.readDeclaredField(MQClientManager.getInstance(), "factoryTable", true);
factoryTable.forEach((s, instance) -> instance.shutdown());
factoryTable.clear();
consumerGroup = "FooBarGroup" + System.currentTimeMillis(); consumerGroup = "FooBarGroup" + System.currentTimeMillis();
pushConsumer = new DefaultMQPushConsumer(consumerGroup, true, ""); pushConsumer = new DefaultMQPushConsumer(consumerGroup, true, "");
consumerGroupNormal = "FooBarGroup" + System.currentTimeMillis(); consumerGroupNormal = "FooBarGroup" + System.currentTimeMillis();
...@@ -135,8 +137,14 @@ public class DefaultMQConsumerWithTraceTest { ...@@ -135,8 +137,14 @@ public class DefaultMQConsumerWithTraceTest {
} }
}); });
PowerMockito.suppress(PowerMockito.method(DefaultMQPushConsumerImpl.class, "updateTopicSubscribeInfoWhenSubscriptionChanged"));
DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl(); DefaultMQPushConsumerImpl pushConsumerImpl = pushConsumer.getDefaultMQPushConsumerImpl();
// suppress updateTopicRouteInfoFromNameServer
pushConsumer.changeInstanceNameToPID();
mQClientFactory = spy(MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true)));
factoryTable.put(pushConsumer.buildMQClientId(), mQClientFactory);
doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString());
rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl())); rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl()));
Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl");
field.setAccessible(true); field.setAccessible(true);
...@@ -174,9 +182,9 @@ public class DefaultMQConsumerWithTraceTest { ...@@ -174,9 +182,9 @@ public class DefaultMQConsumerWithTraceTest {
when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class),
anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) anyLong(), any(CommunicationMode.class), nullable(PullCallback.class)))
.thenAnswer(new Answer<Object>() { .thenAnswer(new Answer<PullResult>() {
@Override @Override
public Object answer(InvocationOnMock mock) throws Throwable { public PullResult answer(InvocationOnMock mock) throws Throwable {
PullMessageRequestHeader requestHeader = mock.getArgument(1); PullMessageRequestHeader requestHeader = mock.getArgument(1);
MessageClientExt messageClientExt = new MessageClientExt(); MessageClientExt messageClientExt = new MessageClientExt();
messageClientExt.setTopic(topic); messageClientExt.setTopic(topic);
...@@ -208,12 +216,12 @@ public class DefaultMQConsumerWithTraceTest { ...@@ -208,12 +216,12 @@ public class DefaultMQConsumerWithTraceTest {
traceProducer.getDefaultMQProducerImpl().getmQClientFactory().registerProducer(producerGroupTraceTemp, traceProducer.getDefaultMQProducerImpl()); traceProducer.getDefaultMQProducerImpl().getmQClientFactory().registerProducer(producerGroupTraceTemp, traceProducer.getDefaultMQProducerImpl());
final CountDownLatch countDownLatch = new CountDownLatch(1); final CountDownLatch countDownLatch = new CountDownLatch(1);
final MessageExt[] messageExts = new MessageExt[1]; final AtomicReference<MessageExt> messageAtomic = new AtomicReference<>();
pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() {
@Override @Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) { ConsumeConcurrentlyContext context) {
messageExts[0] = msgs.get(0); messageAtomic.set(msgs.get(0));
countDownLatch.countDown(); countDownLatch.countDown();
return null; return null;
} }
...@@ -221,9 +229,11 @@ public class DefaultMQConsumerWithTraceTest { ...@@ -221,9 +229,11 @@ public class DefaultMQConsumerWithTraceTest {
PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); PullMessageService pullMessageService = mQClientFactory.getPullMessageService();
pullMessageService.executePullRequestImmediately(createPullRequest()); pullMessageService.executePullRequestImmediately(createPullRequest());
countDownLatch.await(3000L, TimeUnit.MILLISECONDS); countDownLatch.await(30, TimeUnit.SECONDS);
assertThat(messageExts[0].getTopic()).isEqualTo(topic); MessageExt msg = messageAtomic.get();
assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); assertThat(msg).isNotNull();
assertThat(msg.getTopic()).isEqualTo(topic);
assertThat(msg.getBody()).isEqualTo(new byte[] {'a'});
} }
private PullRequest createPullRequest() { private PullRequest createPullRequest() {
......
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<Configuration status="WARN">
<Appenders>
<Console name="Console" target="SYSTEM_OUT">
<PatternLayout pattern="%d{HH:mm:ss.SSS} [%t] %-5level %logger{36} - %msg%n"/>
</Console>
</Appenders>
<Loggers>
<Root level="ERROR">
<AppenderRef ref="Console"/>
</Root>
</Loggers>
</Configuration>
\ No newline at end of file
...@@ -33,13 +33,13 @@ public class SysLogger { ...@@ -33,13 +33,13 @@ public class SysLogger {
public static void debug(String msg) { public static void debug(String msg) {
if (debugEnabled && !quietMode) { if (debugEnabled && !quietMode) {
System.out.printf("%s", PREFIX + msg); System.err.println(PREFIX + msg);
} }
} }
public static void debug(String msg, Throwable t) { public static void debug(String msg, Throwable t) {
if (debugEnabled && !quietMode) { if (debugEnabled && !quietMode) {
System.out.printf("%s", PREFIX + msg); System.err.println(PREFIX + msg);
if (t != null) { if (t != null) {
t.printStackTrace(System.out); t.printStackTrace(System.out);
} }
......
...@@ -102,7 +102,6 @@ ...@@ -102,7 +102,6 @@
<!-- Exclude all generated code --> <!-- Exclude all generated code -->
<sonar.jacoco.itReportPath>${project.basedir}/../test/target/jacoco-it.exec</sonar.jacoco.itReportPath> <sonar.jacoco.itReportPath>${project.basedir}/../test/target/jacoco-it.exec</sonar.jacoco.itReportPath>
<sonar.exclusions>file:**/generated-sources/**,**/test/**</sonar.exclusions> <sonar.exclusions>file:**/generated-sources/**,**/test/**</sonar.exclusions>
<powermock.version>2.0.2</powermock.version>
</properties> </properties>
...@@ -425,7 +424,7 @@ ...@@ -425,7 +424,7 @@
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
<version>4.11</version> <version>4.13.2</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
...@@ -437,19 +436,7 @@ ...@@ -437,19 +436,7 @@
<dependency> <dependency>
<groupId>org.mockito</groupId> <groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId> <artifactId>mockito-core</artifactId>
<version>2.23.0</version> <version>3.10.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-module-junit4</artifactId>
<version>${powermock.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-api-mockito2</artifactId>
<version>${powermock.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册