From 00f8de7ebc0951c78949a39352fd8e225dc98494 Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Wed, 9 Dec 2020 18:15:06 +0100 Subject: [PATCH] [FLINK-20491] Add preferred/pass-though inputs in MultiInputSortingDataInput This will allow processing the broadcast side of a broadcast operator first, before processing the keyed side that requires sorting for stateful BATCH execution. For now, the wiring from the API is not there, this will be added in follow-up changes. --- .../sort/MultiInputSortingDataInput.java | 89 +++++++++++---- .../sort/ObservableStreamTaskInput.java | 78 +++++++++++++ .../StreamMultipleInputProcessorFactory.java | 3 +- .../io/StreamTwoInputProcessorFactory.java | 5 +- .../sort/LargeSortingDataInputITCase.java | 3 +- .../sort/MultiInputSortingDataInputsTest.java | 106 +++++++++++++++++- 6 files changed, 257 insertions(+), 27 deletions(-) create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/ObservableStreamTaskInput.java diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInput.java index 94f605ba34c..d162a3b4bbd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInput.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInput.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.memory.MemoryAllocationException; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.sort.ExternalSorter; import org.apache.flink.runtime.operators.sort.PushSorter; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; import org.apache.flink.streaming.api.operators.InputSelectable; import org.apache.flink.streaming.api.operators.InputSelection; import org.apache.flink.streaming.api.watermark.Watermark; @@ -48,8 +49,13 @@ import org.apache.flink.util.MutableObjectIterator; import javax.annotation.Nonnull; import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; import java.util.PriorityQueue; +import java.util.Queue; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -107,11 +113,15 @@ public final class MultiInputSortingDataInput implements StreamTaskInput< */ public static class SelectableSortingInputs { private final InputSelectable inputSelectable; - private final StreamTaskInput[] sortingInputs; + private final StreamTaskInput[] sortedInputs; + private final StreamTaskInput[] passThroughInputs; public SelectableSortingInputs( - StreamTaskInput[] sortingInputs, InputSelectable inputSelectable) { - this.sortingInputs = sortingInputs; + StreamTaskInput[] sortedInputs, + StreamTaskInput[] passThroughInputs, + InputSelectable inputSelectable) { + this.sortedInputs = sortedInputs; + this.passThroughInputs = passThroughInputs; this.inputSelectable = inputSelectable; } @@ -119,17 +129,22 @@ public final class MultiInputSortingDataInput implements StreamTaskInput< return inputSelectable; } - public StreamTaskInput[] getSortingInputs() { - return sortingInputs; + public StreamTaskInput[] getSortedInputs() { + return sortedInputs; + } + + public StreamTaskInput[] getPassThroughInputs() { + return passThroughInputs; } } public static SelectableSortingInputs wrapInputs( AbstractInvokable containingTask, - StreamTaskInput[] inputs, + StreamTaskInput[] sortingInputs, KeySelector[] keySelectors, TypeSerializer[] inputSerializers, TypeSerializer keySerializer, + StreamTaskInput[] passThroughInputs, MemoryManager memoryManager, IOManager ioManager, boolean objectReuse, @@ -146,20 +161,28 @@ public final class MultiInputSortingDataInput implements StreamTaskInput< comparator = new VariableLengthByteKeyComparator<>(); } - int numberOfInputs = inputs.length; - CommonContext commonContext = new CommonContext(numberOfInputs); - StreamTaskInput[] sortingInputs = - IntStream.range(0, numberOfInputs) + List passThroughInputIndices = + Arrays.stream(passThroughInputs) + .map(StreamTaskInput::getInputIndex) + .collect(Collectors.toList()); + int numberOfInputs = sortingInputs.length + passThroughInputs.length; + CommonContext commonContext = new CommonContext(sortingInputs); + InputSelector inputSelector = + new InputSelector(commonContext, numberOfInputs, passThroughInputIndices); + + StreamTaskInput[] wrappedSortingInputs = + IntStream.range(0, sortingInputs.length) .mapToObj( idx -> { try { KeyAndValueSerializer keyAndValueSerializer = new KeyAndValueSerializer<>( inputSerializers[idx], keyLength); + return new MultiInputSortingDataInput<>( commonContext, - inputs[idx], - idx, + sortingInputs[idx], + sortingInputs[idx].getInputIndex(), ExternalSorter.newBuilder( memoryManager, containingTask, @@ -189,8 +212,14 @@ public final class MultiInputSortingDataInput implements StreamTaskInput< } }) .toArray(StreamTaskInput[]::new); + + StreamTaskInput[] wrappedPassThroughInputs = + Arrays.stream(passThroughInputs) + .map(input -> new ObservableStreamTaskInput<>(input, inputSelector)) + .toArray(StreamTaskInput[]::new); + return new SelectableSortingInputs( - sortingInputs, new InputSelector(commonContext, numberOfInputs)); + wrappedSortingInputs, wrappedPassThroughInputs, inputSelector); } @Override @@ -318,23 +347,40 @@ public final class MultiInputSortingDataInput implements StreamTaskInput< * all sorting inputs. Should be used by the {@link StreamInputProcessor} to choose the next * input to consume from. */ - private static class InputSelector implements InputSelectable { + private static class InputSelector implements InputSelectable, BoundedMultiInput { private final CommonContext commonContext; - private final int numberOfInputs; + private final int numInputs; + private final Queue passThroughInputsIndices; - private InputSelector(CommonContext commonContext, int numberOfInputs) { + private InputSelector( + CommonContext commonContext, int numInputs, List passThroughInputIndices) { this.commonContext = commonContext; - this.numberOfInputs = numberOfInputs; + this.numInputs = numInputs; + this.passThroughInputsIndices = new LinkedList<>(passThroughInputIndices); + } + + @Override + public void endInput(int inputId) throws Exception { + passThroughInputsIndices.remove(inputId); } @Override public InputSelection nextSelection() { + Integer currentPassThroughInputIndex = passThroughInputsIndices.peek(); + + if (currentPassThroughInputIndex != null) { + // yes, 0-based to 1-based mapping ... 🙏 + return new InputSelection.Builder() + .select(currentPassThroughInputIndex + 1) + .build(numInputs); + } + if (commonContext.allSorted()) { HeadElement headElement = commonContext.getQueueOfHeads().peek(); if (headElement != null) { int headIdx = headElement.inputIndex; - return new InputSelection.Builder().select(headIdx + 1).build(numberOfInputs); + return new InputSelection.Builder().select(headIdx + 1).build(numInputs); } } return InputSelection.ALL; @@ -419,9 +465,10 @@ public final class MultiInputSortingDataInput implements StreamTaskInput< private long notFinishedSortingMask = 0; private long finishedEmitting = 0; - public CommonContext(int numberOfInputs) { - for (int i = 0; i < numberOfInputs; i++) { - notFinishedSortingMask = setBitMask(notFinishedSortingMask, i); + public CommonContext(StreamTaskInput[] sortingInputs) { + for (StreamTaskInput sortingInput : sortingInputs) { + notFinishedSortingMask = + setBitMask(notFinishedSortingMask, sortingInput.getInputIndex()); } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/ObservableStreamTaskInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/ObservableStreamTaskInput.java new file mode 100644 index 00000000000..0cd5d6c4b9e --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sort/ObservableStreamTaskInput.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.sort; + +import org.apache.flink.core.io.InputStatus; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.runtime.io.StreamTaskInput; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +/** + * A wrapping {@link StreamTaskInput} that invokes a given {@link BoundedMultiInput} when reaching + * {@link InputStatus#END_OF_INPUT}. + */ +class ObservableStreamTaskInput implements StreamTaskInput { + + private final StreamTaskInput wrappedInput; + private final BoundedMultiInput endOfInputObserver; + + public ObservableStreamTaskInput( + StreamTaskInput wrappedInput, BoundedMultiInput endOfInputObserver) { + this.wrappedInput = wrappedInput; + this.endOfInputObserver = endOfInputObserver; + } + + @Override + public InputStatus emitNext(DataOutput output) throws Exception { + InputStatus result = wrappedInput.emitNext(output); + if (result == InputStatus.END_OF_INPUT) { + endOfInputObserver.endInput(wrappedInput.getInputIndex()); + } + return result; + } + + @Override + public int getInputIndex() { + return wrappedInput.getInputIndex(); + } + + @Override + public CompletableFuture prepareSnapshot( + ChannelStateWriter channelStateWriter, long checkpointId) throws IOException { + return wrappedInput.prepareSnapshot(channelStateWriter, checkpointId); + } + + @Override + public void close() throws IOException { + wrappedInput.close(); + } + + @Override + public CompletableFuture getAvailableFuture() { + return wrappedInput.getAvailableFuture(); + } + + @Override + public boolean isAvailable() { + return wrappedInput.isAvailable(); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java index 801c4bc9982..bd1a6c08524 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java @@ -142,6 +142,7 @@ public class StreamMultipleInputProcessorFactory { idx, userClassloader)) .toArray(TypeSerializer[]::new), streamConfig.getStateKeySerializer(userClassloader), + new StreamTaskInput[0], memoryManager, ioManager, executionConfig.isObjectReuseEnabled(), @@ -151,7 +152,7 @@ public class StreamMultipleInputProcessorFactory { userClassloader), jobConfig); - inputs = selectableSortingInputs.getSortingInputs(); + inputs = selectableSortingInputs.getSortedInputs(); inputSelectable = selectableSortingInputs.getInputSelectable(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java index d8b9ad62fd9..25f81a9f7e6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java @@ -108,6 +108,7 @@ public class StreamTwoInputProcessorFactory { }, new TypeSerializer[] {typeSerializer1, typeSerializer2}, streamConfig.getStateKeySerializer(userClassloader), + new StreamTaskInput[0], memoryManager, ioManager, executionConfig.isObjectReuseEnabled(), @@ -117,8 +118,8 @@ public class StreamTwoInputProcessorFactory { userClassloader), jobConfig); inputSelectable = selectableSortingInputs.getInputSelectable(); - input1 = getSortedInput(selectableSortingInputs.getSortingInputs()[0]); - input2 = getSortedInput(selectableSortingInputs.getSortingInputs()[1]); + input1 = getSortedInput(selectableSortingInputs.getSortedInputs()[0]); + input2 = getSortedInput(selectableSortingInputs.getSortedInputs()[1]); } StreamTaskNetworkOutput output1 = diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/LargeSortingDataInputITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/LargeSortingDataInputITCase.java index 0ae66db91ec..39f0de8936c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/LargeSortingDataInputITCase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/LargeSortingDataInputITCase.java @@ -137,13 +137,14 @@ public class LargeSortingDataInputITCase { GeneratedRecordsDataInput.SERIALIZER }, new StringSerializer(), + new StreamTaskInput[0], environment.getMemoryManager(), environment.getIOManager(), true, 1.0, new Configuration()); - StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortingInputs(); + StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortedInputs(); try (StreamTaskInput> sortedInput1 = (StreamTaskInput>) sortingDataInputs[0]; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInputsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInputsTest.java index 63e7609dd24..15fc0f4c396 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInputsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/sort/MultiInputSortingDataInputsTest.java @@ -45,6 +45,106 @@ import static org.junit.Assert.assertThat; /** Tests for {@link MultiInputSortingDataInput}. */ public class MultiInputSortingDataInputsTest { + + @Test + public void passThroughThenSortedInput() throws Exception { + twoInputOrderTest(1, 0); + } + + @Test + public void sortedThenPassThroughInput() throws Exception { + twoInputOrderTest(0, 1); + } + + @SuppressWarnings("unchecked") + public void twoInputOrderTest(int preferredIndex, int sortedIndex) throws Exception { + CollectingDataOutput collectingDataOutput = new CollectingDataOutput<>(); + + List sortedInputElements = + Arrays.asList( + new StreamRecord<>(1, 3), + new StreamRecord<>(1, 1), + new StreamRecord<>(2, 1), + new StreamRecord<>(2, 3), + new StreamRecord<>(1, 2), + new StreamRecord<>(2, 2), + Watermark.MAX_WATERMARK); + CollectionDataInput sortedInput = + new CollectionDataInput<>(sortedInputElements, sortedIndex); + + List preferredInputElements = + Arrays.asList( + new StreamRecord<>(99, 3), new StreamRecord<>(99, 1), new Watermark(99L)); + CollectionDataInput preferredInput = + new CollectionDataInput<>(preferredInputElements, preferredIndex); + + KeySelector keySelector = value -> value; + + try (MockEnvironment environment = MockEnvironment.builder().build()) { + SelectableSortingInputs selectableSortingInputs = + MultiInputSortingDataInput.wrapInputs( + new DummyInvokable(), + new StreamTaskInput[] {sortedInput}, + new KeySelector[] {keySelector}, + new TypeSerializer[] {new IntSerializer()}, + new IntSerializer(), + new StreamTaskInput[] {preferredInput}, + environment.getMemoryManager(), + environment.getIOManager(), + true, + 1.0, + new Configuration()); + + StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortedInputs(); + StreamTaskInput[] preferredDataInputs = + selectableSortingInputs.getPassThroughInputs(); + + try (StreamTaskInput preferredTaskInput = + (StreamTaskInput) preferredDataInputs[0]; + StreamTaskInput sortedTaskInput = + (StreamTaskInput) sortingDataInputs[0]) { + + MultipleInputSelectionHandler selectionHandler = + new MultipleInputSelectionHandler( + selectableSortingInputs.getInputSelectable(), 2); + + @SuppressWarnings("rawtypes") + StreamOneInputProcessor[] inputProcessors = new StreamOneInputProcessor[2]; + inputProcessors[preferredIndex] = + new StreamOneInputProcessor<>( + preferredTaskInput, collectingDataOutput, new DummyOperatorChain()); + + inputProcessors[sortedIndex] = + new StreamOneInputProcessor<>( + sortedTaskInput, collectingDataOutput, new DummyOperatorChain()); + + StreamMultipleInputProcessor processor = + new StreamMultipleInputProcessor(selectionHandler, inputProcessors); + + InputStatus inputStatus; + do { + inputStatus = processor.processInput(); + } while (inputStatus != InputStatus.END_OF_INPUT); + } + } + + assertThat( + collectingDataOutput.events, + equalTo( + Arrays.asList( + new StreamRecord<>(99, 3), + new StreamRecord<>(99, 1), + new Watermark(99L), // max watermark from the preferred input + new StreamRecord<>(1, 1), + new StreamRecord<>(1, 2), + new StreamRecord<>(1, 3), + new StreamRecord<>(2, 1), + new StreamRecord<>(2, 2), + new StreamRecord<>(2, 3), + Watermark.MAX_WATERMARK // max watermark from the sorted input + ))); + } + @Test @SuppressWarnings("unchecked") public void simpleFixedLengthKeySorting() throws Exception { @@ -69,13 +169,14 @@ public class MultiInputSortingDataInputsTest { new KeySelector[] {keySelector, keySelector}, new TypeSerializer[] {new IntSerializer(), new IntSerializer()}, new IntSerializer(), + new StreamTaskInput[0], environment.getMemoryManager(), environment.getIOManager(), true, 1.0, new Configuration()); - StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortingInputs(); + StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortedInputs(); try (StreamTaskInput input1 = (StreamTaskInput) sortingDataInputs[0]; StreamTaskInput input2 = (StreamTaskInput) sortingDataInputs[1]) { @@ -148,13 +249,14 @@ public class MultiInputSortingDataInputsTest { new KeySelector[] {keySelector, keySelector}, new TypeSerializer[] {new IntSerializer(), new IntSerializer()}, new IntSerializer(), + new StreamTaskInput[0], environment.getMemoryManager(), environment.getIOManager(), true, 1.0, new Configuration()); - StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortingInputs(); + StreamTaskInput[] sortingDataInputs = selectableSortingInputs.getSortedInputs(); try (StreamTaskInput input1 = (StreamTaskInput) sortingDataInputs[0]; StreamTaskInput input2 = (StreamTaskInput) sortingDataInputs[1]) { -- GitLab