提交 94123cec 编写于 作者: N Nico Kruber 提交者: zentol

[FLINK-7749][network] Refactor ResultPartitionWriter into an interface

This closes #5127.
上级 175e1b38
...@@ -22,7 +22,6 @@ import org.apache.flink.api.common.JobID; ...@@ -22,7 +22,6 @@ import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManager.IOMode; import org.apache.flink.runtime.io.disk.iomanager.IOManager.IOMode;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartition;
...@@ -174,11 +173,6 @@ public class NetworkEnvironment { ...@@ -174,11 +173,6 @@ public class NetworkEnvironment {
public void registerTask(Task task) throws IOException { public void registerTask(Task task) throws IOException {
final ResultPartition[] producedPartitions = task.getProducedPartitions(); final ResultPartition[] producedPartitions = task.getProducedPartitions();
final ResultPartitionWriter[] writers = task.getAllWriters();
if (writers.length != producedPartitions.length) {
throw new IllegalStateException("Unequal number of writers and partitions.");
}
synchronized (lock) { synchronized (lock) {
if (isShutdown) { if (isShutdown) {
...@@ -187,7 +181,6 @@ public class NetworkEnvironment { ...@@ -187,7 +181,6 @@ public class NetworkEnvironment {
for (int i = 0; i < producedPartitions.length; i++) { for (int i = 0; i < producedPartitions.length; i++) {
final ResultPartition partition = producedPartitions[i]; final ResultPartition partition = producedPartitions[i];
final ResultPartitionWriter writer = writers[i];
// Buffer pool for the partition // Buffer pool for the partition
BufferPool bufferPool = null; BufferPool bufferPool = null;
...@@ -214,7 +207,7 @@ public class NetworkEnvironment { ...@@ -214,7 +207,7 @@ public class NetworkEnvironment {
} }
// Register writer with task event dispatcher // Register writer with task event dispatcher
taskEventDispatcher.registerPartition(writer.getPartitionId()); taskEventDispatcher.registerPartition(partition.getPartitionId());
} }
// Setup the buffer pool for each buffer reader // Setup the buffer pool for each buffer reader
...@@ -263,19 +256,10 @@ public class NetworkEnvironment { ...@@ -263,19 +256,10 @@ public class NetworkEnvironment {
resultPartitionManager.releasePartitionsProducedBy(executionId, task.getFailureCause()); resultPartitionManager.releasePartitionsProducedBy(executionId, task.getFailureCause());
} }
ResultPartitionWriter[] writers = task.getAllWriters(); for (ResultPartition partition : task.getProducedPartitions()) {
if (writers != null) { taskEventDispatcher.unregisterPartition(partition.getPartitionId());
for (ResultPartitionWriter writer : writers) {
taskEventDispatcher.unregisterPartition(writer.getPartitionId());
}
}
ResultPartition[] partitions = task.getProducedPartitions();
if (partitions != null) {
for (ResultPartition partition : partitions) {
partition.destroyBufferPool(); partition.destroyBufferPool();
} }
}
final SingleInputGate[] inputGates = task.getAllInputGates(); final SingleInputGate[] inputGates = task.getAllInputGates();
......
...@@ -36,11 +36,11 @@ import static org.apache.flink.runtime.io.network.api.serialization.RecordSerial ...@@ -36,11 +36,11 @@ import static org.apache.flink.runtime.io.network.api.serialization.RecordSerial
/** /**
* A record-oriented runtime result writer. * A record-oriented runtime result writer.
* <p> *
* The RecordWriter wraps the runtime's {@link ResultPartitionWriter} and takes care of * <p>The RecordWriter wraps the runtime's {@link ResultPartitionWriter} and takes care of
* serializing records into buffers. * serializing records into buffers.
* <p> *
* <strong>Important</strong>: it is necessary to call {@link #flush()} after * <p><strong>Important</strong>: it is necessary to call {@link #flush()} after
* all records have been written with {@link #emit(IOReadableWritable)}. This * all records have been written with {@link #emit(IOReadableWritable)}. This
* ensures that all produced records are written to the output stream (incl. * ensures that all produced records are written to the output stream (incl.
* partially filled ones). * partially filled ones).
...@@ -71,7 +71,7 @@ public class RecordWriter<T extends IOReadableWritable> { ...@@ -71,7 +71,7 @@ public class RecordWriter<T extends IOReadableWritable> {
this.targetPartition = writer; this.targetPartition = writer;
this.channelSelector = channelSelector; this.channelSelector = channelSelector;
this.numChannels = writer.getNumberOfOutputChannels(); this.numChannels = writer.getNumberOfSubpartitions();
/** /**
* The runtime exposes a channel abstraction for the produced results * The runtime exposes a channel abstraction for the produced results
......
...@@ -20,73 +20,50 @@ package org.apache.flink.runtime.io.network.api.writer; ...@@ -20,73 +20,50 @@ package org.apache.flink.runtime.io.network.api.writer;
import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.BufferProvider;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import java.io.IOException; import java.io.IOException;
/** /**
* A buffer-oriented runtime result writer. * A buffer-oriented runtime result writer API for producing results.
* <p>
* The {@link ResultPartitionWriter} is the runtime API for producing results. It
* supports two kinds of data to be sent: buffers and events.
*/ */
public class ResultPartitionWriter { public interface ResultPartitionWriter {
private final ResultPartition partition; BufferProvider getBufferProvider();
public ResultPartitionWriter(ResultPartition partition) { ResultPartitionID getPartitionId();
this.partition = partition;
}
// ------------------------------------------------------------------------
// Attributes
// ------------------------------------------------------------------------
public ResultPartitionID getPartitionId() { int getNumberOfSubpartitions();
return partition.getPartitionId();
}
public BufferProvider getBufferProvider() { int getNumTargetKeyGroups();
return partition.getBufferProvider();
}
public int getNumberOfOutputChannels() { /**
return partition.getNumberOfSubpartitions(); * Adds a buffer to the subpartition with the given index.
} *
* <p>For PIPELINED {@link org.apache.flink.runtime.io.network.partition.ResultPartitionType}s,
public int getNumTargetKeyGroups() { * this will trigger the deployment of consuming tasks after the first buffer has been added.
return partition.getNumTargetKeyGroups(); */
} void writeBuffer(Buffer buffer, int subpartitionIndex) throws IOException;
// ------------------------------------------------------------------------
// Data processing
// ------------------------------------------------------------------------
public void writeBuffer(Buffer buffer, int targetChannel) throws IOException {
partition.add(buffer, targetChannel);
}
/** /**
* Writes the given buffer to all available target channels. * Writes the given buffer to all available target subpartitions.
* *
* The buffer is taken over and used for each of the channels. * <p>The buffer is taken over and used for each of the channels.
* It will be recycled afterwards. * It will be recycled afterwards.
* *
* @param eventBuffer the buffer to write * @param buffer the buffer to write
* @throws IOException
*/ */
public void writeBufferToAllChannels(final Buffer eventBuffer) throws IOException { default void writeBufferToAllSubpartitions(final Buffer buffer) throws IOException {
try { try {
for (int targetChannel = 0; targetChannel < partition.getNumberOfSubpartitions(); targetChannel++) { for (int subpartition = 0; subpartition < getNumberOfSubpartitions(); subpartition++) {
// retain the buffer so that it can be recycled by each channel of targetPartition // retain the buffer so that it can be recycled by each channel of targetPartition
eventBuffer.retain(); buffer.retain();
writeBuffer(eventBuffer, targetChannel); writeBuffer(buffer, subpartition);
} }
} finally { } finally {
// we do not need to further retain the eventBuffer // we do not need to further retain the eventBuffer
// (it will be recycled after the last channel stops using it) // (it will be recycled after the last channel stops using it)
eventBuffer.recycle(); buffer.recycle();
} }
} }
} }
...@@ -21,6 +21,7 @@ package org.apache.flink.runtime.io.network.partition; ...@@ -21,6 +21,7 @@ package org.apache.flink.runtime.io.network.partition;
import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.BufferPoolOwner; import org.apache.flink.runtime.io.network.buffer.BufferPoolOwner;
...@@ -74,7 +75,7 @@ import static org.apache.flink.util.Preconditions.checkState; ...@@ -74,7 +75,7 @@ import static org.apache.flink.util.Preconditions.checkState;
* *
* <h2>State management</h2> * <h2>State management</h2>
*/ */
public class ResultPartition implements BufferPoolOwner { public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner {
private static final Logger LOG = LoggerFactory.getLogger(ResultPartition.class); private static final Logger LOG = LoggerFactory.getLogger(ResultPartition.class);
...@@ -209,10 +210,12 @@ public class ResultPartition implements BufferPoolOwner { ...@@ -209,10 +210,12 @@ public class ResultPartition implements BufferPoolOwner {
return partitionId; return partitionId;
} }
@Override
public int getNumberOfSubpartitions() { public int getNumberOfSubpartitions() {
return subpartitions.length; return subpartitions.length;
} }
@Override
public BufferProvider getBufferProvider() { public BufferProvider getBufferProvider() {
return bufferPool; return bufferPool;
} }
...@@ -260,13 +263,8 @@ public class ResultPartition implements BufferPoolOwner { ...@@ -260,13 +263,8 @@ public class ResultPartition implements BufferPoolOwner {
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
/** @Override
* Adds a buffer to the subpartition with the given index. public void writeBuffer(Buffer buffer, int subpartitionIndex) throws IOException {
*
* <p> For PIPELINED results, this will trigger the deployment of consuming tasks after the
* first buffer has been added.
*/
public void add(Buffer buffer, int subpartitionIndex) throws IOException {
boolean success = false; boolean success = false;
try { try {
...@@ -381,6 +379,7 @@ public class ResultPartition implements BufferPoolOwner { ...@@ -381,6 +379,7 @@ public class ResultPartition implements BufferPoolOwner {
return cause; return cause;
} }
@Override
public int getNumTargetKeyGroups() { public int getNumTargetKeyGroups() {
return numTargetKeyGroups; return numTargetKeyGroups;
} }
......
...@@ -441,6 +441,6 @@ public class IterationHeadTask<X, Y, S extends Function, OT> extends AbstractIte ...@@ -441,6 +441,6 @@ public class IterationHeadTask<X, Y, S extends Function, OT> extends AbstractIte
log.info(formatLogString("sending " + WorkerDoneEvent.class.getSimpleName() + " to sync")); log.info(formatLogString("sending " + WorkerDoneEvent.class.getSimpleName() + " to sync"));
} }
this.toSync.writeBufferToAllChannels(EventSerializer.toBuffer(event)); this.toSync.writeBufferToAllSubpartitions(EventSerializer.toBuffer(event));
} }
} }
...@@ -51,7 +51,6 @@ import org.apache.flink.runtime.executiongraph.TaskInformation; ...@@ -51,7 +51,6 @@ import org.apache.flink.runtime.executiongraph.TaskInformation;
import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.filecache.FileCache;
import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
...@@ -187,8 +186,6 @@ public class Task implements Runnable, TaskActions { ...@@ -187,8 +186,6 @@ public class Task implements Runnable, TaskActions {
private final ResultPartition[] producedPartitions; private final ResultPartition[] producedPartitions;
private final ResultPartitionWriter[] writers;
private final SingleInputGate[] inputGates; private final SingleInputGate[] inputGates;
private final Map<IntermediateDataSetID, SingleInputGate> inputGatesById; private final Map<IntermediateDataSetID, SingleInputGate> inputGatesById;
...@@ -360,7 +357,6 @@ public class Task implements Runnable, TaskActions { ...@@ -360,7 +357,6 @@ public class Task implements Runnable, TaskActions {
// Produced intermediate result partitions // Produced intermediate result partitions
this.producedPartitions = new ResultPartition[resultPartitionDeploymentDescriptors.size()]; this.producedPartitions = new ResultPartition[resultPartitionDeploymentDescriptors.size()];
this.writers = new ResultPartitionWriter[resultPartitionDeploymentDescriptors.size()];
int counter = 0; int counter = 0;
...@@ -380,8 +376,6 @@ public class Task implements Runnable, TaskActions { ...@@ -380,8 +376,6 @@ public class Task implements Runnable, TaskActions {
ioManager, ioManager,
desc.sendScheduleOrUpdateConsumersMessage()); desc.sendScheduleOrUpdateConsumersMessage());
writers[counter] = new ResultPartitionWriter(producedPartitions[counter]);
++counter; ++counter;
} }
...@@ -445,10 +439,6 @@ public class Task implements Runnable, TaskActions { ...@@ -445,10 +439,6 @@ public class Task implements Runnable, TaskActions {
return this.taskConfiguration; return this.taskConfiguration;
} }
public ResultPartitionWriter[] getAllWriters() {
return writers;
}
public SingleInputGate[] getAllInputGates() { public SingleInputGate[] getAllInputGates() {
return inputGates; return inputGates;
} }
...@@ -682,7 +672,7 @@ public class Task implements Runnable, TaskActions { ...@@ -682,7 +672,7 @@ public class Task implements Runnable, TaskActions {
kvStateRegistry, kvStateRegistry,
inputSplitProvider, inputSplitProvider,
distributedCacheEntries, distributedCacheEntries,
writers, producedPartitions,
inputGates, inputGates,
network.getTaskEventDispatcher(), network.getTaskEventDispatcher(),
checkpointResponder, checkpointResponder,
......
...@@ -37,8 +37,6 @@ import org.junit.Test; ...@@ -37,8 +37,6 @@ import org.junit.Test;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.io.IOException;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
...@@ -82,9 +80,6 @@ public class NetworkEnvironmentTest { ...@@ -82,9 +80,6 @@ public class NetworkEnvironmentTest {
ResultPartition rp3 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 2); ResultPartition rp3 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 2);
ResultPartition rp4 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 8); ResultPartition rp4 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 8);
final ResultPartition[] resultPartitions = new ResultPartition[] {rp1, rp2, rp3, rp4}; final ResultPartition[] resultPartitions = new ResultPartition[] {rp1, rp2, rp3, rp4};
final ResultPartitionWriter[] resultPartitionWriters = new ResultPartitionWriter[] {
new ResultPartitionWriter(rp1), new ResultPartitionWriter(rp2),
new ResultPartitionWriter(rp3), new ResultPartitionWriter(rp4)};
// input gates // input gates
SingleInputGate ig1 = createSingleInputGateMock(ResultPartitionType.PIPELINED, 2); SingleInputGate ig1 = createSingleInputGateMock(ResultPartitionType.PIPELINED, 2);
...@@ -96,7 +91,6 @@ public class NetworkEnvironmentTest { ...@@ -96,7 +91,6 @@ public class NetworkEnvironmentTest {
// overall task to register // overall task to register
Task task = mock(Task.class); Task task = mock(Task.class);
when(task.getProducedPartitions()).thenReturn(resultPartitions); when(task.getProducedPartitions()).thenReturn(resultPartitions);
when(task.getAllWriters()).thenReturn(resultPartitionWriters);
when(task.getAllInputGates()).thenReturn(inputGates); when(task.getAllInputGates()).thenReturn(inputGates);
network.registerTask(task); network.registerTask(task);
......
...@@ -181,7 +181,7 @@ public class RecordWriterTest { ...@@ -181,7 +181,7 @@ public class RecordWriterTest {
ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class);
when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferPool)); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferPool));
when(partitionWriter.getNumberOfOutputChannels()).thenReturn(1); when(partitionWriter.getNumberOfSubpartitions()).thenReturn(1);
// Recycle buffer and throw Exception // Recycle buffer and throw Exception
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
...@@ -454,7 +454,7 @@ public class RecordWriterTest { ...@@ -454,7 +454,7 @@ public class RecordWriterTest {
ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class);
when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider)); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider));
when(partitionWriter.getNumberOfOutputChannels()).thenReturn(numChannels); when(partitionWriter.getNumberOfSubpartitions()).thenReturn(numChannels);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override
...@@ -512,7 +512,7 @@ public class RecordWriterTest { ...@@ -512,7 +512,7 @@ public class RecordWriterTest {
ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class);
when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider)); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider));
when(partitionWriter.getNumberOfOutputChannels()).thenReturn(1); when(partitionWriter.getNumberOfSubpartitions()).thenReturn(1);
// Recycle each written buffer. // Recycle each written buffer.
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.network.api.writer;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
public class ResultPartitionWriterTest {
// ---------------------------------------------------------------------------------------------
// Resource release tests
// ---------------------------------------------------------------------------------------------
/**
* Tests that event buffers are properly recycled when broadcasting events
* to multiple channels.
*
* @throws Exception
*/
@Test
public void testWriteBufferToAllChannelsReferenceCounting() throws Exception {
Buffer buffer = EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE);
ResultPartition partition = new ResultPartition(
"TestTask",
mock(TaskActions.class),
new JobID(),
new ResultPartitionID(),
ResultPartitionType.PIPELINED,
2,
2,
mock(ResultPartitionManager.class),
mock(ResultPartitionConsumableNotifier.class),
mock(IOManager.class),
false);
ResultPartitionWriter partitionWriter =
new ResultPartitionWriter(
partition);
partitionWriter.writeBufferToAllChannels(buffer);
// Verify added to all queues, i.e. two buffers in total
assertEquals(2, partition.getTotalNumberOfBuffers());
// release the buffers in the partition
partition.release();
assertTrue(buffer.isRecycled());
}
}
...@@ -20,12 +20,16 @@ package org.apache.flink.runtime.io.network.partition; ...@@ -20,12 +20,16 @@ package org.apache.flink.runtime.io.network.partition;
import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
...@@ -44,7 +48,7 @@ public class ResultPartitionTest { ...@@ -44,7 +48,7 @@ public class ResultPartitionTest {
// Pipelined, send message => notify // Pipelined, send message => notify
ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class);
ResultPartition partition = createPartition(notifier, ResultPartitionType.PIPELINED, true); ResultPartition partition = createPartition(notifier, ResultPartitionType.PIPELINED, true);
partition.add(TestBufferFactory.createBuffer(), 0); partition.writeBuffer(TestBufferFactory.createBuffer(), 0);
verify(notifier, times(1)).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); verify(notifier, times(1)).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class));
} }
...@@ -52,7 +56,7 @@ public class ResultPartitionTest { ...@@ -52,7 +56,7 @@ public class ResultPartitionTest {
// Pipelined, don't send message => don't notify // Pipelined, don't send message => don't notify
ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class);
ResultPartition partition = createPartition(notifier, ResultPartitionType.PIPELINED, false); ResultPartition partition = createPartition(notifier, ResultPartitionType.PIPELINED, false);
partition.add(TestBufferFactory.createBuffer(), 0); partition.writeBuffer(TestBufferFactory.createBuffer(), 0);
verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class));
} }
...@@ -60,7 +64,7 @@ public class ResultPartitionTest { ...@@ -60,7 +64,7 @@ public class ResultPartitionTest {
// Blocking, send message => don't notify // Blocking, send message => don't notify
ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class);
ResultPartition partition = createPartition(notifier, ResultPartitionType.BLOCKING, true); ResultPartition partition = createPartition(notifier, ResultPartitionType.BLOCKING, true);
partition.add(TestBufferFactory.createBuffer(), 0); partition.writeBuffer(TestBufferFactory.createBuffer(), 0);
verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class));
} }
...@@ -68,7 +72,7 @@ public class ResultPartitionTest { ...@@ -68,7 +72,7 @@ public class ResultPartitionTest {
// Blocking, don't send message => don't notify // Blocking, don't send message => don't notify
ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class);
ResultPartition partition = createPartition(notifier, ResultPartitionType.BLOCKING, false); ResultPartition partition = createPartition(notifier, ResultPartitionType.BLOCKING, false);
partition.add(TestBufferFactory.createBuffer(), 0); partition.writeBuffer(TestBufferFactory.createBuffer(), 0);
verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class));
} }
} }
...@@ -84,7 +88,7 @@ public class ResultPartitionTest { ...@@ -84,7 +88,7 @@ public class ResultPartitionTest {
} }
/** /**
* Tests {@link ResultPartition#add} on a partition which has already finished. * Tests {@link ResultPartition#writeBuffer} on a partition which has already finished.
* *
* @param pipelined the result partition type to set up * @param pipelined the result partition type to set up
*/ */
...@@ -97,7 +101,7 @@ public class ResultPartitionTest { ...@@ -97,7 +101,7 @@ public class ResultPartitionTest {
partition.finish(); partition.finish();
reset(notifier); reset(notifier);
// partition.add() should fail // partition.add() should fail
partition.add(buffer, 0); partition.writeBuffer(buffer, 0);
Assert.fail("exception expected"); Assert.fail("exception expected");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
// expected => ignored // expected => ignored
...@@ -122,7 +126,7 @@ public class ResultPartitionTest { ...@@ -122,7 +126,7 @@ public class ResultPartitionTest {
} }
/** /**
* Tests {@link ResultPartition#add} on a partition which has already been released. * Tests {@link ResultPartition#writeBuffer} on a partition which has already been released.
* *
* @param pipelined the result partition type to set up * @param pipelined the result partition type to set up
*/ */
...@@ -134,7 +138,7 @@ public class ResultPartitionTest { ...@@ -134,7 +138,7 @@ public class ResultPartitionTest {
ResultPartition partition = createPartition(notifier, pipelined, true); ResultPartition partition = createPartition(notifier, pipelined, true);
partition.release(); partition.release();
// partition.add() silently drops the buffer but recycles it // partition.add() silently drops the buffer but recycles it
partition.add(buffer, 0); partition.writeBuffer(buffer, 0);
} finally { } finally {
if (!buffer.isRecycled()) { if (!buffer.isRecycled()) {
Assert.fail("buffer not recycled"); Assert.fail("buffer not recycled");
...@@ -145,6 +149,37 @@ public class ResultPartitionTest { ...@@ -145,6 +149,37 @@ public class ResultPartitionTest {
} }
} }
/**
* Tests that event buffers are properly added and recycled when broadcasting events
* to multiple channels.
*/
@Test
public void testWriteBufferToAllSubpartitionsReferenceCounting() throws Exception {
Buffer buffer = EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE);
ResultPartition partition = new ResultPartition(
"TestTask",
mock(TaskActions.class),
new JobID(),
new ResultPartitionID(),
ResultPartitionType.PIPELINED,
2,
2,
mock(ResultPartitionManager.class),
mock(ResultPartitionConsumableNotifier.class),
mock(IOManager.class),
false);
partition.writeBufferToAllSubpartitions(buffer);
// Verify added to all queues, i.e. two buffers in total
assertEquals(2, partition.getTotalNumberOfBuffers());
// release the buffers in the partition
partition.release();
assertTrue(buffer.isRecycled());
}
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
private static ResultPartition createPartition( private static ResultPartition createPartition(
......
...@@ -77,12 +77,12 @@ public class TestPartitionProducer implements Callable<Boolean> { ...@@ -77,12 +77,12 @@ public class TestPartitionProducer implements Callable<Boolean> {
int targetChannelIndex = bufferOrEvent.getChannelIndex(); int targetChannelIndex = bufferOrEvent.getChannelIndex();
if (bufferOrEvent.isBuffer()) { if (bufferOrEvent.isBuffer()) {
partition.add(bufferOrEvent.getBuffer(), targetChannelIndex); partition.writeBuffer(bufferOrEvent.getBuffer(), targetChannelIndex);
} }
else if (bufferOrEvent.isEvent()) { else if (bufferOrEvent.isEvent()) {
final Buffer buffer = EventSerializer.toBuffer(bufferOrEvent.getEvent()); final Buffer buffer = EventSerializer.toBuffer(bufferOrEvent.getEvent());
partition.add(buffer, targetChannelIndex); partition.writeBuffer(buffer, targetChannelIndex);
} }
else { else {
throw new IllegalStateException("BufferOrEvent instance w/o buffer nor event."); throw new IllegalStateException("BufferOrEvent instance w/o buffer nor event.");
......
...@@ -204,7 +204,7 @@ public class MockEnvironment implements Environment { ...@@ -204,7 +204,7 @@ public class MockEnvironment implements Environment {
}); });
ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class);
when(mockWriter.getNumberOfOutputChannels()).thenReturn(1); when(mockWriter.getNumberOfSubpartitions()).thenReturn(1);
when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider); when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider);
final Record record = new Record(); final Record record = new Record();
......
...@@ -632,7 +632,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> ...@@ -632,7 +632,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
for (ResultPartitionWriter output : getEnvironment().getAllWriters()) { for (ResultPartitionWriter output : getEnvironment().getAllWriters()) {
try { try {
output.writeBufferToAllChannels(EventSerializer.toBuffer(message)); output.writeBufferToAllSubpartitions(EventSerializer.toBuffer(message));
} catch (Exception e) { } catch (Exception e) {
exception = ExceptionUtils.firstOrSuppressed( exception = ExceptionUtils.firstOrSuppressed(
new Exception("Could not send cancel checkpoint marker to downstream tasks.", e), new Exception("Could not send cancel checkpoint marker to downstream tasks.", e),
......
...@@ -98,7 +98,7 @@ public class StreamRecordWriterTest { ...@@ -98,7 +98,7 @@ public class StreamRecordWriterTest {
ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class);
when(mockWriter.getBufferProvider()).thenReturn(mockProvider); when(mockWriter.getBufferProvider()).thenReturn(mockProvider);
when(mockWriter.getNumberOfOutputChannels()).thenReturn(numPartitions); when(mockWriter.getNumberOfSubpartitions()).thenReturn(numPartitions);
return mockWriter; return mockWriter;
} }
......
...@@ -161,7 +161,7 @@ public class StreamMockEnvironment implements Environment { ...@@ -161,7 +161,7 @@ public class StreamMockEnvironment implements Environment {
}); });
ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class);
when(mockWriter.getNumberOfOutputChannels()).thenReturn(1); when(mockWriter.getNumberOfSubpartitions()).thenReturn(1);
when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider); when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider);
final RecordDeserializer<DeserializationDelegate<T>> recordDeserializer = new AdaptiveSpanningRecordDeserializer<DeserializationDelegate<T>>(); final RecordDeserializer<DeserializationDelegate<T>> recordDeserializer = new AdaptiveSpanningRecordDeserializer<DeserializationDelegate<T>>();
...@@ -186,7 +186,7 @@ public class StreamMockEnvironment implements Environment { ...@@ -186,7 +186,7 @@ public class StreamMockEnvironment implements Environment {
addBufferToOutputList(recordDeserializer, delegate, buffer, outputList); addBufferToOutputList(recordDeserializer, delegate, buffer, outputList);
return null; return null;
} }
}).when(mockWriter).writeBufferToAllChannels(any(Buffer.class)); }).when(mockWriter).writeBufferToAllSubpartitions(any(Buffer.class));
outputs.add(mockWriter); outputs.add(mockWriter);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册