[FLINK-8731] Replaced mockito with custom mock in TestInputChannel

This closes #6338
上级 5be27a23
......@@ -25,19 +25,16 @@ import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer;
import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability;
import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel.BufferAndAvailabilityProvider;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.MutableObjectIterator;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.util.Optional;
import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.buildSingleBuffer;
import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createBufferBuilder;
import static org.mockito.Mockito.when;
public class IteratorWrappingTestSingleInputGate<T extends IOReadableWritable> extends TestSingleInputGate {
......@@ -66,12 +63,12 @@ public class IteratorWrappingTestSingleInputGate<T extends IOReadableWritable> e
// The input iterator can produce an infinite stream. That's why we have to serialize each
// record on demand and cannot do it upfront.
final Answer<Optional<BufferAndAvailability>> answer = new Answer<Optional<BufferAndAvailability>>() {
final BufferAndAvailabilityProvider answer = new BufferAndAvailabilityProvider() {
private boolean hasData = inputIterator.next(reuse) != null;
@Override
public Optional<BufferAndAvailability> answer(InvocationOnMock invocationOnMock) throws Throwable {
public Optional<BufferAndAvailability> getBufferAvailability() throws IOException {
if (hasData) {
serializer.clear();
BufferBuilder bufferBuilder = createBufferBuilder(bufferSize);
......@@ -83,22 +80,24 @@ public class IteratorWrappingTestSingleInputGate<T extends IOReadableWritable> e
// Call getCurrentBuffer to ensure size is set
return Optional.of(new BufferAndAvailability(buildSingleBuffer(bufferBuilder), true, 0));
} else {
when(inputChannel.getInputChannel().isReleased()).thenReturn(true);
inputChannel.setReleased();
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), false, 0));
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE),
false,
0));
}
}
};
when(inputChannel.getInputChannel().getNextBuffer()).thenAnswer(answer);
inputChannel.addBufferAndAvailability(answer);
inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannel.getInputChannel());
inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannel);
return this;
}
public IteratorWrappingTestSingleInputGate<T> notifyNonEmpty() {
inputGate.notifyChannelNonEmpty(inputChannel.getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannel);
return this;
}
......
......@@ -105,10 +105,10 @@ public class SingleInputGateTest {
};
inputGate.setInputChannel(
new IntermediateResultPartitionID(), inputChannels[0].getInputChannel());
new IntermediateResultPartitionID(), inputChannels[0]);
inputGate.setInputChannel(
new IntermediateResultPartitionID(), inputChannels[1].getInputChannel());
new IntermediateResultPartitionID(), inputChannels[1]);
// Test
inputChannels[0].readBuffer();
......@@ -117,8 +117,8 @@ public class SingleInputGateTest {
inputChannels[1].readEndOfPartitionEvent();
inputChannels[0].readEndOfPartitionEvent();
inputGate.notifyChannelNonEmpty(inputChannels[0].getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannels[1].getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannels[0]);
inputGate.notifyChannelNonEmpty(inputChannels[1]);
verifyBufferOrEvent(inputGate, true, 0, true);
verifyBufferOrEvent(inputGate, true, 1, true);
......@@ -141,16 +141,16 @@ public class SingleInputGateTest {
};
inputGate.setInputChannel(
new IntermediateResultPartitionID(), inputChannels[0].getInputChannel());
new IntermediateResultPartitionID(), inputChannels[0]);
inputGate.setInputChannel(
new IntermediateResultPartitionID(), inputChannels[1].getInputChannel());
new IntermediateResultPartitionID(), inputChannels[1]);
// Test
inputChannels[0].readBuffer();
inputChannels[0].readBuffer(false);
inputGate.notifyChannelNonEmpty(inputChannels[0].getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannels[0]);
verifyBufferOrEvent(inputGate, true, 0, true);
verifyBufferOrEvent(inputGate, true, 0, false);
......
......@@ -18,18 +18,18 @@
package org.apache.flink.runtime.io.network.partition.consumer;
import org.apache.flink.metrics.SimpleCounter;
import org.apache.flink.runtime.event.TaskEvent;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.OngoingStubbing;
import java.io.IOException;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
......@@ -39,20 +39,16 @@ import static org.mockito.Mockito.when;
/**
* A mocked input channel.
*/
public class TestInputChannel {
public class TestInputChannel extends InputChannel {
private final InputChannel mock = Mockito.mock(InputChannel.class);
private final Queue<BufferAndAvailabilityProvider> buffers = new ConcurrentLinkedQueue<>();
private final SingleInputGate inputGate;
private BufferAndAvailabilityProvider lastProvider = null;
// Abusing Mockito here... ;)
protected OngoingStubbing<Optional<BufferAndAvailability>> stubbing;
private boolean isReleased = false;
public TestInputChannel(SingleInputGate inputGate, int channelIndex) {
checkArgument(channelIndex >= 0);
this.inputGate = checkNotNull(inputGate);
when(mock.getChannelIndex()).thenReturn(channelIndex);
TestInputChannel(SingleInputGate inputGate, int channelIndex) {
super(inputGate, channelIndex, new ResultPartitionID(), 0, 0, new SimpleCounter());
}
public TestInputChannel read(Buffer buffer) throws IOException, InterruptedException {
......@@ -60,48 +56,40 @@ public class TestInputChannel {
}
public TestInputChannel read(Buffer buffer, boolean moreAvailable) throws IOException, InterruptedException {
if (stubbing == null) {
stubbing = when(mock.getNextBuffer()).thenReturn(Optional.of(new BufferAndAvailability(buffer, moreAvailable, 0)));
} else {
stubbing = stubbing.thenReturn(Optional.of(new BufferAndAvailability(buffer, moreAvailable, 0)));
}
addBufferAndAvailability(new BufferAndAvailability(buffer, moreAvailable, 0));
return this;
}
public TestInputChannel readBuffer() throws IOException, InterruptedException {
TestInputChannel readBuffer() throws IOException, InterruptedException {
return readBuffer(true);
}
public TestInputChannel readBuffer(boolean moreAvailable) throws IOException, InterruptedException {
TestInputChannel readBuffer(boolean moreAvailable) throws IOException, InterruptedException {
final Buffer buffer = mock(Buffer.class);
when(buffer.isBuffer()).thenReturn(true);
return read(buffer, moreAvailable);
}
public TestInputChannel readEndOfPartitionEvent() throws IOException, InterruptedException {
final Answer<Optional<BufferAndAvailability>> answer = new Answer<Optional<BufferAndAvailability>>() {
@Override
public Optional<BufferAndAvailability> answer(InvocationOnMock invocationOnMock) throws Throwable {
// Return true after finishing
when(mock.isReleased()).thenReturn(true);
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), false, 0));
TestInputChannel readEndOfPartitionEvent() throws InterruptedException {
addBufferAndAvailability(
() -> {
setReleased();
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE),
false,
0));
}
};
if (stubbing == null) {
stubbing = when(mock.getNextBuffer()).thenAnswer(answer);
} else {
stubbing = stubbing.thenAnswer(answer);
}
);
return this;
}
public InputChannel getInputChannel() {
return mock;
void addBufferAndAvailability(BufferAndAvailability bufferAndAvailability) {
buffers.add(() -> Optional.of(bufferAndAvailability));
}
void addBufferAndAvailability(BufferAndAvailabilityProvider bufferAndAvailability) {
buffers.add(bufferAndAvailability);
}
// ------------------------------------------------------------------------
......@@ -111,7 +99,7 @@ public class TestInputChannel {
*
* @return The created test input channels.
*/
public static TestInputChannel[] createInputChannels(SingleInputGate inputGate, int numberOfInputChannels) {
static TestInputChannel[] createInputChannels(SingleInputGate inputGate, int numberOfInputChannels) {
checkNotNull(inputGate);
checkArgument(numberOfInputChannels > 0);
......@@ -120,9 +108,62 @@ public class TestInputChannel {
for (int i = 0; i < numberOfInputChannels; i++) {
mocks[i] = new TestInputChannel(inputGate, i);
inputGate.setInputChannel(new IntermediateResultPartitionID(), mocks[i].getInputChannel());
inputGate.setInputChannel(new IntermediateResultPartitionID(), mocks[i]);
}
return mocks;
}
@Override
void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException {
}
@Override
Optional<BufferAndAvailability> getNextBuffer() throws IOException, InterruptedException {
BufferAndAvailabilityProvider provider = buffers.poll();
if (provider != null) {
lastProvider = provider;
return provider.getBufferAvailability();
} else if (lastProvider != null) {
return lastProvider.getBufferAvailability();
} else {
return Optional.empty();
}
}
@Override
void sendTaskEvent(TaskEvent event) throws IOException {
}
@Override
boolean isReleased() {
return isReleased;
}
void setReleased() {
this.isReleased = true;
}
@Override
void notifySubpartitionConsumed() throws IOException {
}
@Override
void releaseAllResources() throws IOException {
}
@Override
protected void notifyChannelNonEmpty() {
}
interface BufferAndAvailabilityProvider {
Optional<BufferAndAvailability> getBufferAvailability() throws IOException, InterruptedException;
}
}
......@@ -98,7 +98,7 @@ public class TestSingleInputGate {
if (initialize) {
for (int i = 0; i < numberOfInputChannels; i++) {
inputChannels[i] = new TestInputChannel(inputGate, i);
inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannels[i].getInputChannel());
inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannels[i]);
}
}
}
......
......@@ -86,15 +86,15 @@ public class UnionInputGateTest {
inputChannels[1][1].readEndOfPartitionEvent(); // 0 => 3
inputChannels[1][0].readEndOfPartitionEvent(); // 0 => 3
ig1.notifyChannelNonEmpty(inputChannels[0][0].getInputChannel());
ig1.notifyChannelNonEmpty(inputChannels[0][1].getInputChannel());
ig1.notifyChannelNonEmpty(inputChannels[0][2].getInputChannel());
ig1.notifyChannelNonEmpty(inputChannels[0][0]);
ig1.notifyChannelNonEmpty(inputChannels[0][1]);
ig1.notifyChannelNonEmpty(inputChannels[0][2]);
ig2.notifyChannelNonEmpty(inputChannels[1][0].getInputChannel());
ig2.notifyChannelNonEmpty(inputChannels[1][1].getInputChannel());
ig2.notifyChannelNonEmpty(inputChannels[1][2].getInputChannel());
ig2.notifyChannelNonEmpty(inputChannels[1][3].getInputChannel());
ig2.notifyChannelNonEmpty(inputChannels[1][4].getInputChannel());
ig2.notifyChannelNonEmpty(inputChannels[1][0]);
ig2.notifyChannelNonEmpty(inputChannels[1][1]);
ig2.notifyChannelNonEmpty(inputChannels[1][2]);
ig2.notifyChannelNonEmpty(inputChannels[1][3]);
ig2.notifyChannelNonEmpty(inputChannels[1][4]);
verifyBufferOrEvent(union, true, 0, true); // gate 1, channel 0
verifyBufferOrEvent(union, true, 3, true); // gate 2, channel 0
......
......@@ -28,14 +28,12 @@ import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer;
import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability;
import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel.BufferAndAvailabilityProvider;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.ConcurrentLinkedQueue;
......@@ -43,7 +41,6 @@ import java.util.concurrent.ConcurrentLinkedQueue;
import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.buildSingleBuffer;
import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createBufferBuilder;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;
/**
* Test {@link InputGate} that allows setting multiple channels. Use
......@@ -94,44 +91,40 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
inputQueues[channelIndex] = new ConcurrentLinkedQueue<InputValue<Object>>();
inputChannels[channelIndex] = new TestInputChannel(inputGate, i);
final Answer<Optional<BufferAndAvailability>> answer = new Answer<Optional<BufferAndAvailability>>() {
@Override
public Optional<BufferAndAvailability> answer(InvocationOnMock invocationOnMock) throws Throwable {
ConcurrentLinkedQueue<InputValue<Object>> inputQueue = inputQueues[channelIndex];
InputValue<Object> input;
boolean moreAvailable;
synchronized (inputQueue) {
input = inputQueue.poll();
moreAvailable = !inputQueue.isEmpty();
}
if (input != null && input.isStreamEnd()) {
when(inputChannels[channelIndex].getInputChannel().isReleased()).thenReturn(
true);
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), moreAvailable, 0));
} else if (input != null && input.isStreamRecord()) {
Object inputElement = input.getStreamRecord();
BufferBuilder bufferBuilder = createBufferBuilder(bufferSize);
recordSerializer.continueWritingWithNextBufferBuilder(bufferBuilder);
delegate.setInstance(inputElement);
recordSerializer.addRecord(delegate);
bufferBuilder.finish();
// Call getCurrentBuffer to ensure size is set
return Optional.of(new BufferAndAvailability(buildSingleBuffer(bufferBuilder), moreAvailable, 0));
} else if (input != null && input.isEvent()) {
AbstractEvent event = input.getEvent();
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(event), moreAvailable, 0));
} else {
return Optional.empty();
}
final BufferAndAvailabilityProvider answer = () -> {
ConcurrentLinkedQueue<InputValue<Object>> inputQueue = inputQueues[channelIndex];
InputValue<Object> input;
boolean moreAvailable;
synchronized (inputQueue) {
input = inputQueue.poll();
moreAvailable = !inputQueue.isEmpty();
}
if (input != null && input.isStreamEnd()) {
inputChannels[channelIndex].setReleased();
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), moreAvailable, 0));
} else if (input != null && input.isStreamRecord()) {
Object inputElement = input.getStreamRecord();
BufferBuilder bufferBuilder = createBufferBuilder(bufferSize);
recordSerializer.continueWritingWithNextBufferBuilder(bufferBuilder);
delegate.setInstance(inputElement);
recordSerializer.addRecord(delegate);
bufferBuilder.finish();
// Call getCurrentBuffer to ensure size is set
return Optional.of(new BufferAndAvailability(buildSingleBuffer(bufferBuilder), moreAvailable, 0));
} else if (input != null && input.isEvent()) {
AbstractEvent event = input.getEvent();
return Optional.of(new BufferAndAvailability(EventSerializer.toBuffer(event), moreAvailable, 0));
} else {
return Optional.empty();
}
};
when(inputChannels[channelIndex].getInputChannel().getNextBuffer()).thenAnswer(answer);
inputChannels[channelIndex].addBufferAndAvailability(answer);
inputGate.setInputChannel(new IntermediateResultPartitionID(),
inputChannels[channelIndex].getInputChannel());
inputChannels[channelIndex]);
}
}
......@@ -140,7 +133,7 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
inputQueues[channel].add(InputValue.element(element));
inputQueues[channel].notifyAll();
}
inputGate.notifyChannelNonEmpty(inputChannels[channel].getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannels[channel]);
}
public void sendEvent(AbstractEvent event, int channel) {
......@@ -148,7 +141,7 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
inputQueues[channel].add(InputValue.event(event));
inputQueues[channel].notifyAll();
}
inputGate.notifyChannelNonEmpty(inputChannels[channel].getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannels[channel]);
}
public void endInput() {
......@@ -157,7 +150,7 @@ public class StreamTestSingleInputGate<T> extends TestSingleInputGate {
inputQueues[i].add(InputValue.streamEnd());
inputQueues[i].notifyAll();
}
inputGate.notifyChannelNonEmpty(inputChannels[i].getInputChannel());
inputGate.notifyChannelNonEmpty(inputChannels[i]);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册