diff --git a/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java b/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java index b9be49c033b538915c70f015313402ba6e8bf612..2e23eaa7f3b478e9ae91abc5af623bfa21b2658a 100644 --- a/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/producer/DefaultMQProducerTest.java @@ -52,7 +52,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Spy; +import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Fail.failBecauseExceptionWasNotThrown; @@ -106,6 +108,17 @@ public class DefaultMQProducerTest { when(mQClientAPIImpl.sendMessage(anyString(), anyString(), any(Message.class), any(SendMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), nullable(SendCallback.class), nullable(TopicPublishInfo.class), nullable(MQClientInstance.class), anyInt(), nullable(SendMessageContext.class), any(DefaultMQProducerImpl.class))) .thenReturn(createSendResult(SendStatus.SEND_OK)); + when(mQClientAPIImpl.sendMessage(anyString(), anyString(), any(Message.class), any(SendMessageRequestHeader.class), anyLong(), any(CommunicationMode.class), + any(SendCallback.class), nullable(TopicPublishInfo.class), any(MQClientInstance.class), anyInt(), nullable(SendMessageContext.class), any(DefaultMQProducerImpl.class))) + .thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + SendCallback callback = (SendCallback) args[6]; + callback.onSuccess(createSendResult(SendStatus.SEND_OK)); + return new SendResult(); + } + }); } @After @@ -168,6 +181,8 @@ public class DefaultMQProducerTest { @Test public void testSendMessageAsync_Success() throws RemotingException, InterruptedException, MQBrokerException, MQClientException { final CountDownLatch countDownLatch = new CountDownLatch(1); + final AtomicInteger cc = new AtomicInteger(0); + when(mQClientAPIImpl.getTopicRouteInfoFromNameServer(anyString(), anyLong())).thenReturn(createTopicRoute()); producer.send(message, new SendCallback() { @Override public void onSuccess(SendResult sendResult) { @@ -175,14 +190,15 @@ public class DefaultMQProducerTest { assertThat(sendResult.getOffsetMsgId()).isEqualTo("123"); assertThat(sendResult.getQueueOffset()).isEqualTo(456L); countDownLatch.countDown(); + cc.incrementAndGet(); } @Override public void onException(Throwable e) { - countDownLatch.countDown(); } }); countDownLatch.await(3000L, TimeUnit.MILLISECONDS); + assertThat(cc.get()).isEqualTo(1); } @Test @@ -190,9 +206,11 @@ public class DefaultMQProducerTest { final AtomicInteger cc = new AtomicInteger(0); final CountDownLatch countDownLatch = new CountDownLatch(6); + when(mQClientAPIImpl.getTopicRouteInfoFromNameServer(anyString(), anyLong())).thenReturn(createTopicRoute()); SendCallback sendCallback = new SendCallback() { @Override public void onSuccess(SendResult sendResult) { + countDownLatch.countDown(); } @Override @@ -213,20 +231,21 @@ public class DefaultMQProducerTest { message.setTopic("test"); message.setBody("hello world".getBytes()); producer.send(new Message(), sendCallback); - producer.send(message, sendCallback, 1000); producer.send(message, new MessageQueue(), sendCallback); producer.send(new Message(), new MessageQueue(), sendCallback, 1000); producer.send(new Message(), messageQueueSelector, null, sendCallback); producer.send(message, messageQueueSelector, null, sendCallback, 1000); + //this message is send success + producer.send(message, sendCallback, 1000); countDownLatch.await(3000L, TimeUnit.MILLISECONDS); - assertThat(cc.get()).isEqualTo(6); + assertThat(cc.get()).isEqualTo(5); } @Test public void testSendMessageAsync_BodyCompressed() throws RemotingException, InterruptedException, MQBrokerException, MQClientException { - final CountDownLatch countDownLatch = new CountDownLatch(1); + when(mQClientAPIImpl.getTopicRouteInfoFromNameServer(anyString(), anyLong())).thenReturn(createTopicRoute()); producer.send(bigMessage, new SendCallback() { @Override public void onSuccess(SendResult sendResult) { @@ -238,7 +257,6 @@ public class DefaultMQProducerTest { @Override public void onException(Throwable e) { - countDownLatch.countDown(); } }); countDownLatch.await(3000L, TimeUnit.MILLISECONDS);