[FLINK-20491] Add preferred/pass-though inputs in MultiInputSortingDataInput

This will allow processing the broadcast side of a broadcast operator
first, before processing the keyed side that requires sorting for
stateful BATCH execution.

For now, the wiring from the API is not there, this will be added in
follow-up changes.
上级 e31b162a
...@@ -33,6 +33,7 @@ import org.apache.flink.runtime.memory.MemoryAllocationException; ...@@ -33,6 +33,7 @@ import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.sort.ExternalSorter; import org.apache.flink.runtime.operators.sort.ExternalSorter;
import org.apache.flink.runtime.operators.sort.PushSorter; import org.apache.flink.runtime.operators.sort.PushSorter;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.api.operators.InputSelectable; import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection; import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.api.watermark.Watermark;
...@@ -48,8 +49,13 @@ import org.apache.flink.util.MutableObjectIterator; ...@@ -48,8 +49,13 @@ import org.apache.flink.util.MutableObjectIterator;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue; import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
/** /**
...@@ -107,11 +113,15 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput< ...@@ -107,11 +113,15 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput<
*/ */
public static class SelectableSortingInputs { public static class SelectableSortingInputs {
private final InputSelectable inputSelectable; private final InputSelectable inputSelectable;
private final StreamTaskInput<?>[] sortingInputs; private final StreamTaskInput<?>[] sortedInputs;
private final StreamTaskInput<?>[] passThroughInputs;
public SelectableSortingInputs( public SelectableSortingInputs(
StreamTaskInput<?>[] sortingInputs, InputSelectable inputSelectable) { StreamTaskInput<?>[] sortedInputs,
this.sortingInputs = sortingInputs; StreamTaskInput<?>[] passThroughInputs,
InputSelectable inputSelectable) {
this.sortedInputs = sortedInputs;
this.passThroughInputs = passThroughInputs;
this.inputSelectable = inputSelectable; this.inputSelectable = inputSelectable;
} }
...@@ -119,17 +129,22 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput< ...@@ -119,17 +129,22 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput<
return inputSelectable; return inputSelectable;
} }
public StreamTaskInput<?>[] getSortingInputs() { public StreamTaskInput<?>[] getSortedInputs() {
return sortingInputs; return sortedInputs;
}
public StreamTaskInput<?>[] getPassThroughInputs() {
return passThroughInputs;
} }
} }
public static <K> SelectableSortingInputs wrapInputs( public static <K> SelectableSortingInputs wrapInputs(
AbstractInvokable containingTask, AbstractInvokable containingTask,
StreamTaskInput<Object>[] inputs, StreamTaskInput<Object>[] sortingInputs,
KeySelector<Object, K>[] keySelectors, KeySelector<Object, K>[] keySelectors,
TypeSerializer<Object>[] inputSerializers, TypeSerializer<Object>[] inputSerializers,
TypeSerializer<K> keySerializer, TypeSerializer<K> keySerializer,
StreamTaskInput<Object>[] passThroughInputs,
MemoryManager memoryManager, MemoryManager memoryManager,
IOManager ioManager, IOManager ioManager,
boolean objectReuse, boolean objectReuse,
...@@ -146,20 +161,28 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput< ...@@ -146,20 +161,28 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput<
comparator = new VariableLengthByteKeyComparator<>(); comparator = new VariableLengthByteKeyComparator<>();
} }
int numberOfInputs = inputs.length; List<Integer> passThroughInputIndices =
CommonContext commonContext = new CommonContext(numberOfInputs); Arrays.stream(passThroughInputs)
StreamTaskInput<?>[] sortingInputs = .map(StreamTaskInput::getInputIndex)
IntStream.range(0, numberOfInputs) .collect(Collectors.toList());
int numberOfInputs = sortingInputs.length + passThroughInputs.length;
CommonContext commonContext = new CommonContext(sortingInputs);
InputSelector inputSelector =
new InputSelector(commonContext, numberOfInputs, passThroughInputIndices);
StreamTaskInput<?>[] wrappedSortingInputs =
IntStream.range(0, sortingInputs.length)
.mapToObj( .mapToObj(
idx -> { idx -> {
try { try {
KeyAndValueSerializer<Object> keyAndValueSerializer = KeyAndValueSerializer<Object> keyAndValueSerializer =
new KeyAndValueSerializer<>( new KeyAndValueSerializer<>(
inputSerializers[idx], keyLength); inputSerializers[idx], keyLength);
return new MultiInputSortingDataInput<>( return new MultiInputSortingDataInput<>(
commonContext, commonContext,
inputs[idx], sortingInputs[idx],
idx, sortingInputs[idx].getInputIndex(),
ExternalSorter.newBuilder( ExternalSorter.newBuilder(
memoryManager, memoryManager,
containingTask, containingTask,
...@@ -189,8 +212,14 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput< ...@@ -189,8 +212,14 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput<
} }
}) })
.toArray(StreamTaskInput[]::new); .toArray(StreamTaskInput[]::new);
StreamTaskInput<?>[] wrappedPassThroughInputs =
Arrays.stream(passThroughInputs)
.map(input -> new ObservableStreamTaskInput<>(input, inputSelector))
.toArray(StreamTaskInput[]::new);
return new SelectableSortingInputs( return new SelectableSortingInputs(
sortingInputs, new InputSelector(commonContext, numberOfInputs)); wrappedSortingInputs, wrappedPassThroughInputs, inputSelector);
} }
@Override @Override
...@@ -318,23 +347,40 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput< ...@@ -318,23 +347,40 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput<
* all sorting inputs. Should be used by the {@link StreamInputProcessor} to choose the next * all sorting inputs. Should be used by the {@link StreamInputProcessor} to choose the next
* input to consume from. * input to consume from.
*/ */
private static class InputSelector implements InputSelectable { private static class InputSelector implements InputSelectable, BoundedMultiInput {
private final CommonContext commonContext; private final CommonContext commonContext;
private final int numberOfInputs; private final int numInputs;
private final Queue<Integer> passThroughInputsIndices;
private InputSelector(CommonContext commonContext, int numberOfInputs) { private InputSelector(
CommonContext commonContext, int numInputs, List<Integer> passThroughInputIndices) {
this.commonContext = commonContext; this.commonContext = commonContext;
this.numberOfInputs = numberOfInputs; this.numInputs = numInputs;
this.passThroughInputsIndices = new LinkedList<>(passThroughInputIndices);
}
@Override
public void endInput(int inputId) throws Exception {
passThroughInputsIndices.remove(inputId);
} }
@Override @Override
public InputSelection nextSelection() { public InputSelection nextSelection() {
Integer currentPassThroughInputIndex = passThroughInputsIndices.peek();
if (currentPassThroughInputIndex != null) {
// yes, 0-based to 1-based mapping ... 🙏
return new InputSelection.Builder()
.select(currentPassThroughInputIndex + 1)
.build(numInputs);
}
if (commonContext.allSorted()) { if (commonContext.allSorted()) {
HeadElement headElement = commonContext.getQueueOfHeads().peek(); HeadElement headElement = commonContext.getQueueOfHeads().peek();
if (headElement != null) { if (headElement != null) {
int headIdx = headElement.inputIndex; int headIdx = headElement.inputIndex;
return new InputSelection.Builder().select(headIdx + 1).build(numberOfInputs); return new InputSelection.Builder().select(headIdx + 1).build(numInputs);
} }
} }
return InputSelection.ALL; return InputSelection.ALL;
...@@ -419,9 +465,10 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput< ...@@ -419,9 +465,10 @@ public final class MultiInputSortingDataInput<IN, K> implements StreamTaskInput<
private long notFinishedSortingMask = 0; private long notFinishedSortingMask = 0;
private long finishedEmitting = 0; private long finishedEmitting = 0;
public CommonContext(int numberOfInputs) { public CommonContext(StreamTaskInput<Object>[] sortingInputs) {
for (int i = 0; i < numberOfInputs; i++) { for (StreamTaskInput<Object> sortingInput : sortingInputs) {
notFinishedSortingMask = setBitMask(notFinishedSortingMask, i); notFinishedSortingMask =
setBitMask(notFinishedSortingMask, sortingInput.getInputIndex());
} }
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.operators.sort;
import org.apache.flink.core.io.InputStatus;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.runtime.io.StreamTaskInput;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
/**
* A wrapping {@link StreamTaskInput} that invokes a given {@link BoundedMultiInput} when reaching
* {@link InputStatus#END_OF_INPUT}.
*/
class ObservableStreamTaskInput<T> implements StreamTaskInput<T> {
private final StreamTaskInput<T> wrappedInput;
private final BoundedMultiInput endOfInputObserver;
public ObservableStreamTaskInput(
StreamTaskInput<T> wrappedInput, BoundedMultiInput endOfInputObserver) {
this.wrappedInput = wrappedInput;
this.endOfInputObserver = endOfInputObserver;
}
@Override
public InputStatus emitNext(DataOutput<T> output) throws Exception {
InputStatus result = wrappedInput.emitNext(output);
if (result == InputStatus.END_OF_INPUT) {
endOfInputObserver.endInput(wrappedInput.getInputIndex());
}
return result;
}
@Override
public int getInputIndex() {
return wrappedInput.getInputIndex();
}
@Override
public CompletableFuture<Void> prepareSnapshot(
ChannelStateWriter channelStateWriter, long checkpointId) throws IOException {
return wrappedInput.prepareSnapshot(channelStateWriter, checkpointId);
}
@Override
public void close() throws IOException {
wrappedInput.close();
}
@Override
public CompletableFuture<?> getAvailableFuture() {
return wrappedInput.getAvailableFuture();
}
@Override
public boolean isAvailable() {
return wrappedInput.isAvailable();
}
}
...@@ -142,6 +142,7 @@ public class StreamMultipleInputProcessorFactory { ...@@ -142,6 +142,7 @@ public class StreamMultipleInputProcessorFactory {
idx, userClassloader)) idx, userClassloader))
.toArray(TypeSerializer[]::new), .toArray(TypeSerializer[]::new),
streamConfig.getStateKeySerializer(userClassloader), streamConfig.getStateKeySerializer(userClassloader),
new StreamTaskInput[0],
memoryManager, memoryManager,
ioManager, ioManager,
executionConfig.isObjectReuseEnabled(), executionConfig.isObjectReuseEnabled(),
...@@ -151,7 +152,7 @@ public class StreamMultipleInputProcessorFactory { ...@@ -151,7 +152,7 @@ public class StreamMultipleInputProcessorFactory {
userClassloader), userClassloader),
jobConfig); jobConfig);
inputs = selectableSortingInputs.getSortingInputs(); inputs = selectableSortingInputs.getSortedInputs();
inputSelectable = selectableSortingInputs.getInputSelectable(); inputSelectable = selectableSortingInputs.getInputSelectable();
} }
......
...@@ -108,6 +108,7 @@ public class StreamTwoInputProcessorFactory { ...@@ -108,6 +108,7 @@ public class StreamTwoInputProcessorFactory {
}, },
new TypeSerializer[] {typeSerializer1, typeSerializer2}, new TypeSerializer[] {typeSerializer1, typeSerializer2},
streamConfig.getStateKeySerializer(userClassloader), streamConfig.getStateKeySerializer(userClassloader),
new StreamTaskInput[0],
memoryManager, memoryManager,
ioManager, ioManager,
executionConfig.isObjectReuseEnabled(), executionConfig.isObjectReuseEnabled(),
...@@ -117,8 +118,8 @@ public class StreamTwoInputProcessorFactory { ...@@ -117,8 +118,8 @@ public class StreamTwoInputProcessorFactory {
userClassloader), userClassloader),
jobConfig); jobConfig);
inputSelectable = selectableSortingInputs.getInputSelectable(); inputSelectable = selectableSortingInputs.getInputSelectable();
input1 = getSortedInput(selectableSortingInputs.getSortingInputs()[0]); input1 = getSortedInput(selectableSortingInputs.getSortedInputs()[0]);
input2 = getSortedInput(selectableSortingInputs.getSortingInputs()[1]); input2 = getSortedInput(selectableSortingInputs.getSortedInputs()[1]);
} }
StreamTaskNetworkOutput<IN1> output1 = StreamTaskNetworkOutput<IN1> output1 =
......
...@@ -137,13 +137,14 @@ public class LargeSortingDataInputITCase { ...@@ -137,13 +137,14 @@ public class LargeSortingDataInputITCase {
GeneratedRecordsDataInput.SERIALIZER GeneratedRecordsDataInput.SERIALIZER
}, },
new StringSerializer(), new StringSerializer(),
new StreamTaskInput[0],
environment.getMemoryManager(), environment.getMemoryManager(),
environment.getIOManager(), environment.getIOManager(),
true, true,
1.0, 1.0,
new Configuration()); new Configuration());
StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortingInputs(); StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortedInputs();
try (StreamTaskInput<Tuple3<Integer, String, byte[]>> sortedInput1 = try (StreamTaskInput<Tuple3<Integer, String, byte[]>> sortedInput1 =
(StreamTaskInput<Tuple3<Integer, String, byte[]>>) (StreamTaskInput<Tuple3<Integer, String, byte[]>>)
sortingDataInputs[0]; sortingDataInputs[0];
......
...@@ -45,6 +45,106 @@ import static org.junit.Assert.assertThat; ...@@ -45,6 +45,106 @@ import static org.junit.Assert.assertThat;
/** Tests for {@link MultiInputSortingDataInput}. */ /** Tests for {@link MultiInputSortingDataInput}. */
public class MultiInputSortingDataInputsTest { public class MultiInputSortingDataInputsTest {
@Test
public void passThroughThenSortedInput() throws Exception {
twoInputOrderTest(1, 0);
}
@Test
public void sortedThenPassThroughInput() throws Exception {
twoInputOrderTest(0, 1);
}
@SuppressWarnings("unchecked")
public void twoInputOrderTest(int preferredIndex, int sortedIndex) throws Exception {
CollectingDataOutput<Object> collectingDataOutput = new CollectingDataOutput<>();
List<StreamElement> sortedInputElements =
Arrays.asList(
new StreamRecord<>(1, 3),
new StreamRecord<>(1, 1),
new StreamRecord<>(2, 1),
new StreamRecord<>(2, 3),
new StreamRecord<>(1, 2),
new StreamRecord<>(2, 2),
Watermark.MAX_WATERMARK);
CollectionDataInput<Integer> sortedInput =
new CollectionDataInput<>(sortedInputElements, sortedIndex);
List<StreamElement> preferredInputElements =
Arrays.asList(
new StreamRecord<>(99, 3), new StreamRecord<>(99, 1), new Watermark(99L));
CollectionDataInput<Integer> preferredInput =
new CollectionDataInput<>(preferredInputElements, preferredIndex);
KeySelector<Integer, Integer> keySelector = value -> value;
try (MockEnvironment environment = MockEnvironment.builder().build()) {
SelectableSortingInputs selectableSortingInputs =
MultiInputSortingDataInput.wrapInputs(
new DummyInvokable(),
new StreamTaskInput[] {sortedInput},
new KeySelector[] {keySelector},
new TypeSerializer[] {new IntSerializer()},
new IntSerializer(),
new StreamTaskInput[] {preferredInput},
environment.getMemoryManager(),
environment.getIOManager(),
true,
1.0,
new Configuration());
StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortedInputs();
StreamTaskInput<?>[] preferredDataInputs =
selectableSortingInputs.getPassThroughInputs();
try (StreamTaskInput<Object> preferredTaskInput =
(StreamTaskInput<Object>) preferredDataInputs[0];
StreamTaskInput<Object> sortedTaskInput =
(StreamTaskInput<Object>) sortingDataInputs[0]) {
MultipleInputSelectionHandler selectionHandler =
new MultipleInputSelectionHandler(
selectableSortingInputs.getInputSelectable(), 2);
@SuppressWarnings("rawtypes")
StreamOneInputProcessor[] inputProcessors = new StreamOneInputProcessor[2];
inputProcessors[preferredIndex] =
new StreamOneInputProcessor<>(
preferredTaskInput, collectingDataOutput, new DummyOperatorChain());
inputProcessors[sortedIndex] =
new StreamOneInputProcessor<>(
sortedTaskInput, collectingDataOutput, new DummyOperatorChain());
StreamMultipleInputProcessor processor =
new StreamMultipleInputProcessor(selectionHandler, inputProcessors);
InputStatus inputStatus;
do {
inputStatus = processor.processInput();
} while (inputStatus != InputStatus.END_OF_INPUT);
}
}
assertThat(
collectingDataOutput.events,
equalTo(
Arrays.asList(
new StreamRecord<>(99, 3),
new StreamRecord<>(99, 1),
new Watermark(99L), // max watermark from the preferred input
new StreamRecord<>(1, 1),
new StreamRecord<>(1, 2),
new StreamRecord<>(1, 3),
new StreamRecord<>(2, 1),
new StreamRecord<>(2, 2),
new StreamRecord<>(2, 3),
Watermark.MAX_WATERMARK // max watermark from the sorted input
)));
}
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void simpleFixedLengthKeySorting() throws Exception { public void simpleFixedLengthKeySorting() throws Exception {
...@@ -69,13 +169,14 @@ public class MultiInputSortingDataInputsTest { ...@@ -69,13 +169,14 @@ public class MultiInputSortingDataInputsTest {
new KeySelector[] {keySelector, keySelector}, new KeySelector[] {keySelector, keySelector},
new TypeSerializer[] {new IntSerializer(), new IntSerializer()}, new TypeSerializer[] {new IntSerializer(), new IntSerializer()},
new IntSerializer(), new IntSerializer(),
new StreamTaskInput[0],
environment.getMemoryManager(), environment.getMemoryManager(),
environment.getIOManager(), environment.getIOManager(),
true, true,
1.0, 1.0,
new Configuration()); new Configuration());
StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortingInputs(); StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortedInputs();
try (StreamTaskInput<Object> input1 = (StreamTaskInput<Object>) sortingDataInputs[0]; try (StreamTaskInput<Object> input1 = (StreamTaskInput<Object>) sortingDataInputs[0];
StreamTaskInput<Object> input2 = StreamTaskInput<Object> input2 =
(StreamTaskInput<Object>) sortingDataInputs[1]) { (StreamTaskInput<Object>) sortingDataInputs[1]) {
...@@ -148,13 +249,14 @@ public class MultiInputSortingDataInputsTest { ...@@ -148,13 +249,14 @@ public class MultiInputSortingDataInputsTest {
new KeySelector[] {keySelector, keySelector}, new KeySelector[] {keySelector, keySelector},
new TypeSerializer[] {new IntSerializer(), new IntSerializer()}, new TypeSerializer[] {new IntSerializer(), new IntSerializer()},
new IntSerializer(), new IntSerializer(),
new StreamTaskInput[0],
environment.getMemoryManager(), environment.getMemoryManager(),
environment.getIOManager(), environment.getIOManager(),
true, true,
1.0, 1.0,
new Configuration()); new Configuration());
StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortingInputs(); StreamTaskInput<?>[] sortingDataInputs = selectableSortingInputs.getSortedInputs();
try (StreamTaskInput<Object> input1 = (StreamTaskInput<Object>) sortingDataInputs[0]; try (StreamTaskInput<Object> input1 = (StreamTaskInput<Object>) sortingDataInputs[0];
StreamTaskInput<Object> input2 = StreamTaskInput<Object> input2 =
(StreamTaskInput<Object>) sortingDataInputs[1]) { (StreamTaskInput<Object>) sortingDataInputs[1]) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册