diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java index 48d9b3a91a72af27ae1724b6bd88750aa72f332d..c2c60824ec80e9f4f3e2eac99c95704739349ecd 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java @@ -17,7 +17,6 @@ package org.apache.rocketmq.client.consumer; import java.io.ByteArrayOutputStream; -import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.util.Collections; import java.util.HashSet; @@ -45,7 +44,6 @@ import org.apache.rocketmq.client.impl.consumer.ConsumeMessageConcurrentlyServic import org.apache.rocketmq.client.impl.consumer.ConsumeMessageOrderlyService; import org.apache.rocketmq.client.impl.consumer.DefaultMQPushConsumerImpl; import org.apache.rocketmq.client.impl.consumer.ProcessQueue; -import org.apache.rocketmq.client.impl.consumer.PullAPIWrapper; import org.apache.rocketmq.client.impl.consumer.PullMessageService; import org.apache.rocketmq.client.impl.consumer.PullRequest; import org.apache.rocketmq.client.impl.consumer.PullResultExt; @@ -88,7 +86,6 @@ public class DefaultMQPushConsumerTest { @Mock private MQClientAPIImpl mQClientAPIImpl; - private PullAPIWrapper pullAPIWrapper; private RebalanceImpl rebalanceImpl; private DefaultMQPushConsumer pushConsumer; @@ -98,6 +95,27 @@ public class DefaultMQPushConsumerTest { factoryTable.forEach((s, instance) -> instance.shutdown()); factoryTable.clear(); + when(mQClientAPIImpl.pullMessage(anyString(), any(PullMessageRequestHeader.class), + anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + .thenAnswer(new Answer() { + @Override + public PullResult answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + MessageClientExt messageClientExt = new MessageClientExt(); + messageClientExt.setTopic(topic); + messageClientExt.setQueueId(0); + messageClientExt.setMsgId("123"); + messageClientExt.setBody(new byte[] {'a'}); + messageClientExt.setOffsetMsgId("234"); + messageClientExt.setBornHost(new InetSocketAddress(8080)); + messageClientExt.setStoreHost(new InetSocketAddress(8080)); + PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); + return pullResult; + } + }); + + consumerGroup = "FooBarGroup" + System.currentTimeMillis(); pushConsumer = new DefaultMQPushConsumer(consumerGroup); pushConsumer.setNamesrvAddr("127.0.0.1:9876"); @@ -115,58 +133,24 @@ public class DefaultMQPushConsumerTest { // suppress updateTopicRouteInfoFromNameServer pushConsumer.changeInstanceNameToPID(); - mQClientFactory = spy(MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true))); + mQClientFactory = MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true)); + FieldUtils.writeDeclaredField(mQClientFactory, "mQClientAPIImpl", mQClientAPIImpl, true); + mQClientFactory = spy(mQClientFactory); factoryTable.put(pushConsumer.buildMQClientId(), mQClientFactory); doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString()); - rebalanceImpl = spy(pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl()); - Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); - field.setAccessible(true); - field.set(pushConsumerImpl, rebalanceImpl); - - pushConsumer.subscribe(topic, "*"); - pushConsumer.start(); - - field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); - field.setAccessible(true); - field.set(pushConsumerImpl, mQClientFactory); - - field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl"); - field.setAccessible(true); - field.set(mQClientFactory, mQClientAPIImpl); - - pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false)); - field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper"); - field.setAccessible(true); - field.set(pushConsumerImpl, pullAPIWrapper); - - mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); + doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); - when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), - anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) - .thenAnswer(new Answer() { - @Override - public PullResult answer(InvocationOnMock mock) throws Throwable { - PullMessageRequestHeader requestHeader = mock.getArgument(1); - MessageClientExt messageClientExt = new MessageClientExt(); - messageClientExt.setTopic(topic); - messageClientExt.setQueueId(0); - messageClientExt.setMsgId("123"); - messageClientExt.setBody(new byte[] {'a'}); - messageClientExt.setOffsetMsgId("234"); - messageClientExt.setBornHost(new InetSocketAddress(8080)); - messageClientExt.setStoreHost(new InetSocketAddress(8080)); - PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); - ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); - return pullResult; - } - }); + rebalanceImpl = spy(pushConsumerImpl.getRebalanceImpl()); + doReturn(123L).when(rebalanceImpl).computePullFromWhere(any(MessageQueue.class)); + FieldUtils.writeDeclaredField(pushConsumerImpl, "rebalanceImpl", rebalanceImpl, true); - doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); Set messageQueueSet = new HashSet(); messageQueueSet.add(createPullRequest().getMessageQueue()); - pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet); - doReturn(123L).when(rebalanceImpl).computePullFromWhere(any(MessageQueue.class)); + pushConsumerImpl.updateTopicSubscribeInfo(topic, messageQueueSet); + + pushConsumer.subscribe(topic, "*"); + pushConsumer.start(); } @After @@ -292,7 +276,7 @@ public class DefaultMQPushConsumerTest { PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); pullMessageService.executePullRequestImmediately(createPullRequest()); - assertThat(countDownLatch.await(30, TimeUnit.SECONDS)).isTrue(); + assertThat(countDownLatch.await(10, TimeUnit.SECONDS)).isTrue(); pushConsumer.shutdown(); assertThat(messageConsumedFlag.get()).isTrue(); diff --git a/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQConsumerWithOpenTracingTest.java b/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQConsumerWithOpenTracingTest.java index 1d8ac85b272c46d85f2a7e8eff501e7f500b525c..ecf72ae44cfefdb682346d366b557b0a322f4077 100644 --- a/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQConsumerWithOpenTracingTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQConsumerWithOpenTracingTest.java @@ -21,7 +21,6 @@ import io.opentracing.mock.MockSpan; import io.opentracing.mock.MockTracer; import io.opentracing.tag.Tags; import java.io.ByteArrayOutputStream; -import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.util.Collections; import java.util.HashSet; @@ -106,6 +105,27 @@ public class DefaultMQConsumerWithOpenTracingTest { factoryTable.forEach((s, instance) -> instance.shutdown()); factoryTable.clear(); + when(mQClientAPIImpl.pullMessage(anyString(), any(PullMessageRequestHeader.class), + anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + .thenAnswer(new Answer() { + @Override + public PullResult answer(InvocationOnMock mock) throws Throwable { + PullMessageRequestHeader requestHeader = mock.getArgument(1); + MessageClientExt messageClientExt = new MessageClientExt(); + messageClientExt.setTopic(topic); + messageClientExt.setQueueId(0); + messageClientExt.setMsgId("123"); + messageClientExt.setBody(new byte[]{'a'}); + messageClientExt.setOffsetMsgId("234"); + messageClientExt.setBornHost(new InetSocketAddress(8080)); + messageClientExt.setStoreHost(new InetSocketAddress(8080)); + PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); + return pullResult; + } + }); + + consumerGroup = "FooBarGroup" + System.currentTimeMillis(); pushConsumer = new DefaultMQPushConsumer(consumerGroup); pushConsumer.getDefaultMQPushConsumerImpl().registerConsumeMessageHook( @@ -129,58 +149,20 @@ public class DefaultMQConsumerWithOpenTracingTest { // suppress updateTopicRouteInfoFromNameServer pushConsumer.changeInstanceNameToPID(); - mQClientFactory = spy(MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true))); + mQClientFactory = MQClientManager.getInstance().getOrCreateMQClientInstance(pushConsumer, (RPCHook) FieldUtils.readDeclaredField(pushConsumerImpl, "rpcHook", true)); + FieldUtils.writeDeclaredField(mQClientFactory, "mQClientAPIImpl", mQClientAPIImpl, true); + mQClientFactory = spy(mQClientFactory); factoryTable.put(pushConsumer.buildMQClientId(), mQClientFactory); doReturn(false).when(mQClientFactory).updateTopicRouteInfoFromNameServer(anyString()); - rebalancePushImpl = spy(new RebalancePushImpl(pushConsumer.getDefaultMQPushConsumerImpl())); - Field field = DefaultMQPushConsumerImpl.class.getDeclaredField("rebalanceImpl"); - field.setAccessible(true); - field.set(pushConsumerImpl, rebalancePushImpl); - pushConsumer.subscribe(topic, "*"); - - pushConsumer.start(); - - field = DefaultMQPushConsumerImpl.class.getDeclaredField("mQClientFactory"); - field.setAccessible(true); - field.set(pushConsumerImpl, mQClientFactory); - - field = MQClientInstance.class.getDeclaredField("mQClientAPIImpl"); - field.setAccessible(true); - field.set(mQClientFactory, mQClientAPIImpl); - - pullAPIWrapper = spy(new PullAPIWrapper(mQClientFactory, consumerGroup, false)); - field = DefaultMQPushConsumerImpl.class.getDeclaredField("pullAPIWrapper"); - field.setAccessible(true); - field.set(pushConsumerImpl, pullAPIWrapper); - - pushConsumer.getDefaultMQPushConsumerImpl().getRebalanceImpl().setmQClientFactory(mQClientFactory); - mQClientFactory.registerConsumer(consumerGroup, pushConsumerImpl); - - when(mQClientAPIImpl.pullMessage(anyString(), any(PullMessageRequestHeader.class), - anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) - .thenAnswer(new Answer() { - @Override - public PullResult answer(InvocationOnMock mock) throws Throwable { - PullMessageRequestHeader requestHeader = mock.getArgument(1); - MessageClientExt messageClientExt = new MessageClientExt(); - messageClientExt.setTopic(topic); - messageClientExt.setQueueId(0); - messageClientExt.setMsgId("123"); - messageClientExt.setBody(new byte[]{'a'}); - messageClientExt.setOffsetMsgId("234"); - messageClientExt.setBornHost(new InetSocketAddress(8080)); - messageClientExt.setStoreHost(new InetSocketAddress(8080)); - PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); - ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); - return pullResult; - } - }); - doReturn(new FindBrokerResult("127.0.0.1:10911", false)).when(mQClientFactory).findBrokerAddressInSubscribe(anyString(), anyLong(), anyBoolean()); + Set messageQueueSet = new HashSet(); messageQueueSet.add(createPullRequest().getMessageQueue()); - pushConsumer.getDefaultMQPushConsumerImpl().updateTopicSubscribeInfo(topic, messageQueueSet); + pushConsumerImpl.updateTopicSubscribeInfo(topic, messageQueueSet); + + pushConsumer.subscribe(topic, "*"); + pushConsumer.start(); } @After