提交 722719f2 编写于 作者: D Dániel Bali 提交者: vasia

[FLINK-1514] [gelly] Add a Gather-Sum-Apply iteration method

上级 b2aafe58
...@@ -57,4 +57,35 @@ under the License. ...@@ -57,4 +57,35 @@ under the License.
<version>${guava.version}</version> <version>${guava.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>
<!-- See main pom.xml for explanation of profiles -->
<profiles>
<profile>
<id>hadoop-1</id>
<activation>
<property>
<!-- Please do not remove the 'hadoop1' comment. See ./tools/generate_specific_pom.sh -->
<!--hadoop1--><name>hadoop.profile</name><value>1</value>
</property>
</activation>
<dependencies>
<!-- Add this here, for hadoop-2 we don't need it since we get guava transitively -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>hadoop-2</id>
<activation>
<property>
<!-- Please do not remove the 'hadoop2' comment. See ./tools/generate_specific_pom.sh -->
<!--hadoop2--><name>!hadoop.profile</name>
</property>
</activation>
</profile>
</profiles>
</project> </project>
...@@ -43,15 +43,18 @@ import org.apache.flink.api.java.operators.DeltaIteration; ...@@ -43,15 +43,18 @@ import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.graph.spargel.IterationConfiguration; import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.spargel.MessagingFunction; import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexCentricIteration; import org.apache.flink.graph.spargel.VertexCentricIteration;
import org.apache.flink.graph.spargel.VertexUpdateFunction; import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.utils.EdgeToTuple3Map; import org.apache.flink.graph.utils.EdgeToTuple3Map;
import org.apache.flink.graph.utils.GraphUtils;
import org.apache.flink.graph.utils.Tuple2ToVertexMap; import org.apache.flink.graph.utils.Tuple2ToVertexMap;
import org.apache.flink.graph.utils.Tuple3ToEdgeMap; import org.apache.flink.graph.utils.Tuple3ToEdgeMap;
import org.apache.flink.graph.utils.VertexToTuple2Map; import org.apache.flink.graph.utils.VertexToTuple2Map;
...@@ -79,7 +82,8 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab ...@@ -79,7 +82,8 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab
private final DataSet<Edge<K, EV>> edges; private final DataSet<Edge<K, EV>> edges;
/** /**
* Creates a graph from two DataSets: vertices and edges * Creates a graph from two DataSets: vertices and edges and allow setting
* the undirected property
* *
* @param vertices a DataSet of vertices. * @param vertices a DataSet of vertices.
* @param edges a DataSet of edges. * @param edges a DataSet of edges.
...@@ -347,7 +351,7 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab ...@@ -347,7 +351,7 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab
@Override @Override
public void join(Tuple4<K, K, VV, EV> tripletWithSrcValSet, public void join(Tuple4<K, K, VV, EV> tripletWithSrcValSet,
Vertex<K, VV> vertex, Collector<Triplet<K, VV, EV>> collector) throws Exception { Vertex<K, VV> vertex, Collector<Triplet<K, VV, EV>> collector) throws Exception {
collector.collect(new Triplet<K, VV, EV>(tripletWithSrcValSet.f0, tripletWithSrcValSet.f1, collector.collect(new Triplet<K, VV, EV>(tripletWithSrcValSet.f0, tripletWithSrcValSet.f1,
tripletWithSrcValSet.f2, vertex.getValue(), tripletWithSrcValSet.f3)); tripletWithSrcValSet.f2, vertex.getValue(), tripletWithSrcValSet.f3));
...@@ -914,7 +918,7 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab ...@@ -914,7 +918,7 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab
} }
/** /**
* @return a long integer representing the number of edges * @return Singleton DataSet containing the edge count
*/ */
public long numberOfEdges() throws Exception { public long numberOfEdges() throws Exception {
return edges.count(); return edges.count();
...@@ -1011,6 +1015,13 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab ...@@ -1011,6 +1015,13 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab
} }
} }
private static final class CheckIfOneComponentMapper implements MapFunction<Integer, Boolean> {
@Override
public Boolean map(Integer n) {
return (n == 1);
}
}
/** /**
* Adds the input vertex and edges to the graph. If the vertex already * Adds the input vertex and edges to the graph. If the vertex already
* exists in the graph, it will not be added again, but the given edges * exists in the graph, it will not be added again, but the given edges
...@@ -1165,7 +1176,7 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab ...@@ -1165,7 +1176,7 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab
int maximumNumberOfIterations) { int maximumNumberOfIterations) {
return this.runVertexCentricIteration(vertexUpdateFunction, messagingFunction, return this.runVertexCentricIteration(vertexUpdateFunction, messagingFunction,
maximumNumberOfIterations, null); maximumNumberOfIterations, null);
} }
/** /**
...@@ -1397,4 +1408,4 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab ...@@ -1397,4 +1408,4 @@ public class Graph<K extends Comparable<K> & Serializable, VV extends Serializab
return TypeExtractor.createTypeInfo(NeighborsFunctionWithVertexValue.class, function.getClass(), 3, null, null); return TypeExtractor.createTypeInfo(NeighborsFunctionWithVertexValue.class, function.getClass(), 3, null, null);
} }
} }
} }
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.example;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.gsa.RichEdge;
import org.apache.flink.util.Collector;
import java.util.HashSet;
/**
* This is an implementation of the Greedy Graph Coloring algorithm, using a gather-sum-apply iteration
*/
public class GSAGreedyGraphColoringExample implements ProgramDescription {
// --------------------------------------------------------------------------------------------
// Program
// --------------------------------------------------------------------------------------------
public static void main(String[] args) throws Exception {
if (!parseParameters(args)) {
return;
}
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Vertex<Long, Double>> vertices = getVertexDataSet(env);
DataSet<Edge<Long, Double>> edges = getEdgeDataSet(env);
Graph<Long, Double, Double> graph = Graph.fromDataSet(vertices, edges, env);
// Gather the target vertices into a one-element set
GatherFunction<Double, Double, HashSet<Double>> gather = new GreedyGraphColoringGather();
// Merge the sets between neighbors
SumFunction<Double, Double, HashSet<Double>> sum = new GreedyGraphColoringSum();
// Find the minimum vertex id in the set which will be propagated
ApplyFunction<Double, Double, HashSet<Double>> apply = new GreedyGraphColoringApply();
// Execute the GSA iteration
GatherSumApplyIteration<Long, Double, Double, HashSet<Double>> iteration =
graph.createGatherSumApplyIteration(gather, sum, apply, maxIterations);
Graph<Long, Double, Double> result = graph.runGatherSumApplyIteration(iteration);
// Extract the vertices as the result
DataSet<Vertex<Long, Double>> greedyGraphColoring = result.getVertices();
// emit result
if (fileOutput) {
greedyGraphColoring.writeAsCsv(outputPath, "\n", " ");
} else {
greedyGraphColoring.print();
}
env.execute("GSA Greedy Graph Coloring");
}
// --------------------------------------------------------------------------------------------
// Greedy Graph Coloring UDFs
// --------------------------------------------------------------------------------------------
private static final class GreedyGraphColoringGather
extends GatherFunction<Double, Double, HashSet<Double>> {
@Override
public HashSet<Double> gather(RichEdge<Double, Double> richEdge) {
HashSet<Double> result = new HashSet<Double>();
result.add(richEdge.getSrcVertexValue());
return result;
}
};
private static final class GreedyGraphColoringSum
extends SumFunction<Double, Double, HashSet<Double>> {
@Override
public HashSet<Double> sum(HashSet<Double> newValue, HashSet<Double> currentValue) {
HashSet<Double> result = new HashSet<Double>();
result.addAll(newValue);
result.addAll(currentValue);
return result;
}
};
private static final class GreedyGraphColoringApply
extends ApplyFunction<Double, Double, HashSet<Double>> {
@Override
public void apply(HashSet<Double> set, Double src) {
double minValue = src;
for (Double d : set) {
if (d < minValue) {
minValue = d;
}
}
// This is the condition that enables the termination of the iteration
if (minValue < src) {
setResult(minValue);
}
}
};
// --------------------------------------------------------------------------------------------
// Util methods
// --------------------------------------------------------------------------------------------
private static boolean fileOutput = false;
private static String vertexInputPath = null;
private static String edgeInputPath = null;
private static String outputPath = null;
private static int maxIterations = 16;
private static boolean parseParameters(String[] args) {
if(args.length > 0) {
// parse input arguments
fileOutput = true;
if(args.length != 4) {
System.err.println("Usage: GSAGreedyGraphColoringExample <vertex path> <edge path> " +
"<result path> <max iterations>");
return false;
}
vertexInputPath = args[0];
edgeInputPath = args[1];
outputPath = args[2];
maxIterations = Integer.parseInt(args[3]);
} else {
System.out.println("Executing GSA Greedy Graph Coloring example with built-in default data.");
System.out.println(" Provide parameters to read input data from files.");
System.out.println(" See the documentation for the correct format of input files.");
System.out.println(" Usage: GSAGreedyGraphColoringExample <vertex path> <edge path> "
+ "<result path> <max iterations>");
}
return true;
}
private static DataSet<Vertex<Long, Double>> getVertexDataSet(ExecutionEnvironment env) {
if(fileOutput) {
return env
.readCsvFile(vertexInputPath)
.fieldDelimiter(" ")
.lineDelimiter("\n")
.types(Long.class, Double.class)
.map(new MapFunction<Tuple2<Long, Double>, Vertex<Long, Double>>() {
@Override
public Vertex<Long, Double> map(Tuple2<Long, Double> value) throws Exception {
return new Vertex<Long, Double>(value.f0, value.f1);
}
});
}
return env.generateSequence(0, 5).map(new MapFunction<Long, Vertex<Long, Double>>() {
@Override
public Vertex<Long, Double> map(Long value) throws Exception {
return new Vertex<Long, Double>(value, (double) value);
}
});
}
private static DataSet<Edge<Long, Double>> getEdgeDataSet(ExecutionEnvironment env) {
if(fileOutput) {
return env.readCsvFile(edgeInputPath)
.fieldDelimiter(" ")
.lineDelimiter("\n")
.types(Long.class, Long.class, Double.class)
.map(new MapFunction<Tuple3<Long, Long, Double>, Edge<Long, Double>>() {
@Override
public Edge<Long, Double> map(Tuple3<Long, Long, Double> value) throws Exception {
return new Edge<Long, Double>(value.f0, value.f1, value.f2);
}
});
}
return env.generateSequence(0, 5).flatMap(new FlatMapFunction<Long, Edge<Long, Double>>() {
@Override
public void flatMap(Long value, Collector<Edge<Long, Double>> out) throws Exception {
out.collect(new Edge<Long, Double>(value, (value + 1) % 6, 0.0));
out.collect(new Edge<Long, Double>(value, (value + 2) % 6, 0.0));
}
});
}
@Override
public String getDescription() {
return "GSA Greedy Graph Coloring";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.example;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.example.utils.SingleSourceShortestPathsData;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.gsa.RichEdge;
import java.io.Serializable;
/**
* This is an implementation of the Single Source Shortest Paths algorithm, using a gather-sum-apply iteration
*/
public class GSASingleSourceShortestPathsExample implements ProgramDescription {
// --------------------------------------------------------------------------------------------
// Program
// --------------------------------------------------------------------------------------------
public static void main(String[] args) throws Exception {
if(!parseParameters(args)) {
return;
}
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Vertex<Long, Double>> vertices = getVertexDataSet(env);
DataSet<Edge<Long, Double>> edges = getEdgeDataSet(env);
Graph<Long, Double, Double> graph = Graph.fromDataSet(vertices, edges, env);
// The path from src to trg through edge e costs src + e
GatherFunction<Double, Double, Double> gather = new SingleSourceShortestPathGather();
// Return the smaller path length to minimize distance
SumFunction<Double, Double, Double> sum = new SingleSourceShortestPathSum();
// Iterate as long as the distance is updated
ApplyFunction<Double, Double, Double> apply = new SingleSourceShortestPathApply();
// Execute the GSA iteration
GatherSumApplyIteration<Long, Double, Double, Double> iteration = graph.createGatherSumApplyIteration(
gather, sum, apply, maxIterations);
Graph<Long, Double, Double> result = graph.mapVertices(new InitVerticesMapper<Long>(srcVertexId))
.runGatherSumApplyIteration(iteration);
// Extract the vertices as the result
DataSet<Vertex<Long, Double>> singleSourceShortestPaths = result.getVertices();
// emit result
if(fileOutput) {
singleSourceShortestPaths.writeAsCsv(outputPath, "\n", " ");
} else {
singleSourceShortestPaths.print();
}
env.execute("GSA Single Source Shortest Paths Example");
}
public static final class InitVerticesMapper<K extends Comparable<K> & Serializable>
implements MapFunction<Vertex<K, Double>, Double> {
private K srcVertexId;
public InitVerticesMapper(K srcId) {
this.srcVertexId = srcId;
}
public Double map(Vertex<K, Double> value) {
if (value.f0.equals(srcVertexId)) {
return 0.0;
} else {
return Double.POSITIVE_INFINITY;
}
}
}
// --------------------------------------------------------------------------------------------
// Single Source Shortest Path UDFs
// --------------------------------------------------------------------------------------------
private static final class SingleSourceShortestPathGather
extends GatherFunction<Double, Double, Double> {
@Override
public Double gather(RichEdge<Double, Double> richEdge) {
return richEdge.getSrcVertexValue() + richEdge.getEdgeValue();
}
};
private static final class SingleSourceShortestPathSum
extends SumFunction<Double, Double, Double> {
@Override
public Double sum(Double newValue, Double currentValue) {
return Math.min(newValue, currentValue);
}
};
private static final class SingleSourceShortestPathApply
extends ApplyFunction<Double, Double, Double> {
@Override
public void apply(Double summed, Double target) {
if (summed < target) {
setResult(summed);
}
}
};
// --------------------------------------------------------------------------------------------
// Util methods
// --------------------------------------------------------------------------------------------
private static boolean fileOutput = false;
private static String vertexInputPath = null;
private static String edgeInputPath = null;
private static String outputPath = null;
private static int maxIterations = 2;
private static long srcVertexId = 1;
private static boolean parseParameters(String[] args) {
if(args.length > 0) {
// parse input arguments
fileOutput = true;
if(args.length != 5) {
System.err.println("Usage: GSASingleSourceShortestPathsExample <vertex path> <edge path> " +
"<result path> <src vertex> <max iterations>");
return false;
}
vertexInputPath = args[0];
edgeInputPath = args[1];
outputPath = args[2];
srcVertexId = Long.parseLong(args[3]);
maxIterations = Integer.parseInt(args[4]);
} else {
System.out.println("Executing GSA Single Source Shortest Paths example with built-in default data.");
System.out.println(" Provide parameters to read input data from files.");
System.out.println(" See the documentation for the correct format of input files.");
System.out.println(" Usage: GSASingleSourceShortestPathsExample <vertex path> <edge path> "
+ "<result path> <src vertex> <max iterations>");
}
return true;
}
private static DataSet<Vertex<Long, Double>> getVertexDataSet(ExecutionEnvironment env) {
if(fileOutput) {
return env
.readCsvFile(vertexInputPath)
.fieldDelimiter(" ")
.lineDelimiter("\n")
.types(Long.class, Double.class)
.map(new MapFunction<Tuple2<Long, Double>, Vertex<Long, Double>>() {
@Override
public Vertex<Long, Double> map(Tuple2<Long, Double> value) throws Exception {
return new Vertex<Long, Double>(value.f0, value.f1);
}
});
} else {
return SingleSourceShortestPathsData.getDefaultVertexDataSet(env);
}
}
private static DataSet<Edge<Long, Double>> getEdgeDataSet(ExecutionEnvironment env) {
if(fileOutput) {
return env.readCsvFile(edgeInputPath)
.fieldDelimiter(" ")
.lineDelimiter("\n")
.types(Long.class, Long.class, Double.class)
.map(new MapFunction<Tuple3<Long, Long, Double>, Edge<Long, Double>>() {
@Override
public Edge<Long, Double> map(Tuple3<Long, Long, Double> value) throws Exception {
return new Edge<Long, Double>(value.f0, value.f1, value.f2);
}
});
} else {
return SingleSourceShortestPathsData.getDefaultEdgeDataSet(env);
}
}
@Override
public String getDescription() {
return "GSA Single Source Shortest Paths";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.util.Collector;
import java.io.Serializable;
public abstract class ApplyFunction<VV extends Serializable, EV extends Serializable, M> implements Serializable {
public abstract void apply(M message, VV vertexValue);
/**
* Sets the result for the apply function
*
* @param result the result of the apply phase
*/
public void setResult(VV result) {
out.collect(result);
}
/**
* This method is executed once per superstep before the vertex update function is invoked for each vertex.
*
* @throws Exception Exceptions in the pre-superstep phase cause the superstep to fail.
*/
public void preSuperstep() {};
/**
* This method is executed once per superstep after the vertex update function has been invoked for each vertex.
*
* @throws Exception Exceptions in the post-superstep phase cause the superstep to fail.
*/
public void postSuperstep() {};
// --------------------------------------------------------------------------------------------
// Internal methods
// --------------------------------------------------------------------------------------------
private IterationRuntimeContext runtimeContext;
private Collector<VV> out;
public void init(IterationRuntimeContext iterationRuntimeContext) {
this.runtimeContext = iterationRuntimeContext;
};
public void setOutput(Collector<VV> out) {
this.out = out;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import java.io.Serializable;
public abstract class GatherFunction<VV extends Serializable, EV extends Serializable, M> implements Serializable {
public abstract M gather(RichEdge<VV, EV> richEdge);
/**
* This method is executed once per superstep before the vertex update function is invoked for each vertex.
*
* @throws Exception Exceptions in the pre-superstep phase cause the superstep to fail.
*/
public void preSuperstep() {};
/**
* This method is executed once per superstep after the vertex update function has been invoked for each vertex.
*
* @throws Exception Exceptions in the post-superstep phase cause the superstep to fail.
*/
public void postSuperstep() {};
// --------------------------------------------------------------------------------------------
// Internal methods
// --------------------------------------------------------------------------------------------
private IterationRuntimeContext runtimeContext;
public void init(IterationRuntimeContext iterationRuntimeContext) {
this.runtimeContext = iterationRuntimeContext;
};
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.commons.lang3.Validate;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Vertex;
import org.apache.flink.util.Collector;
import java.io.Serializable;
/**
* This class represents iterative graph computations, programmed in a gather-sum-apply perspective.
*
* @param <K> The type of the vertex key in the graph
* @param <VV> The type of the vertex value in the graph
* @param <EV> The type of the edge value in the graph
* @param <M> The intermediate type used by the gather, sum and apply functions
*/
public class GatherSumApplyIteration<K extends Comparable<K> & Serializable,
VV extends Serializable, EV extends Serializable, M> implements CustomUnaryOperation<Vertex<K, VV>,
Vertex<K, VV>> {
private DataSet<Vertex<K, VV>> vertexDataSet;
private DataSet<Edge<K, EV>> edgeDataSet;
private final GatherFunction<VV, EV, M> gather;
private final SumFunction<VV, EV, M> sum;
private final ApplyFunction<VV, EV, M> apply;
private final int maximumNumberOfIterations;
private String name;
private int parallelism = -1;
// ----------------------------------------------------------------------------------
private GatherSumApplyIteration(GatherFunction<VV, EV, M> gather, SumFunction<VV, EV, M> sum,
ApplyFunction<VV, EV, M> apply, DataSet<Edge<K, EV>> edges, int maximumNumberOfIterations) {
Validate.notNull(gather);
Validate.notNull(sum);
Validate.notNull(apply);
Validate.notNull(edges);
Validate.isTrue(maximumNumberOfIterations > 0, "The maximum number of iterations must be at least one.");
this.gather = gather;
this.sum = sum;
this.apply = apply;
this.edgeDataSet = edges;
this.maximumNumberOfIterations = maximumNumberOfIterations;
}
/**
* Sets the name for the gather-sum-apply iteration. The name is displayed in logs and messages.
*
* @param name The name for the iteration.
*/
public void setName(String name) {
this.name = name;
}
/**
* Gets the name from this gather-sum-apply iteration.
*
* @return The name of the iteration.
*/
public String getName() {
return name;
}
/**
* Sets the degree of parallelism for the iteration.
*
* @param parallelism The degree of parallelism.
*/
public void setParallelism(int parallelism) {
Validate.isTrue(parallelism > 0 || parallelism == -1,
"The degree of parallelism must be positive, or -1 (use default).");
this.parallelism = parallelism;
}
/**
* Gets the iteration's degree of parallelism.
*
* @return The iterations parallelism, or -1, if not set.
*/
public int getParallelism() {
return parallelism;
}
// --------------------------------------------------------------------------------------------
// Custom Operator behavior
// --------------------------------------------------------------------------------------------
/**
* Sets the input data set for this operator. In the case of this operator this input data set represents
* the set of vertices with their initial state.
*
* @param dataSet The input data set, which in the case of this operator represents the set of
* vertices with their initial state.
*/
@Override
public void setInput(DataSet<Vertex<K, VV>> dataSet) {
this.vertexDataSet = dataSet;
}
/**
* Computes the results of the gather-sum-apply iteration
*
* @return The resulting DataSet
*/
@Override
public DataSet<Vertex<K, VV>> createResult() {
if (vertexDataSet == null) {
throw new IllegalStateException("The input data set has not been set.");
}
// Prepare type information
TypeInformation<K> keyType = ((TupleTypeInfo<?>) vertexDataSet.getType()).getTypeAt(0);
TypeInformation<M> messageType = TypeExtractor.createTypeInfo(GatherFunction.class, gather.getClass(), 2, null, null);
TypeInformation<Tuple2<K, M>> innerType = new TupleTypeInfo<Tuple2<K, M>>(keyType, messageType);
TypeInformation<Vertex<K, VV>> outputType = vertexDataSet.getType();
// Prepare UDFs
GatherUdf<K, VV, EV, M> gatherUdf = new GatherUdf<K, VV, EV, M>(gather, innerType);
SumUdf<K, VV, EV, M> sumUdf = new SumUdf<K, VV, EV, M>(sum, innerType);
ApplyUdf<K, VV, EV, M> applyUdf = new ApplyUdf<K, VV, EV, M>(apply, outputType);
final int[] zeroKeyPos = new int[] {0};
final DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration =
vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos);
// Prepare the rich edges
DataSet<Tuple2<Vertex<K, VV>, Edge<K, EV>>> richEdges = iteration
.getWorkset()
.join(edgeDataSet)
.where(0)
.equalTo(0);
// Gather, sum and apply
DataSet<Tuple2<K, M>> gatheredSet = richEdges.map(gatherUdf);
DataSet<Tuple2<K, M>> summedSet = gatheredSet.groupBy(0).reduce(sumUdf);
DataSet<Vertex<K, VV>> appliedSet = summedSet
.join(iteration.getSolutionSet())
.where(0)
.equalTo(0)
.with(applyUdf);
return iteration.closeWith(appliedSet, appliedSet);
}
/**
* Creates a new gather-sum-apply iteration operator for graphs
*
* @param edges The edge DataSet
*
* @param gather The gather function of the GSA iteration
* @param sum The sum function of the GSA iteration
* @param apply The apply function of the GSA iteration
*
* @param maximumNumberOfIterations The maximum number of iterations executed
*
* @param <K> The type of the vertex key in the graph
* @param <VV> The type of the vertex value in the graph
* @param <EV> The type of the edge value in the graph
* @param <M> The intermediate type used by the gather, sum and apply functions
*
* @return An in stance of the gather-sum-apply graph computation operator.
*/
public static final <K extends Comparable<K> & Serializable, VV extends Serializable, EV extends Serializable, M>
GatherSumApplyIteration<K, VV, EV, M> withEdges(DataSet<Edge<K, EV>> edges,
GatherFunction<VV, EV, M> gather, SumFunction<VV, EV, M> sum, ApplyFunction<VV, EV, M> apply,
int maximumNumberOfIterations) {
return new GatherSumApplyIteration<K, VV, EV, M>(gather, sum, apply, edges, maximumNumberOfIterations);
}
// --------------------------------------------------------------------------------------------
// Wrapping UDFs
// --------------------------------------------------------------------------------------------
private static final class GatherUdf<K extends Comparable<K> & Serializable, VV extends Serializable,
EV extends Serializable, M> extends RichMapFunction<Tuple2<Vertex<K, VV>, Edge<K, EV>>,
Tuple2<K, M>> implements ResultTypeQueryable<Tuple2<K, M>> {
private final GatherFunction<VV, EV, M> gatherFunction;
private transient TypeInformation<Tuple2<K, M>> resultType;
private GatherUdf(GatherFunction<VV, EV, M> gatherFunction, TypeInformation<Tuple2<K, M>> resultType) {
this.gatherFunction = gatherFunction;
this.resultType = resultType;
}
@Override
public Tuple2<K, M> map(Tuple2<Vertex<K, VV>, Edge<K, EV>> richEdge) throws Exception {
RichEdge<VV, EV> userRichEdge = new RichEdge<VV, EV>(richEdge.f0.getValue(),
richEdge.f1.getValue());
K key = richEdge.f1.getTarget();
M result = this.gatherFunction.gather(userRichEdge);
return new Tuple2<K, M>(key, result);
}
@Override
public void open(Configuration parameters) throws Exception {
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.gatherFunction.init(getIterationRuntimeContext());
}
this.gatherFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.gatherFunction.postSuperstep();
}
@Override
public TypeInformation<Tuple2<K, M>> getProducedType() {
return this.resultType;
}
}
private static final class SumUdf<K extends Comparable<K> & Serializable, VV extends Serializable,
EV extends Serializable, M> extends RichReduceFunction<Tuple2<K, M>>
implements ResultTypeQueryable<Tuple2<K, M>>{
private final SumFunction<VV, EV, M> sumFunction;
private transient TypeInformation<Tuple2<K, M>> resultType;
private SumUdf(SumFunction<VV, EV, M> sumFunction, TypeInformation<Tuple2<K, M>> resultType) {
this.sumFunction = sumFunction;
this.resultType = resultType;
}
@Override
public Tuple2<K, M> reduce(Tuple2<K, M> arg0, Tuple2<K, M> arg1) throws Exception {
K key = arg0.f0;
M result = this.sumFunction.sum(arg0.f1, arg1.f1);
return new Tuple2<K, M>(key, result);
}
@Override
public void open(Configuration parameters) throws Exception {
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.sumFunction.init(getIterationRuntimeContext());
}
this.sumFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.sumFunction.postSuperstep();
}
@Override
public TypeInformation<Tuple2<K, M>> getProducedType() {
return this.resultType;
}
}
private static final class ApplyUdf<K extends Comparable<K> & Serializable,
VV extends Serializable, EV extends Serializable, M> extends RichFlatJoinFunction<Tuple2<K, M>,
Vertex<K, VV>, Vertex<K, VV>> implements ResultTypeQueryable<Vertex<K, VV>> {
private final ApplyFunction<VV, EV, M> applyFunction;
private transient TypeInformation<Vertex<K, VV>> resultType;
private ApplyUdf(ApplyFunction<VV, EV, M> applyFunction, TypeInformation<Vertex<K, VV>> resultType) {
this.applyFunction = applyFunction;
this.resultType = resultType;
}
@Override
public void join(Tuple2<K, M> arg0, Vertex<K, VV> arg1, final Collector<Vertex<K, VV>> out) throws Exception {
final K key = arg1.getId();
Collector<VV> userOut = new Collector<VV>() {
@Override
public void collect(VV record) {
out.collect(new Vertex<K, VV>(key, record));
}
@Override
public void close() {
out.close();
}
};
this.applyFunction.setOutput(userOut);
this.applyFunction.apply(arg0.f1, arg1.getValue());
}
@Override
public void open(Configuration parameters) throws Exception {
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.applyFunction.init(getIterationRuntimeContext());
}
this.applyFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.applyFunction.postSuperstep();
}
@Override
public TypeInformation<Vertex<K, VV>> getProducedType() {
return this.resultType;
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.flink.api.java.tuple.Tuple2;
import java.io.Serializable;
/**
* A wrapper around Tuple3<VV, EV, VV> for convenience in the GatherFunction
* @param <VV> the vertex value type
* @param <EV> the edge value type
*/
public class RichEdge<VV extends Serializable, EV extends Serializable>
extends Tuple2<VV, EV> {
public RichEdge() {}
public RichEdge(VV src, EV edge) {
super(src, edge);
}
public VV getSrcVertexValue() {
return this.f0;
}
public EV getEdgeValue() {
return this.f1;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import java.io.Serializable;
public abstract class SumFunction<VV extends Serializable, EV extends Serializable, M> implements Serializable {
public abstract M sum(M arg0, M arg1);
/**
* This method is executed once per superstep before the vertex update function is invoked for each vertex.
*
* @throws Exception Exceptions in the pre-superstep phase cause the superstep to fail.
*/
public void preSuperstep() {};
/**
* This method is executed once per superstep after the vertex update function has been invoked for each vertex.
*
* @throws Exception Exceptions in the post-superstep phase cause the superstep to fail.
*/
public void postSuperstep() {};
// --------------------------------------------------------------------------------------------
// Internal methods
// --------------------------------------------------------------------------------------------
private IterationRuntimeContext runtimeContext;
public void init(IterationRuntimeContext iterationRuntimeContext) {
this.runtimeContext = iterationRuntimeContext;
};
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.test;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
import org.apache.flink.graph.example.GSAGreedyGraphColoringExample;
import org.apache.flink.graph.example.GSASingleSourceShortestPathsExample;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.File;
@RunWith(Parameterized.class)
public class GatherSumApplyITCase extends MultipleProgramsTestBase {
public GatherSumApplyITCase(TestExecutionMode mode){
super(mode);
}
private String verticesPath;
private String edgesPath;
private String resultPath;
private String expectedResult;
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
@Before
public void before() throws Exception{
resultPath = tempFolder.newFile().toURI().toString();
File verticesFile = tempFolder.newFile();
Files.write(GatherSumApplyITCase.VERTICES, verticesFile, Charsets.UTF_8);
File edgesFile = tempFolder.newFile();
Files.write(GatherSumApplyITCase.EDGES, edgesFile, Charsets.UTF_8);
verticesPath = verticesFile.toURI().toString();
edgesPath = edgesFile.toURI().toString();
}
@After
public void after() throws Exception{
compareResultsByLinesInMemory(expectedResult, resultPath);
}
// --------------------------------------------------------------------------------------------
// Greedy Graph Coloring Test
// --------------------------------------------------------------------------------------------
@Test
public void testGreedyGraphColoring() throws Exception {
GSAGreedyGraphColoringExample.main(new String[] {verticesPath, edgesPath, resultPath, "16"});
expectedResult = "1 1.0\n" +
"2 1.0\n" +
"3 1.0\n" +
"4 1.0\n" +
"5 1.0\n";
}
// --------------------------------------------------------------------------------------------
// Single Source Shortest Path Test
// --------------------------------------------------------------------------------------------
@Test
public void testSingleSourceShortestPath() throws Exception {
GSASingleSourceShortestPathsExample.main(new String[]{verticesPath, edgesPath, resultPath, "1", "16"});
expectedResult = "1 0.0\n" +
"2 12.0\n" +
"3 13.0\n" +
"4 47.0\n" +
"5 48.0\n";
}
// --------------------------------------------------------------------------------------------
// Sample data
// --------------------------------------------------------------------------------------------
private static final String VERTICES = "1 1.0\n" +
"2 2.0\n" +
"3 3.0\n" +
"4 4.0\n" +
"5 5.0\n";
private static final String EDGES = "1 2 12.0\n" +
"1 3 13.0\n" +
"2 3 23.0\n" +
"3 4 34.0\n" +
"3 5 35.0\n" +
"4 5 45.0\n" +
"5 1 51.0\n";
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册