From 1752fdb339df4e4d0a5063b24c460abdc0a44264 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Thu, 28 Sep 2017 23:39:26 +0800 Subject: [PATCH] [FLINK-7416][network] Implement Netty receiver outgoing pipeline for credit-based --- .../netty/CreditBasedClientHandler.java | 108 ++++++- .../io/network/netty/NettyMessage.java | 64 ++++ .../network/netty/PartitionRequestClient.java | 7 + .../netty/PartitionRequestClientHandler.java | 7 + .../partition/consumer/InputChannel.java | 4 + .../consumer/RemoteInputChannel.java | 41 ++- .../netty/NettyMessageSerializationTest.java | 9 + .../PartitionRequestClientHandlerTest.java | 276 ++++++++++++++++-- .../consumer/RemoteInputChannelTest.java | 21 +- 9 files changed, 499 insertions(+), 38 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java index 1f1858843ef..f5279bff1b5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java @@ -25,10 +25,14 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException; import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException; import org.apache.flink.runtime.io.network.netty.exception.TransportException; +import org.apache.flink.runtime.io.network.netty.NettyMessage.AddCredit; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFuture; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelFutureListener; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; @@ -37,6 +41,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; @@ -52,14 +57,23 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { /** Channels, which already requested partitions from the producers. */ private final ConcurrentMap inputChannels = new ConcurrentHashMap<>(); + /** Channels, which will notify the producers about unannounced credit. */ + private final ArrayDeque inputChannelsWithCredit = new ArrayDeque<>(); + private final AtomicReference channelError = new AtomicReference<>(); + private final ChannelFutureListener writeListener = new WriteAndFlushNextMessageIfPossibleListener(); + /** * Set of cancelled partition requests. A request is cancelled iff an input channel is cleared * while data is still coming in for this channel. */ private final ConcurrentMap cancelled = new ConcurrentHashMap<>(); + /** + * The channel handler context is initialized in channel active event by netty thread, the context may also + * be accessed by task thread or canceler thread to cancel partition request during releasing resources. + */ private volatile ChannelHandlerContext ctx; // ------------------------------------------------------------------------ @@ -88,6 +102,22 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { } } + /** + * The credit begins to announce after receiving the sender's backlog from buffer response. + * Than means it should only happen after some interactions with the channel to make sure + * the context will not be null. + * + * @param inputChannel The input channel with unannounced credits. + */ + void notifyCreditAvailable(final RemoteInputChannel inputChannel) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + ctx.pipeline().fireUserEventTriggered(inputChannel); + } + }); + } + // ------------------------------------------------------------------------ // Network events // ------------------------------------------------------------------------ @@ -123,7 +153,6 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause instanceof TransportException) { notifyAllChannelsOfErrorAndClose(cause); } else { @@ -152,6 +181,29 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { } } + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof RemoteInputChannel) { + // Queue an input channel for available credits announcement. + // If the queue is empty, we try to trigger the actual write. Otherwise + // this will be handled by the writeAndFlushNextMessageIfPossible calls. + boolean triggerWrite = inputChannelsWithCredit.isEmpty(); + + inputChannelsWithCredit.add((RemoteInputChannel) msg); + + if (triggerWrite) { + writeAndFlushNextMessageIfPossible(ctx.channel()); + } + } else { + ctx.fireUserEventTriggered(msg); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + writeAndFlushNextMessageIfPossible(ctx.channel()); + } + private void notifyAllChannelsOfErrorAndClose(Throwable cause) { if (channelError.compareAndSet(null, cause)) { try { @@ -163,6 +215,7 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { LOG.warn("An Exception was thrown during error notification of a remote input channel.", t); } finally { inputChannels.clear(); + inputChannelsWithCredit.clear(); if (ctx != null) { ctx.close(); @@ -274,4 +327,57 @@ class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { bufferOrEvent.releaseBuffer(); } } + + /** + * Fetches one un-released input channel from the queue and writes the + * unannounced credits immediately. After this is done, we will continue + * with the next input channel via listener's callback. + */ + private void writeAndFlushNextMessageIfPossible(Channel channel) { + if (channelError.get() != null || !channel.isWritable()) { + return; + } + + while (true) { + RemoteInputChannel inputChannel = inputChannelsWithCredit.poll(); + + // The input channel may be null because of the write callbacks + // that are executed after each write. + if (inputChannel == null) { + return; + } + + //It is no need to notify credit for the released channel. + if (!inputChannel.isReleased()) { + AddCredit msg = new AddCredit( + inputChannel.getPartitionId(), + inputChannel.getAndResetUnannouncedCredit(), + inputChannel.getInputChannelId()); + + // Write and flush and wait until this is done before + // trying to continue with the next input channel. + channel.writeAndFlush(msg).addListener(writeListener); + + return; + } + } + } + + private class WriteAndFlushNextMessageIfPossibleListener implements ChannelFutureListener { + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + try { + if (future.isSuccess()) { + writeAndFlushNextMessageIfPossible(future.channel()); + } else if (future.cause() != null) { + notifyAllChannelsOfErrorAndClose(future.cause()); + } else { + notifyAllChannelsOfErrorAndClose(new IllegalStateException("Sending cancelled by user.")); + } + } catch (Throwable t) { + notifyAllChannelsOfErrorAndClose(t); + } + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index db1b899b832..cffad83f21f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -198,6 +198,9 @@ public abstract class NettyMessage { case CloseRequest.ID: decodedMsg = CloseRequest.readFrom(msg); break; + case AddCredit.ID: + decodedMsg = AddCredit.readFrom(msg); + break; default: throw new ProtocolException("Received unknown message from producer: " + msg); } @@ -584,4 +587,65 @@ public abstract class NettyMessage { return new CloseRequest(); } } + + /** + * Incremental credit announcement from the client to the server. + */ + static class AddCredit extends NettyMessage { + + private static final byte ID = 6; + + final ResultPartitionID partitionId; + + final int credit; + + final InputChannelID receiverId; + + AddCredit(ResultPartitionID partitionId, int credit, InputChannelID receiverId) { + checkArgument(credit > 0, "The announced credit should be greater than 0"); + + this.partitionId = partitionId; + this.credit = credit; + this.receiverId = receiverId; + } + + @Override + ByteBuf write(ByteBufAllocator allocator) throws IOException { + ByteBuf result = null; + + try { + result = allocateBuffer(allocator, ID, 16 + 16 + 4 + 16); + + partitionId.getPartitionId().writeTo(result); + partitionId.getProducerId().writeTo(result); + result.writeInt(credit); + receiverId.writeTo(result); + + return result; + } + catch (Throwable t) { + if (result != null) { + result.release(); + } + + throw new IOException(t); + } + } + + static AddCredit readFrom(ByteBuf buffer) { + ResultPartitionID partitionId = + new ResultPartitionID( + IntermediateResultPartitionID.fromByteBuf(buffer), + ExecutionAttemptID.fromByteBuf(buffer)); + int credit = buffer.readInt(); + InputChannelID receiverId = InputChannelID.fromByteBuf(buffer); + + return new AddCredit(partitionId, credit, receiverId); + } + + @Override + public String toString() { + return String.format("AddCredit(%s : %d)", receiverId, credit); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java index 8dbc6b7a02c..12a9531784d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClient.java @@ -167,6 +167,13 @@ public class PartitionRequestClient { }); } + public void notifyCreditAvailable(RemoteInputChannel inputChannel) { + // We should skip the notification if the client is already closed. + if (!closeReferenceCounter.isDisposed()) { + partitionRequestHandler.notifyCreditAvailable(inputChannel); + } + } + public void close(RemoteInputChannel inputChannel) throws IOException { partitionRequestHandler.removeInputChannel(inputChannel); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java index ab4798e2172..e50c0592c29 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java @@ -330,6 +330,13 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter { } } + /** + * This class would be replaced by CreditBasedClientHandler in the final, + * so we only implement this method in CreditBasedClientHandler. + */ + void notifyCreditAvailable(RemoteInputChannel inputChannel) { + } + private class AsyncErrorNotificationTask implements Runnable { private final Throwable error; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java index f46abfdc154..68b05d45aa8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java @@ -100,6 +100,10 @@ public abstract class InputChannel { return channelIndex; } + public ResultPartitionID getPartitionId() { + return partitionId; + } + /** * Notifies the owning {@link SingleInputGate} that this channel became non-empty. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 02c7b34863e..7605075c6f1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -154,8 +154,9 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, /** * Requests a remote subpartition. */ + @VisibleForTesting @Override - void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException { + public void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException { if (partitionRequestClient == null) { // Create a client and request the partition partitionRequestClient = connectionManager @@ -279,10 +280,15 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, // ------------------------------------------------------------------------ /** - * Enqueue this input channel in the pipeline for sending unannounced credits to producer. + * Enqueue this input channel in the pipeline for notifying the producer of unannounced credit. */ void notifyCreditAvailable() { - //TODO in next PR + checkState(partitionRequestClient != null, "Tried to send task event to producer before requesting a queue."); + + // We should skip the notification if this channel is already released. + if (!isReleased.get()) { + partitionRequestClient.notifyCreditAvailable(this); + } } /** @@ -320,11 +326,14 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, } } - @VisibleForTesting public int getNumberOfRequiredBuffers() { return numRequiredBuffers; } + public int getSenderBacklog() { + return numRequiredBuffers - initialCredit; + } + /** * The Buffer pool notifies this channel of an available floating buffer. If the channel is released or * currently does not need extra buffers, the buffer should be recycled to the buffer pool. Otherwise, @@ -379,6 +388,29 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, // Network I/O notifications (called by network I/O thread) // ------------------------------------------------------------------------ + /** + * Gets the currently unannounced credit. + * + * @return Credit which was not announced to the sender yet. + */ + public int getUnannouncedCredit() { + return unannouncedCredit.get(); + } + + /** + * Gets the unannounced credit and resets it to 0 atomically. + * + * @return Credit which was not announced to the sender yet. + */ + public int getAndResetUnannouncedCredit() { + return unannouncedCredit.getAndSet(0); + } + + /** + * Gets the current number of received buffers which have not been processed yet. + * + * @return Buffers queued for processing. + */ public int getNumberOfQueuedBuffers() { synchronized (receivedBuffers) { return receivedBuffers.size(); @@ -426,7 +458,6 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, * * @param backlog The number of unsent buffers in the producer's sub partition. */ - @VisibleForTesting void onSenderBacklog(int backlog) throws IOException { int numRequestedBuffers = 0; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java index 8c87cebca24..98614bcbe62 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java @@ -158,6 +158,15 @@ public class NettyMessageSerializationTest { assertEquals(expected.getClass(), actual.getClass()); } + + { + NettyMessage.AddCredit expected = new NettyMessage.AddCredit(new ResultPartitionID(new IntermediateResultPartitionID(), new ExecutionAttemptID()), random.nextInt(Integer.MAX_VALUE) + 1, new InputChannelID()); + NettyMessage.AddCredit actual = encodeAndDecode(expected); + + assertEquals(expected.partitionId, actual.partitionId); + assertEquals(expected.credit, actual.credit); + assertEquals(expected.receiverId, actual.receiverId); + } } @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java index d3ff6c26afc..42a5f11bbb7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java @@ -18,34 +18,53 @@ package org.apache.flink.runtime.io.network.netty; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.io.network.ConnectionID; +import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferListener; +import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.netty.NettyMessage.BufferResponse; import org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse; +import org.apache.flink.runtime.io.network.netty.NettyMessage.AddCredit; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.util.TestBufferFactory; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; import org.junit.Test; import java.io.IOException; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.hamcrest.Matchers.instanceOf; public class PartitionRequestClientHandlerTest { @@ -74,7 +93,7 @@ public class PartitionRequestClientHandlerTest { when(inputChannel.getBufferProvider()).thenReturn(bufferProvider); final BufferResponse receivedBuffer = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); client.addInputChannel(inputChannel); @@ -122,21 +141,33 @@ public class PartitionRequestClientHandlerTest { */ @Test public void testReceiveBuffer() throws Exception { - final Buffer buffer = TestBufferFactory.createBuffer(); - final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); - when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); - when(inputChannel.requestBuffer()).thenReturn(buffer); - - final int backlog = 2; - final BufferResponse bufferResponse = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), backlog); - - final CreditBasedClientHandler client = new CreditBasedClientHandler(); - client.addInputChannel(inputChannel); - - client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); - - verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.getPartitionId().getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(8, 8); + inputGate.setBufferPool(bufferPool); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + handler.addInputChannel(inputChannel); + + final int backlog = 2; + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel.getInputChannelId(), backlog); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + assertEquals(1, inputChannel.getNumberOfQueuedBuffers()); + assertEquals(2, inputChannel.getSenderBacklog()); + } finally { + // Release all the buffer resources + inputGate.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** @@ -145,17 +176,18 @@ public class PartitionRequestClientHandlerTest { */ @Test public void testThrowExceptionForNoAvailableBuffer() throws Exception { - final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); - when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); - when(inputChannel.requestBuffer()).thenReturn(null); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); - final BufferResponse bufferResponse = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + handler.addInputChannel(inputChannel); - final CreditBasedClientHandler client = new CreditBasedClientHandler(); - client.addInputChannel(inputChannel); + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); - client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse); verify(inputChannel, times(1)).onError(any(IllegalStateException.class)); } @@ -208,8 +240,200 @@ public class PartitionRequestClientHandlerTest { client.cancelRequestFor(inputChannel.getInputChannelId()); } + /** + * Verifies that {@link RemoteInputChannel} is enqueued in the pipeline for notifying credits, + * and verifies the behaviour of credit notification by triggering channel's writability changed. + */ + @Test + public void testNotifyCreditAvailable() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel1 = createRemoteInputChannel(inputGate); + final RemoteInputChannel inputChannel2 = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel1.getPartitionId().getPartitionId(), inputChannel1); + inputGate.setInputChannel(inputChannel2.getPartitionId().getPartitionId(), inputChannel2); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6); + inputGate.setBufferPool(bufferPool); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to add input channels in CreditBasedClientHandler explicitly + inputChannel1.requestSubpartition(0); + inputChannel2.requestSubpartition(0); + handler.addInputChannel(inputChannel1); + handler.addInputChannel(inputChannel2); + + // The buffer response will take one available buffer from input channel, and it will trigger + // requesting (backlog + numExclusiveBuffers - numAvailableBuffers) floating buffers + final BufferResponse bufferResponse1 = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel1.getInputChannelId(), 1); + final BufferResponse bufferResponse2 = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel2.getInputChannelId(), 1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse2); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to notify credit available in CreditBasedClientHandler explicitly + handler.notifyCreditAvailable(inputChannel1); + handler.notifyCreditAvailable(inputChannel2); + + assertEquals(2, inputChannel1.getUnannouncedCredit()); + assertEquals(2, inputChannel2.getUnannouncedCredit()); + + channel.runPendingTasks(); + + // The two input channels should notify credits via writable channel + assertTrue(channel.isWritable()); + Object readFromOutbound = channel.readOutbound(); + assertThat(readFromOutbound, instanceOf(AddCredit.class)); + assertEquals(2, ((AddCredit) readFromOutbound).credit); + readFromOutbound = channel.readOutbound(); + assertThat(readFromOutbound, instanceOf(AddCredit.class)); + assertEquals(2, ((AddCredit) readFromOutbound).credit); + assertNull(channel.readOutbound()); + + final int highWaterMark = channel.config().getWriteBufferHighWaterMark(); + // Set the writer index to the high water mark to ensure that all bytes are written + // to the wire although the buffer is "empty". + ByteBuf channelBlockingBuffer = Unpooled.buffer(highWaterMark).writerIndex(highWaterMark); + channel.write(channelBlockingBuffer); + + // Trigger notify credits available via buffer response on the condition of un-writable channel + final BufferResponse bufferResponse3 = createBufferResponse( + TestBufferFactory.createBuffer(32), 1, inputChannel1.getInputChannelId(), 1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse3); + handler.notifyCreditAvailable(inputChannel1); + + assertEquals(1, inputChannel1.getUnannouncedCredit()); + assertEquals(0, inputChannel2.getUnannouncedCredit()); + + channel.runPendingTasks(); + + // The input channel will not notify credits via un-writable channel + assertFalse(channel.isWritable()); + assertNull(channel.readOutbound()); + + // Flush the buffer to make the channel writable again + channel.flush(); + assertSame(channelBlockingBuffer, channel.readOutbound()); + + // The input channel should notify credits via channel's writability changed event + assertTrue(channel.isWritable()); + readFromOutbound = channel.readOutbound(); + assertThat(readFromOutbound, instanceOf(AddCredit.class)); + assertEquals(1, ((AddCredit) readFromOutbound).credit); + assertEquals(0, inputChannel1.getUnannouncedCredit()); + assertEquals(0, inputChannel2.getUnannouncedCredit()); + + // no more messages + assertNull(channel.readOutbound()); + } finally { + // Release all the buffer resources + inputGate.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + + /** + * Verifies that {@link RemoteInputChannel} is enqueued in the pipeline, but {@link AddCredit} + * message is not sent actually when this input channel is released. + */ + @Test + public void testNotifyCreditAvailableAfterReleased() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(10, 32); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.getPartitionId().getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(6, 6); + inputGate.setBufferPool(bufferPool); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final CreditBasedClientHandler handler = new CreditBasedClientHandler(); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to add input channels in CreditBasedClientHandler explicitly + inputChannel.requestSubpartition(0); + handler.addInputChannel(inputChannel); + + // Trigger request floating buffers via buffer response to notify credits available + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(32), 0, inputChannel.getInputChannelId(), 1); + handler.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + assertEquals(2, inputChannel.getUnannouncedCredit()); + + // The PartitionRequestClient is tied to PartitionRequestClientHandler currently, so we + // have to notify credit available in CreditBasedClientHandler explicitly + handler.notifyCreditAvailable(inputChannel); + + // Release the input channel + inputGate.releaseAllResources(); + + channel.runPendingTasks(); + + // It will not notify credits for released input channel + assertNull(channel.readOutbound()); + } finally { + // Release all the buffer resources + inputGate.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + // --------------------------------------------------------------------------------------------- + /** + * Creates and returns the single input gate for credit-based testing. + * + * @return The new created single input gate. + */ + private SingleInputGate createSingleInputGate() { + return new SingleInputGate( + "InputGate", + new JobID(), + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED_CREDIT_BASED, + 0, + 1, + mock(TaskActions.class), + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + } + + /** + * Creates and returns a remote input channel for the specific input gate. + * + * @param inputGate The input gate owns the created input channel. + * @return The new created remote input channel. + */ + private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws Exception { + final ConnectionManager connectionManager = mock(ConnectionManager.class); + final PartitionRequestClient partitionRequestClient = mock(PartitionRequestClient.class); + when(connectionManager.createPartitionRequestClient(any(ConnectionID.class))) + .thenReturn(partitionRequestClient); + + return new RemoteInputChannel( + inputGate, + 0, + new ResultPartitionID(), + mock(ConnectionID.class), + connectionManager, + 0, + 0, + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + } + /** * Returns a deserialized buffer message as it would be received during runtime. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index 863f8865c6f..eab1d89f63b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -328,6 +328,7 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); // Prepare the exclusive and floating buffers to verify recycle logic later final Buffer exclusiveBuffer = inputChannel.requestBuffer(); @@ -449,6 +450,7 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); // Prepare the exclusive and floating buffers to verify recycle logic later final Buffer exclusiveBuffer = inputChannel.requestBuffer(); @@ -526,6 +528,7 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); // Prepare the exclusive and floating buffers to verify recycle logic later final Buffer exclusiveBuffer = inputChannel.requestBuffer(); @@ -621,6 +624,9 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + channel1.requestSubpartition(0); + channel2.requestSubpartition(0); + channel3.requestSubpartition(0); // Exhaust all the floating buffers final List floatingBuffers = new ArrayList<>(numFloatingBuffers); @@ -690,6 +696,7 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + inputChannel.requestSubpartition(0); final Callable requestBufferTask = new Callable() { @Override @@ -758,6 +765,7 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + inputChannel.requestSubpartition(0); final Callable requestBufferTask = new Callable() { @Override @@ -772,9 +780,9 @@ public class RemoteInputChannelTest { // Submit tasks and wait to finish submitTasksAndWaitForResults(executor, new Callable[]{ - recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), - recycleFloatingBufferTask(bufferPool, numFloatingBuffers), - requestBufferTask}); + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + requestBufferTask}); assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.", inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers()); @@ -813,6 +821,7 @@ public class RemoteInputChannelTest { final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); inputGate.setBufferPool(bufferPool); inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + inputChannel.requestSubpartition(0); final Callable releaseTask = new Callable() { @Override @@ -825,9 +834,9 @@ public class RemoteInputChannelTest { // Submit tasks and wait to finish submitTasksAndWaitForResults(executor, new Callable[]{ - recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), - recycleFloatingBufferTask(bufferPool, numFloatingBuffers), - releaseTask}); + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + releaseTask}); assertEquals("There should be no buffers available in the channel.", 0, inputChannel.getNumberOfAvailableBuffers()); -- GitLab