提交 31d3eaa7 编写于 作者: A Aljoscha Krettek 提交者: mbalassi

[FLINK-2182] Add stateful Streaming Sequence Source

Closes #804
上级 2eb5cfeb
......@@ -57,11 +57,11 @@ import org.apache.flink.streaming.api.functions.source.FromSplittableIteratorFun
import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SocketTextStreamFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.NumberSequenceIterator;
import org.apache.flink.util.SplittableIterator;
import java.io.File;
......@@ -399,8 +399,10 @@ public abstract class StreamExecutionEnvironment {
// --------------------------------------------------------------------------------------------
/**
* Creates a new data stream that contains a sequence of numbers. The data stream will be created with parallelism
* one, so the order of the elements is guaranteed.
* Creates a new data stream that contains a sequence of numbers. This is a parallel source,
* if you manually set the parallelism to {@code 1}
* (using {@link org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator.setParallelism()})
* the generated sequence of elements is in order.
*
* @param from
* The number to start at (inclusive)
......@@ -412,22 +414,7 @@ public abstract class StreamExecutionEnvironment {
if (from > to) {
throw new IllegalArgumentException("Start of sequence must not be greater than the end");
}
return fromCollection(new NumberSequenceIterator(from, to), BasicTypeInfo.LONG_TYPE_INFO, "Sequence Source");
}
/**
* Creates a new data stream that contains a sequence of numbers. The data stream will be created in parallel, so
* there is no guarantee about the oder of the elements.
*
* @param from
* The number to start at (inclusive)
* @param to
* The number to stop at (inclusive)
* @return A data stream, containing all number in the [from, to] interval
*/
public DataStreamSource<Long> generateParallelSequence(long from, long to) {
return fromParallelCollection(new NumberSequenceIterator(from, to), BasicTypeInfo.LONG_TYPE_INFO, "Parallel " +
"Sequence Source");
return addSource(new StatefulSequenceSource(from, to), "Sequence Source");
}
/**
......
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.functions.source;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
/**
* A stateful streaming source that emits each number from a given interval exactly once,
* possibly in parallel.
*/
public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements Checkpointed<Long> {
private static final long serialVersionUID = 1L;
private final long start;
private final long end;
private long collected;
private volatile boolean isRunning = true;
/**
* Creates a source that emits all numbers from the given interval exactly once.
*
* @param start Start of the range of numbers to emit.
* @param end End of the range of numbers to emit.
*/
public StatefulSequenceSource(long start, long end) {
this.start = start;
this.end = end;
this.collected = 0;
}
@Override
public void run(SourceContext<Long> ctx) throws Exception {
final Object checkpointLock = ctx.getCheckpointLock();
RuntimeContext context = getRuntimeContext();
final long stepSize = context.getNumberOfParallelSubtasks();
final long congruence = start + context.getIndexOfThisSubtask();
final long toCollect =
((end - start + 1) % stepSize > (congruence - start)) ?
((end - start + 1) / stepSize + 1) :
((end - start + 1) / stepSize);
while (isRunning && collected < toCollect) {
synchronized (checkpointLock) {
ctx.collect(collected * stepSize + congruence);
collected++;
}
}
}
@Override
public void cancel() {
isRunning = false;
}
@Override
public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
return collected;
}
@Override
public void restoreState(Long state) {
collected = state;
}
}
......@@ -295,7 +295,7 @@ public class DataStreamTest {
} catch (IllegalArgumentException success) {
}
DataStreamSource<Long> parallelSource = env.generateParallelSequence(0, 0);
DataStreamSource<Long> parallelSource = env.generateSequence(0, 0);
assertEquals(7, graph.getStreamNode(parallelSource.getId()).getParallelism());
parallelSource.setParallelism(3);
......
......@@ -24,7 +24,8 @@ import java.util.Arrays;
import java.util.List;
import org.apache.flink.streaming.api.functions.source.FromElementsFunction;
import org.apache.flink.streaming.util.MockSource;
import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
import org.apache.flink.streaming.util.SourceFunctionUtil;
import org.junit.Test;
public class SourceFunctionTest {
......@@ -32,19 +33,29 @@ public class SourceFunctionTest {
@Test
public void fromElementsTest() throws Exception {
List<Integer> expectedList = Arrays.asList(1, 2, 3);
List<Integer> actualList = MockSource.createAndExecute(new FromElementsFunction<Integer>(1,
2, 3));
List<Integer> actualList = SourceFunctionUtil.runSourceFunction(new FromElementsFunction<Integer>(
1,
2,
3));
assertEquals(expectedList, actualList);
}
@Test
public void fromCollectionTest() throws Exception {
List<Integer> expectedList = Arrays.asList(1, 2, 3);
List<Integer> actualList = MockSource.createAndExecute(new FromElementsFunction<Integer>(
List<Integer> actualList = SourceFunctionUtil.runSourceFunction(new FromElementsFunction<Integer>(
Arrays.asList(1, 2, 3)));
assertEquals(expectedList, actualList);
}
@Test
public void generateSequenceTest() throws Exception {
List<Long> expectedList = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L);
List<Long> actualList = SourceFunctionUtil.runSourceFunction(new StatefulSequenceSource(1,
7));
assertEquals(expectedList, actualList);
}
@Test
public void socketTextStreamTest() throws Exception {
// TODO: does not work because we cannot set the internal socket anymore
......
......@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
......@@ -32,9 +31,9 @@ import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.FromElementsFunction;
import org.apache.flink.streaming.api.functions.source.FromIteratorFunction;
import org.apache.flink.streaming.api.functions.source.FromSplittableIteratorFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
......@@ -73,33 +72,12 @@ public class StreamExecutionEnvironmentTest {
plan.contains("\"contents\":\"Parallel Collection Source\",\"parallelism\":4"));
}
@Test
public void testGenerateSequenceParallelism() throws Exception {
StreamExecutionEnvironment env = new TestStreamEnvironment(PARALLELISM, MEMORYSIZE);
boolean seenExpectedException = false;
try {
DataStream<Long> dataStream1 = env.generateSequence(0, 0).setParallelism(4);
} catch (IllegalArgumentException e) {
seenExpectedException = true;
}
DataStream<Long> dataStream2 = env.generateParallelSequence(0, 0).setParallelism(4);
String plan = env.getExecutionPlan();
assertTrue("Expected Exception for setting parallelism was not thrown.", seenExpectedException);
assertTrue("Parallelism for dataStream1 is not right.",
plan.contains("\"contents\":\"Sequence Source\",\"parallelism\":1"));
assertTrue("Parallelism for dataStream2 is not right.",
plan.contains("\"contents\":\"Parallel Sequence Source\",\"parallelism\":4"));
}
@Test
public void testSources() {
StreamExecutionEnvironment env = new TestStreamEnvironment(PARALLELISM, MEMORYSIZE);
SourceFunction<Integer> srcFun = new SourceFunction<Integer>() {
private static final long serialVersionUID = 1L;
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
......@@ -110,21 +88,18 @@ public class StreamExecutionEnvironmentTest {
}
};
DataStreamSource<Integer> src1 = env.addSource(srcFun);
assertEquals(srcFun, getFunctionForDataSource(src1));
assertEquals(srcFun, getFunctionFromDataSource(src1));
List<Long> list = Arrays.asList(0L, 1L, 2L);
DataStreamSource<Long> src2 = env.generateSequence(0, 2);
assertTrue(getFunctionForDataSource(src2) instanceof FromIteratorFunction);
assertTrue(getFunctionFromDataSource(src2) instanceof StatefulSequenceSource);
DataStreamSource<Long> src3 = env.fromElements(0L, 1L, 2L);
assertTrue(getFunctionForDataSource(src3) instanceof FromElementsFunction);
assertTrue(getFunctionFromDataSource(src3) instanceof FromElementsFunction);
DataStreamSource<Long> src4 = env.fromCollection(list);
assertTrue(getFunctionForDataSource(src4) instanceof FromElementsFunction);
DataStreamSource<Long> src5 = env.generateParallelSequence(0, 2);
assertTrue(getFunctionForDataSource(src5) instanceof FromSplittableIteratorFunction);
assertTrue(getFunctionFromDataSource(src4) instanceof FromElementsFunction);
}
/////////////////////////////////////////////////////////////
......@@ -132,15 +107,15 @@ public class StreamExecutionEnvironmentTest {
/////////////////////////////////////////////////////////////
private static StreamOperator<?> getOperatorForDataStream(DataStream<?> dataStream) {
private static StreamOperator<?> getOperatorFromDataStream(DataStream<?> dataStream) {
StreamExecutionEnvironment env = dataStream.getExecutionEnvironment();
StreamGraph streamGraph = env.getStreamGraph();
return streamGraph.getStreamNode(dataStream.getId()).getOperator();
}
private static <T> SourceFunction<T> getFunctionForDataSource(DataStreamSource<T> dataStreamSource) {
private static <T> SourceFunction<T> getFunctionFromDataSource(DataStreamSource<T> dataStreamSource) {
AbstractUdfStreamOperator<?, ?> operator =
(AbstractUdfStreamOperator<?, ?>) getOperatorForDataStream(dataStreamSource);
(AbstractUdfStreamOperator<?, ?>) getOperatorFromDataStream(dataStreamSource);
return (SourceFunction<T>) operator.getUserFunction();
}
......
......@@ -101,7 +101,7 @@ public class DirectedOutputTest {
TestListResultSink<Long> evenAndOddSink = new TestListResultSink<Long>();
TestListResultSink<Long> allSink = new TestListResultSink<Long>();
SplitDataStream<Long> source = env.generateParallelSequence(1, 11).split(new MyOutputSelector());
SplitDataStream<Long> source = env.generateSequence(1, 11).split(new MyOutputSelector());
source.select(EVEN).addSink(evenSink);
source.select(ODD, TEN).addSink(oddAndTenSink);
source.select(EVEN, ODD).addSink(evenAndOddSink);
......
......@@ -207,6 +207,11 @@ public class ComplexIntegrationTest extends StreamingMultipleProgramsTestBase {
"16937\n" + "11927\n" + "9973\n" + "14431\n" + "19507\n" + "12497\n" + "17497\n" + "14983\n" +
"19997\n";
expected1 = "541\n" + "1223\n" + "1987\n" + "2741\n" + "3571\n" + "10939\n" + "4409\n" +
"5279\n" + "11927\n" + "6133\n" + "6997\n" + "12823\n" + "7919\n" + "8831\n" +
"13763\n" + "9733\n" + "9973\n" + "14759\n" + "15671\n" + "16673\n" + "17659\n" +
"18617\n" + "19697\n" + "19997\n";
for (int i = 2; i < 100; i++) {
expected2 += "(" + i + "," + 20000 / i + ")\n";
}
......@@ -217,11 +222,15 @@ public class ComplexIntegrationTest extends StreamingMultipleProgramsTestBase {
expected2 += "(" + 20000 + "," + 1 + ")";
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// set to parallelism 1 because otherwise we don't know which elements go to which parallel
// count-window.
env.setParallelism(1);
env.setBufferTimeout(0);
DataStream<Long> sourceStream31 = env.generateParallelSequence(1, 10000);
DataStream<Long> sourceStream32 = env.generateParallelSequence(10001, 20000);
DataStream<Long> sourceStream31 = env.generateSequence(1, 10000);
DataStream<Long> sourceStream32 = env.generateSequence(10001, 20000);
sourceStream31.filter(new PrimeFilterFunction())
.window(Count.of(100))
......@@ -299,14 +308,18 @@ public class ComplexIntegrationTest extends StreamingMultipleProgramsTestBase {
//Turning on and off chaining
expected1 = "1\n" + "2\n" + "2\n" + "3\n" + "3\n" + "3\n" + "4\n" + "4\n" + "4\n" + "4\n" + "5\n" + "5\n" +
"5\n" + "5\n" + "5\n" + "1\n" + "3\n" + "3\n" + "4\n" + "5\n" + "5\n" + "6\n" + "8\n" + "9\n" + "10\n" +
"12\n" + "15\n" + "16\n" + "20\n" + "25\n";
"5\n" + "5\n" + "5\n" + "1\n" + "3\n" + "5\n" + "8\n" + "11\n" + "14\n" + "18\n" + "22\n" + "26\n" +
"30\n" + "35\n" + "40\n" + "45\n" + "50\n" + "55\n";
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// Set to parallelism 1 to make it deterministic, otherwise, it is not clear which
// elements will go to which parallel instance of the fold
env.setParallelism(1);
env.setBufferTimeout(0);
DataStream<Long> dataStream51 = env.generateParallelSequence(1, 5)
DataStream<Long> dataStream51 = env.generateSequence(1, 5)
.map(new MapFunction<Long, Long>() {
@Override
......@@ -346,6 +359,8 @@ public class ComplexIntegrationTest extends StreamingMultipleProgramsTestBase {
});
dataStream53.union(dataStream52).print();
dataStream53.union(dataStream52)
.writeAsText(resultPath1, FileSystem.WriteMode.OVERWRITE);
......
......@@ -92,7 +92,7 @@ public class ProjectTest implements Serializable {
StreamExecutionEnvironment env = new TestStreamEnvironment(1, MEMORY_SIZE);
env.generateParallelSequence(1, 10).map(new MapFunction<Long, Tuple3<Long, Character, Double>>() {
env.generateSequence(1, 10).map(new MapFunction<Long, Tuple3<Long, Character, Double>>() {
@Override
public Tuple3<Long, Character, Double> map(Long value) throws Exception {
return new Tuple3<Long, Character, Double>(value, 'c', value.doubleValue());
......
......@@ -158,7 +158,7 @@ public class StreamVertexTest {
StreamExecutionEnvironment env = new TestStreamEnvironment(SOURCE_PARALELISM, MEMORYSIZE);
DataStream<String> fromStringElements = env.fromElements("aa", "bb", "cc");
DataStream<Long> generatedSequence = env.generateParallelSequence(0, 3);
DataStream<Long> generatedSequence = env.generateSequence(0, 3);
fromStringElements.connect(generatedSequence).map(new CoMap()).addSink(new SetSink());
......
......@@ -20,17 +20,27 @@ package org.apache.flink.streaming.util;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
import org.apache.flink.util.Collector;
public class MockSource<T> {
public class SourceFunctionUtil<T> {
public static <T> List<T> createAndExecute(SourceFunction<T> sourceFunction) throws Exception {
public static <T> List<T> runSourceFunction(SourceFunction<T> sourceFunction) throws Exception {
List<T> outputs = new ArrayList<T>();
if (sourceFunction instanceof RichSourceFunction) {
((RichSourceFunction<T>) sourceFunction).open(new Configuration());
if (sourceFunction instanceof RichFunction) {
RuntimeContext runtimeContext = new StreamingRuntimeContext("MockTask", new MockEnvironment(3 * 1024 * 1024, new MockInputSplitProvider(), 1024), null,
new ExecutionConfig());
((RichFunction) sourceFunction).setRuntimeContext(runtimeContext);
((RichFunction) sourceFunction).open(new Configuration());
}
try {
final Collector<T> collector = new MockOutput<T>(outputs);
......
......@@ -232,22 +232,12 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) {
// --------------------------------------------------------------------------------------------
/**
* Creates a new DataStream that contains a sequence of numbers.
*
* Note that this operation will result in a non-parallel data source, i.e. a data source with
* a parallelism of one.
* Creates a new DataStream that contains a sequence of numbers. This source is a parallel source.
* If you manually set the parallelism to `1` the emitted elements are in order.
*/
def generateSequence(from: Long, to: Long): DataStream[Long] = {
new DataStream[java.lang.Long](javaEnv.generateSequence(from, to)).
asInstanceOf[DataStream[Long]]
}
/**
* Creates a new DataStream that contains a sequence of numbers in a parallel fashion.
*/
def generateParallelSequence(from: Long, to: Long): DataStream[Long] = {
new DataStream[java.lang.Long](javaEnv.generateParallelSequence(from, to)).
asInstanceOf[DataStream[Long]]
new DataStream[java.lang.Long](javaEnv.generateSequence(from, to))
.asInstanceOf[DataStream[Long]]
}
/**
......
......@@ -228,7 +228,7 @@ class DataStreamTest {
assert(7 == graph.getStreamNode(windowed.getId).getParallelism)
assert(7 == graph.getStreamNode(sink.getId).getParallelism)
val parallelSource = env.generateParallelSequence(0, 0)
val parallelSource = env.generateSequence(0, 0)
assert(7 == graph.getStreamNode(parallelSource.getId).getParallelism)
......
......@@ -29,6 +29,7 @@ import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FileStateHandle;
......@@ -126,7 +127,7 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
public void run(SourceContext<Long> sourceCtx) throws Exception {
final Object checkpointLock = sourceCtx.getCheckpointLock();
StreamingRuntimeContext runtimeCtx = (StreamingRuntimeContext) getRuntimeContext();
RuntimeContext runtimeCtx = getRuntimeContext();
final long stepSize = runtimeCtx.getNumberOfParallelSubtasks();
final long congruence = runtimeCtx.getIndexOfThisSubtask();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册