提交 3fd3110c 编写于 作者: A Aljoscha Krettek 提交者: Stephan Ewen

[FLINK-1110] Implement Collection-Based Execution for Delta Iterations

上级 b61f63c0
......@@ -18,6 +18,7 @@
package org.apache.flink.api.common.operators;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
......@@ -37,8 +38,13 @@ import org.apache.flink.api.common.functions.util.IterationRuntimeUDFContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.base.BulkIterationBase;
import org.apache.flink.api.common.operators.base.BulkIterationBase.PartialSolutionPlaceHolder;
import org.apache.flink.api.common.operators.base.DeltaIterationBase;
import org.apache.flink.api.common.operators.base.DeltaIterationBase.SolutionSetPlaceHolder;
import org.apache.flink.api.common.operators.base.DeltaIterationBase.WorksetPlaceHolder;
import org.apache.flink.api.common.operators.util.TypeComparable;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.util.Visitor;
/**
......@@ -87,16 +93,19 @@ public class CollectionExecutor {
}
if (operator instanceof BulkIterationBase) {
return executeBulkIteration((BulkIterationBase<?>) operator);
result = executeBulkIteration((BulkIterationBase<?>) operator);
}
else if (operator instanceof DeltaIterationBase) {
result = executeDeltaIteration((DeltaIterationBase<?, ?>) operator);
}
else if (operator instanceof SingleInputOperator) {
return executeUnaryOperator((SingleInputOperator<?, ?, ?>) operator, superStep);
result = executeUnaryOperator((SingleInputOperator<?, ?, ?>) operator, superStep);
}
else if (operator instanceof DualInputOperator) {
return executeBinaryOperator((DualInputOperator<?, ?, ?, ?>) operator, superStep);
result = executeBinaryOperator((DualInputOperator<?, ?, ?, ?>) operator, superStep);
}
else if (operator instanceof GenericDataSourceBase) {
return executeDataSource((GenericDataSourceBase<?, ?>) operator);
result = executeDataSource((GenericDataSourceBase<?, ?>) operator);
}
else if (operator instanceof GenericDataSinkBase) {
executeDataSink((GenericDataSinkBase<?>) operator);
......@@ -217,7 +226,7 @@ public class CollectionExecutor {
private <T> List<T> executeBulkIteration(BulkIterationBase<?> iteration) throws Exception {
Operator<?> inputOp = iteration.getInput();
if (inputOp == null) {
throw new InvalidProgramException("The iteration " + iteration.getName() + " has input (initial partial solution).");
throw new InvalidProgramException("The iteration " + iteration.getName() + " has no input (initial partial solution).");
}
if (iteration.getNextPartialSolution() == null) {
throw new InvalidProgramException("The iteration " + iteration.getName() + " has no next partial solution defined (is not closed).");
......@@ -244,8 +253,7 @@ public class CollectionExecutor {
// grab the current iteration result
currentResult = (List<T>) execute(iteration.getNextPartialSolution(), superstep);
this.intermediateResults.put(iteration.getNextPartialSolution(), currentResult);
// evaluate the termination criterion
if (iteration.getTerminationCriterion() != null) {
List<?> term = execute(((SingleInputOperator<?, ?, ?>) iteration.getTerminationCriterion()).getInput(), superstep);
......@@ -262,6 +270,88 @@ public class CollectionExecutor {
return currentResult;
}
@SuppressWarnings("unchecked")
private <T> List<T> executeDeltaIteration(DeltaIterationBase<?, ?> iteration) throws Exception {
Operator<?> solutionInput = iteration.getInitialSolutionSet();
Operator<?> worksetInput = iteration.getInitialWorkset();
if (solutionInput == null) {
throw new InvalidProgramException("The delta iteration " + iteration.getName() + " has no initial solution set.");
}
if (worksetInput == null) {
throw new InvalidProgramException("The delta iteration " + iteration.getName() + " has no initial workset.");
}
if (iteration.getSolutionSetDelta() == null) {
throw new InvalidProgramException("The iteration " + iteration.getName() + " has no solution set delta defined (is not closed).");
}
if (iteration.getNextWorkset() == null) {
throw new InvalidProgramException("The iteration " + iteration.getName() + " has no workset defined (is not closed).");
}
List<T> solutionInputData = (List<T>) execute(solutionInput);
List<T> worksetInputData = (List<T>) execute(worksetInput);
// get the operators that are iterative
Set<Operator<?>> dynamics = new LinkedHashSet<Operator<?>>();
DynamicPathCollector dynCollector = new DynamicPathCollector(dynamics);
iteration.getSolutionSetDelta().accept(dynCollector);
iteration.getNextWorkset().accept(dynCollector);
BinaryOperatorInformation<?, ?, ?> operatorInfo = iteration.getOperatorInfo();
TypeInformation<?> solutionType = operatorInfo.getFirstInputType();
int[] keyColumns = iteration.getSolutionSetKeyFields();
boolean[] inputOrderings = new boolean[keyColumns.length];
TypeComparator<T> inputComparator = ((CompositeType<T>) solutionType).createComparator(keyColumns, inputOrderings);
Map<TypeComparable<T>, T> solutionMap = new HashMap<TypeComparable<T>, T>(solutionInputData.size());
// fill the solution from the initial input
for (T delta: solutionInputData) {
TypeComparable<T> wrapper = new TypeComparable<T>(delta, inputComparator);
solutionMap.put(wrapper, delta);
}
List<?> currentWorkset = worksetInputData;
final int maxIterations = iteration.getMaximumNumberOfIterations();
for (int superstep = 1; superstep <= maxIterations; superstep++) {
List<T> currentSolution = new ArrayList<T>(solutionMap.size());
currentSolution.addAll(solutionMap.values());
// set the input to the current partial solution
this.intermediateResults.put(iteration.getSolutionSet(), currentSolution);
this.intermediateResults.put(iteration.getWorkset(), currentWorkset);
// grab the current iteration result
List<T> solutionSetDelta = (List<T>) execute(iteration.getSolutionSetDelta(), superstep);
this.intermediateResults.put(iteration.getSolutionSetDelta(), solutionSetDelta);
// update the solution
for (T delta: solutionSetDelta) {
TypeComparable<T> wrapper = new TypeComparable<T>(delta, inputComparator);
solutionMap.put(wrapper, delta);
}
currentWorkset = (List<?>) execute(iteration.getNextWorkset(), superstep);
if (currentWorkset.isEmpty()) {
break;
}
// clear the dynamic results
for (Operator<?> o : dynamics) {
intermediateResults.remove(o);
}
}
List<T> currentSolution = new ArrayList<T>(solutionMap.size());
currentSolution.addAll(solutionMap.values());
return currentSolution;
}
// --------------------------------------------------------------------------------------------
// --------------------------------------------------------------------------------------------
......
......@@ -334,7 +334,6 @@ public class DeltaIterationBase<ST, WT> extends DualInputOperator<ST, WT, ST, Ab
@Override
protected List<ST> executeOnCollections(List<ST> inputData1, List<WT> inputData2, RuntimeContext runtimeContext) throws Exception {
// TODO Auto-generated method stub
return null;
throw new UnsupportedOperationException();
}
}
......@@ -18,18 +18,26 @@
package org.apache.flink.api.common.operators;
//CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator
import static org.junit.Assert.*;
//CHECKSTYLE.ON: AvoidStarImport
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.CollectionEnvironment;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.LocalCollectionOutputFormat;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.util.Collector;
import org.junit.Test;
@SuppressWarnings("serial")
......@@ -88,6 +96,71 @@ public class CollectionExecutionIterationTest implements java.io.Serializable {
fail(e.getMessage());
}
}
@Test
public void testDeltaIteration() {
try {
ExecutionEnvironment env = new CollectionEnvironment();
DataSet<Tuple2<Integer, Integer>> solInput = env.fromElements(
new Tuple2<Integer, Integer>(1, 0),
new Tuple2<Integer, Integer>(2, 0),
new Tuple2<Integer, Integer>(3, 0),
new Tuple2<Integer, Integer>(4, 0));
DataSet<Tuple1<Integer>> workInput = env.fromElements(
new Tuple1<Integer>(1),
new Tuple1<Integer>(2),
new Tuple1<Integer>(3),
new Tuple1<Integer>(4));
// Perform a delta iteration where we add those values to the workset where
// the second tuple field is smaller than the first tuple field.
// At the end both tuple fields must be the same.
DeltaIteration<Tuple2<Integer, Integer>, Tuple1<Integer>> iteration =
solInput.iterateDelta(workInput, 10, 0);
DataSet<Tuple2<Integer, Integer>> solDelta = iteration.getSolutionSet().join(
iteration.getWorkset()).where(0).equalTo(0).with(
new JoinFunction<Tuple2<Integer, Integer>, Tuple1<Integer>, Tuple2<Integer, Integer>>() {
@Override
public Tuple2<Integer, Integer> join(Tuple2<Integer, Integer> first,
Tuple1<Integer> second) throws Exception {
return new Tuple2<Integer, Integer>(first.f0, first.f1 + 1);
}
});
DataSet<Tuple1<Integer>> nextWorkset = solDelta.flatMap(
new FlatMapFunction<Tuple2<Integer, Integer>, Tuple1<Integer>>() {
@Override
public void flatMap(Tuple2<Integer, Integer> in, Collector<Tuple1<Integer>>
out) throws Exception {
if (in.f1 < in.f0) {
out.collect(new Tuple1<Integer>(in.f0));
}
}
});
List<Tuple2<Integer, Integer>> collected = new ArrayList<Tuple2<Integer, Integer>>();
iteration.closeWith(solDelta, nextWorkset)
.output(new LocalCollectionOutputFormat<Tuple2<Integer, Integer>>(collected));
env.execute();
// verify that both tuple fields are now the same
for (Tuple2<Integer, Integer> t: collected) {
assertEquals(t.f0, t.f1);
}
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
public static class AddSuperstepNumberMapper extends RichMapFunction<Integer, Integer> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册