提交 1752fdb3 编写于 作者: Z Zhijiang 提交者: Stefan Richter

[FLINK-7416][network] Implement Netty receiver outgoing pipeline for credit-based

上级 268867ce
......@@ -25,10 +25,14 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException;
import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException;
import org.apache.flink.runtime.io.network.netty.exception.TransportException;
import org.apache.flink.runtime.io.network.netty.NettyMessage.AddCredit;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
......@@ -37,6 +41,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.ArrayDeque;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;
......@@ -52,14 +57,23 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
/** Channels, which already requested partitions from the producers. */
private final ConcurrentMap<InputChannelID, RemoteInputChannel> inputChannels = new ConcurrentHashMap<>();
/** Channels, which will notify the producers about unannounced credit. */
private final ArrayDeque<RemoteInputChannel> inputChannelsWithCredit = new ArrayDeque<>();
private final AtomicReference<Throwable> channelError = new AtomicReference<>();
private final ChannelFutureListener writeListener = new WriteAndFlushNextMessageIfPossibleListener();
/**
* Set of cancelled partition requests. A request is cancelled iff an input channel is cleared
* while data is still coming in for this channel.
*/
private final ConcurrentMap<InputChannelID, InputChannelID> cancelled = new ConcurrentHashMap<>();
/**
* The channel handler context is initialized in channel active event by netty thread, the context may also
* be accessed by task thread or canceler thread to cancel partition request during releasing resources.
*/
private volatile ChannelHandlerContext ctx;
// ------------------------------------------------------------------------
......@@ -88,6 +102,22 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
}
}
/**
* The credit begins to announce after receiving the sender's backlog from buffer response.
* Than means it should only happen after some interactions with the channel to make sure
* the context will not be null.
*
* @param inputChannel The input channel with unannounced credits.
*/
void notifyCreditAvailable(final RemoteInputChannel inputChannel) {
ctx.executor().execute(new Runnable() {
@Override
public void run() {
ctx.pipeline().fireUserEventTriggered(inputChannel);
}
});
}
// ------------------------------------------------------------------------
// Network events
// ------------------------------------------------------------------------
......@@ -123,7 +153,6 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof TransportException) {
notifyAllChannelsOfErrorAndClose(cause);
} else {
......@@ -152,6 +181,29 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof RemoteInputChannel) {
// Queue an input channel for available credits announcement.
// If the queue is empty, we try to trigger the actual write. Otherwise
// this will be handled by the writeAndFlushNextMessageIfPossible calls.
boolean triggerWrite = inputChannelsWithCredit.isEmpty();
inputChannelsWithCredit.add((RemoteInputChannel) msg);
if (triggerWrite) {
writeAndFlushNextMessageIfPossible(ctx.channel());
}
} else {
ctx.fireUserEventTriggered(msg);
}
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
writeAndFlushNextMessageIfPossible(ctx.channel());
}
private void notifyAllChannelsOfErrorAndClose(Throwable cause) {
if (channelError.compareAndSet(null, cause)) {
try {
......@@ -163,6 +215,7 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
LOG.warn("An Exception was thrown during error notification of a remote input channel.", t);
} finally {
inputChannels.clear();
inputChannelsWithCredit.clear();
if (ctx != null) {
ctx.close();
......@@ -274,4 +327,57 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
bufferOrEvent.releaseBuffer();
}
}
/**
* Fetches one un-released input channel from the queue and writes the
* unannounced credits immediately. After this is done, we will continue
* with the next input channel via listener's callback.
*/
private void writeAndFlushNextMessageIfPossible(Channel channel) {
if (channelError.get() != null || !channel.isWritable()) {
return;
}
while (true) {
RemoteInputChannel inputChannel = inputChannelsWithCredit.poll();
// The input channel may be null because of the write callbacks
// that are executed after each write.
if (inputChannel == null) {
return;
}
//It is no need to notify credit for the released channel.
if (!inputChannel.isReleased()) {
AddCredit msg = new AddCredit(
inputChannel.getPartitionId(),
inputChannel.getAndResetUnannouncedCredit(),
inputChannel.getInputChannelId());
// Write and flush and wait until this is done before
// trying to continue with the next input channel.
channel.writeAndFlush(msg).addListener(writeListener);
return;
}
}
}
private class WriteAndFlushNextMessageIfPossibleListener implements ChannelFutureListener {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
try {
if (future.isSuccess()) {
writeAndFlushNextMessageIfPossible(future.channel());
} else if (future.cause() != null) {
notifyAllChannelsOfErrorAndClose(future.cause());
} else {
notifyAllChannelsOfErrorAndClose(new IllegalStateException("Sending cancelled by user."));
}
} catch (Throwable t) {
notifyAllChannelsOfErrorAndClose(t);
}
}
}
}
......@@ -198,6 +198,9 @@ public abstract class NettyMessage {
case CloseRequest.ID:
decodedMsg = CloseRequest.readFrom(msg);
break;
case AddCredit.ID:
decodedMsg = AddCredit.readFrom(msg);
break;
default:
throw new ProtocolException("Received unknown message from producer: " + msg);
}
......@@ -584,4 +587,65 @@ public abstract class NettyMessage {
return new CloseRequest();
}
}
/**
* Incremental credit announcement from the client to the server.
*/
static class AddCredit extends NettyMessage {
private static final byte ID = 6;
final ResultPartitionID partitionId;
final int credit;
final InputChannelID receiverId;
AddCredit(ResultPartitionID partitionId, int credit, InputChannelID receiverId) {
checkArgument(credit > 0, "The announced credit should be greater than 0");
this.partitionId = partitionId;
this.credit = credit;
this.receiverId = receiverId;
}
@Override
ByteBuf write(ByteBufAllocator allocator) throws IOException {
ByteBuf result = null;
try {
result = allocateBuffer(allocator, ID, 16 + 16 + 4 + 16);
partitionId.getPartitionId().writeTo(result);
partitionId.getProducerId().writeTo(result);
result.writeInt(credit);
receiverId.writeTo(result);
return result;
}
catch (Throwable t) {
if (result != null) {
result.release();
}
throw new IOException(t);
}
}
static AddCredit readFrom(ByteBuf buffer) {
ResultPartitionID partitionId =
new ResultPartitionID(
IntermediateResultPartitionID.fromByteBuf(buffer),
ExecutionAttemptID.fromByteBuf(buffer));
int credit = buffer.readInt();
InputChannelID receiverId = InputChannelID.fromByteBuf(buffer);
return new AddCredit(partitionId, credit, receiverId);
}
@Override
public String toString() {
return String.format("AddCredit(%s : %d)", receiverId, credit);
}
}
}
......@@ -167,6 +167,13 @@ public class PartitionRequestClient {
});
}
public void notifyCreditAvailable(RemoteInputChannel inputChannel) {
// We should skip the notification if the client is already closed.
if (!closeReferenceCounter.isDisposed()) {
partitionRequestHandler.notifyCreditAvailable(inputChannel);
}
}
public void close(RemoteInputChannel inputChannel) throws IOException {
partitionRequestHandler.removeInputChannel(inputChannel);
......
......@@ -330,6 +330,13 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter {
}
}
/**
* This class would be replaced by CreditBasedClientHandler in the final,
* so we only implement this method in CreditBasedClientHandler.
*/
void notifyCreditAvailable(RemoteInputChannel inputChannel) {
}
private class AsyncErrorNotificationTask implements Runnable {
private final Throwable error;
......
......@@ -100,6 +100,10 @@ public abstract class InputChannel {
return channelIndex;
}
public ResultPartitionID getPartitionId() {
return partitionId;
}
/**
* Notifies the owning {@link SingleInputGate} that this channel became non-empty.
*
......
......@@ -154,8 +154,9 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
/**
* Requests a remote subpartition.
*/
@VisibleForTesting
@Override
void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException {
public void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException {
if (partitionRequestClient == null) {
// Create a client and request the partition
partitionRequestClient = connectionManager
......@@ -279,10 +280,15 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
// ------------------------------------------------------------------------
/**
* Enqueue this input channel in the pipeline for sending unannounced credits to producer.
* Enqueue this input channel in the pipeline for notifying the producer of unannounced credit.
*/
void notifyCreditAvailable() {
//TODO in next PR
checkState(partitionRequestClient != null, "Tried to send task event to producer before requesting a queue.");
// We should skip the notification if this channel is already released.
if (!isReleased.get()) {
partitionRequestClient.notifyCreditAvailable(this);
}
}
/**
......@@ -320,11 +326,14 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
}
}
@VisibleForTesting
public int getNumberOfRequiredBuffers() {
return numRequiredBuffers;
}
public int getSenderBacklog() {
return numRequiredBuffers - initialCredit;
}
/**
* The Buffer pool notifies this channel of an available floating buffer. If the channel is released or
* currently does not need extra buffers, the buffer should be recycled to the buffer pool. Otherwise,
......@@ -379,6 +388,29 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
// Network I/O notifications (called by network I/O thread)
// ------------------------------------------------------------------------
/**
* Gets the currently unannounced credit.
*
* @return Credit which was not announced to the sender yet.
*/
public int getUnannouncedCredit() {
return unannouncedCredit.get();
}
/**
* Gets the unannounced credit and resets it to <tt>0</tt> atomically.
*
* @return Credit which was not announced to the sender yet.
*/
public int getAndResetUnannouncedCredit() {
return unannouncedCredit.getAndSet(0);
}
/**
* Gets the current number of received buffers which have not been processed yet.
*
* @return Buffers queued for processing.
*/
public int getNumberOfQueuedBuffers() {
synchronized (receivedBuffers) {
return receivedBuffers.size();
......@@ -426,7 +458,6 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler,
*
* @param backlog The number of unsent buffers in the producer's sub partition.
*/
@VisibleForTesting
void onSenderBacklog(int backlog) throws IOException {
int numRequestedBuffers = 0;
......
......@@ -158,6 +158,15 @@ public class NettyMessageSerializationTest {
assertEquals(expected.getClass(), actual.getClass());
}
{
NettyMessage.AddCredit expected = new NettyMessage.AddCredit(new ResultPartitionID(new IntermediateResultPartitionID(), new ExecutionAttemptID()), random.nextInt(Integer.MAX_VALUE) + 1, new InputChannelID());
NettyMessage.AddCredit actual = encodeAndDecode(expected);
assertEquals(expected.partitionId, actual.partitionId);
assertEquals(expected.credit, actual.credit);
assertEquals(expected.receiverId, actual.receiverId);
}
}
@SuppressWarnings("unchecked")
......
......@@ -18,34 +18,53 @@
package org.apache.flink.runtime.io.network.netty;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferListener;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.BufferProvider;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.netty.NettyMessage.BufferResponse;
import org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse;
import org.apache.flink.runtime.io.network.netty.NettyMessage.AddCredit;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator;
import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.hamcrest.Matchers.instanceOf;
public class PartitionRequestClientHandlerTest {
......@@ -74,7 +93,7 @@ public class PartitionRequestClientHandlerTest {
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
final BufferResponse receivedBuffer = createBufferResponse(
TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2);
TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2);
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
......@@ -122,21 +141,33 @@ public class PartitionRequestClientHandlerTest {
*/
@Test
public void testReceiveBuffer() throws Exception {
final Buffer buffer = TestBufferFactory.createBuffer();
final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
when(inputChannel.requestBuffer()).thenReturn(buffer);
final int backlog = 2;
final BufferResponse bufferResponse = createBufferResponse(
TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), backlog);
final CreditBasedClientHandler client = new CreditBasedClientHandler();
client.addInputChannel(inputChannel);
client.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog);
final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32);
final SingleInputGate inputGate = createSingleInputGate();
final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
inputGate.setInputChannel(inputChannel.getPartitionId().getPartitionId(), inputChannel);
try {
final BufferPool bufferPool = networkBufferPool.createBufferPool(8, 8);
inputGate.setBufferPool(bufferPool);
final int numExclusiveBuffers = 2;
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
final CreditBasedClientHandler handler = new CreditBasedClientHandler();
handler.addInputChannel(inputChannel);
final int backlog = 2;
final BufferResponse bufferResponse = createBufferResponse(
TestBufferFactory.createBuffer(32), 0, inputChannel.getInputChannelId(), backlog);
handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
assertEquals(1, inputChannel.getNumberOfQueuedBuffers());
assertEquals(2, inputChannel.getSenderBacklog());
} finally {
// Release all the buffer resources
inputGate.releaseAllResources();
networkBufferPool.destroyAllBufferPools();
networkBufferPool.destroy();
}
}
/**
......@@ -145,17 +176,18 @@ public class PartitionRequestClientHandlerTest {
*/
@Test
public void testThrowExceptionForNoAvailableBuffer() throws Exception {
final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
when(inputChannel.requestBuffer()).thenReturn(null);
final SingleInputGate inputGate = createSingleInputGate();
final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate));
final BufferResponse bufferResponse = createBufferResponse(
TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2);
final CreditBasedClientHandler handler = new CreditBasedClientHandler();
handler.addInputChannel(inputChannel);
final CreditBasedClientHandler client = new CreditBasedClientHandler();
client.addInputChannel(inputChannel);
assertEquals("There should be no buffers available in the channel.",
0, inputChannel.getNumberOfAvailableBuffers());
client.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
final BufferResponse bufferResponse = createBufferResponse(
TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2);
handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
verify(inputChannel, times(1)).onError(any(IllegalStateException.class));
}
......@@ -208,8 +240,200 @@ public class PartitionRequestClientHandlerTest {
client.cancelRequestFor(inputChannel.getInputChannelId());
}
/**
* Verifies that {@link RemoteInputChannel} is enqueued in the pipeline for notifying credits,
* and verifies the behaviour of credit notification by triggering channel's writability changed.
*/
@Test
public void testNotifyCreditAvailable() throws Exception {
final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32);
final SingleInputGate inputGate = createSingleInputGate();
final RemoteInputChannel inputChannel1 = createRemoteInputChannel(inputGate);
final RemoteInputChannel inputChannel2 = createRemoteInputChannel(inputGate);
inputGate.setInputChannel(inputChannel1.getPartitionId().getPartitionId(), inputChannel1);
inputGate.setInputChannel(inputChannel2.getPartitionId().getPartitionId(), inputChannel2);
try {
final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6);
inputGate.setBufferPool(bufferPool);
final int numExclusiveBuffers = 2;
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
final CreditBasedClientHandler handler = new CreditBasedClientHandler();
final EmbeddedChannel channel = new EmbeddedChannel(handler);
// The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we
// have to add input channels in CreditBasedClientHandler explicitly
inputChannel1.requestSubpartition(0);
inputChannel2.requestSubpartition(0);
handler.addInputChannel(inputChannel1);
handler.addInputChannel(inputChannel2);
// The buffer response will take one available buffer from input channel, and it will trigger
// requesting (backlog + numExclusiveBuffers - numAvailableBuffers) floating buffers
final BufferResponse bufferResponse1 = createBufferResponse(
TestBufferFactory.createBuffer(32), 0, inputChannel1.getInputChannelId(), 1);
final BufferResponse bufferResponse2 = createBufferResponse(
TestBufferFactory.createBuffer(32), 0, inputChannel2.getInputChannelId(), 1);
handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse1);
handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse2);
// The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we
// have to notify credit available in CreditBasedClientHandler explicitly
handler.notifyCreditAvailable(inputChannel1);
handler.notifyCreditAvailable(inputChannel2);
assertEquals(2, inputChannel1.getUnannouncedCredit());
assertEquals(2, inputChannel2.getUnannouncedCredit());
channel.runPendingTasks();
// The two input channels should notify credits via writable channel
assertTrue(channel.isWritable());
Object readFromOutbound = channel.readOutbound();
assertThat(readFromOutbound, instanceOf(AddCredit.class));
assertEquals(2, ((AddCredit) readFromOutbound).credit);
readFromOutbound = channel.readOutbound();
assertThat(readFromOutbound, instanceOf(AddCredit.class));
assertEquals(2, ((AddCredit) readFromOutbound).credit);
assertNull(channel.readOutbound());
final int highWaterMark = channel.config().getWriteBufferHighWaterMark();
// Set the writer index to the high water mark to ensure that all bytes are written
// to the wire although the buffer is "empty".
ByteBuf channelBlockingBuffer = Unpooled.buffer(highWaterMark).writerIndex(highWaterMark);
channel.write(channelBlockingBuffer);
// Trigger notify credits available via buffer response on the condition of un-writable channel
final BufferResponse bufferResponse3 = createBufferResponse(
TestBufferFactory.createBuffer(32), 1, inputChannel1.getInputChannelId(), 1);
handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse3);
handler.notifyCreditAvailable(inputChannel1);
assertEquals(1, inputChannel1.getUnannouncedCredit());
assertEquals(0, inputChannel2.getUnannouncedCredit());
channel.runPendingTasks();
// The input channel will not notify credits via un-writable channel
assertFalse(channel.isWritable());
assertNull(channel.readOutbound());
// Flush the buffer to make the channel writable again
channel.flush();
assertSame(channelBlockingBuffer, channel.readOutbound());
// The input channel should notify credits via channel's writability changed event
assertTrue(channel.isWritable());
readFromOutbound = channel.readOutbound();
assertThat(readFromOutbound, instanceOf(AddCredit.class));
assertEquals(1, ((AddCredit) readFromOutbound).credit);
assertEquals(0, inputChannel1.getUnannouncedCredit());
assertEquals(0, inputChannel2.getUnannouncedCredit());
// no more messages
assertNull(channel.readOutbound());
} finally {
// Release all the buffer resources
inputGate.releaseAllResources();
networkBufferPool.destroyAllBufferPools();
networkBufferPool.destroy();
}
}
/**
* Verifies that {@link RemoteInputChannel} is enqueued in the pipeline, but {@link AddCredit}
* message is not sent actually when this input channel is released.
*/
@Test
public void testNotifyCreditAvailableAfterReleased() throws Exception {
final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32);
final SingleInputGate inputGate = createSingleInputGate();
final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
inputGate.setInputChannel(inputChannel.getPartitionId().getPartitionId(), inputChannel);
try {
final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6);
inputGate.setBufferPool(bufferPool);
final int numExclusiveBuffers = 2;
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
final CreditBasedClientHandler handler = new CreditBasedClientHandler();
final EmbeddedChannel channel = new EmbeddedChannel(handler);
// The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we
// have to add input channels in CreditBasedClientHandler explicitly
inputChannel.requestSubpartition(0);
handler.addInputChannel(inputChannel);
// Trigger request floating buffers via buffer response to notify credits available
final BufferResponse bufferResponse = createBufferResponse(
TestBufferFactory.createBuffer(32), 0, inputChannel.getInputChannelId(), 1);
handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse);
assertEquals(2, inputChannel.getUnannouncedCredit());
// The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we
// have to notify credit available in CreditBasedClientHandler explicitly
handler.notifyCreditAvailable(inputChannel);
// Release the input channel
inputGate.releaseAllResources();
channel.runPendingTasks();
// It will not notify credits for released input channel
assertNull(channel.readOutbound());
} finally {
// Release all the buffer resources
inputGate.releaseAllResources();
networkBufferPool.destroyAllBufferPools();
networkBufferPool.destroy();
}
}
// ---------------------------------------------------------------------------------------------
/**
* Creates and returns the single input gate for credit-based testing.
*
* @return The new created single input gate.
*/
private SingleInputGate createSingleInputGate() {
return new SingleInputGate(
"InputGate",
new JobID(),
new IntermediateDataSetID(),
ResultPartitionType.PIPELINED_CREDIT_BASED,
0,
1,
mock(TaskActions.class),
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
}
/**
* Creates and returns a remote input channel for the specific input gate.
*
* @param inputGate The input gate owns the created input channel.
* @return The new created remote input channel.
*/
private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws Exception {
final ConnectionManager connectionManager = mock(ConnectionManager.class);
final PartitionRequestClient partitionRequestClient = mock(PartitionRequestClient.class);
when(connectionManager.createPartitionRequestClient(any(ConnectionID.class)))
.thenReturn(partitionRequestClient);
return new RemoteInputChannel(
inputGate,
0,
new ResultPartitionID(),
mock(ConnectionID.class),
connectionManager,
0,
0,
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
}
/**
* Returns a deserialized buffer message as it would be received during runtime.
*/
......
......@@ -328,6 +328,7 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
inputChannel.requestSubpartition(0);
// Prepare the exclusive and floating buffers to verify recycle logic later
final Buffer exclusiveBuffer = inputChannel.requestBuffer();
......@@ -449,6 +450,7 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
inputChannel.requestSubpartition(0);
// Prepare the exclusive and floating buffers to verify recycle logic later
final Buffer exclusiveBuffer = inputChannel.requestBuffer();
......@@ -526,6 +528,7 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
inputChannel.requestSubpartition(0);
// Prepare the exclusive and floating buffers to verify recycle logic later
final Buffer exclusiveBuffer = inputChannel.requestBuffer();
......@@ -621,6 +624,9 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
channel1.requestSubpartition(0);
channel2.requestSubpartition(0);
channel3.requestSubpartition(0);
// Exhaust all the floating buffers
final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers);
......@@ -690,6 +696,7 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
inputChannel.requestSubpartition(0);
final Callable<Void> requestBufferTask = new Callable<Void>() {
@Override
......@@ -758,6 +765,7 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments);
inputChannel.requestSubpartition(0);
final Callable<Void> requestBufferTask = new Callable<Void>() {
@Override
......@@ -772,9 +780,9 @@ public class RemoteInputChannelTest {
// Submit tasks and wait to finish
submitTasksAndWaitForResults(executor, new Callable[]{
recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
requestBufferTask});
recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
requestBufferTask});
assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.",
inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers());
......@@ -813,6 +821,7 @@ public class RemoteInputChannelTest {
final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
inputGate.setBufferPool(bufferPool);
inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments);
inputChannel.requestSubpartition(0);
final Callable<Void> releaseTask = new Callable<Void>() {
@Override
......@@ -825,9 +834,9 @@ public class RemoteInputChannelTest {
// Submit tasks and wait to finish
submitTasksAndWaitForResults(executor, new Callable[]{
recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
releaseTask});
recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
releaseTask});
assertEquals("There should be no buffers available in the channel.",
0, inputChannel.getNumberOfAvailableBuffers());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册