diff --git a/docs/libs/gelly_guide.md b/docs/libs/gelly_guide.md index 77621e6b155bff052ad7e1c326b3a52b62f95aee..abd489cd7f79297fa1d7d0423d85e93ab5f47500 100644 --- a/docs/libs/gelly_guide.md +++ b/docs/libs/gelly_guide.md @@ -373,7 +373,7 @@ public static final class MinDistanceMessenger {...} {% endhighlight %} ### Configuring a Vertex-Centric Iteration -A vertex-centric iteration can be configured using an `IterationConfiguration` object. +A vertex-centric iteration can be configured using a `VertexCentricConfiguration` object. Currently, the following parameters can be specified: * Name: The name for the vertex-centric iteration. The name is displayed in logs and messages @@ -393,7 +393,7 @@ all aggregates globally once per superstep and makes them available in the next Graph graph = ... // configure the iteration -IterationConfiguration parameters = new IterationConfiguration(); +VertexCentricConfiguration parameters = new VertexCentricConfiguration(); // set the iteration name parameters.setName("Gelly Iteration"); diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java index 0ddd2d420ffa6881c2fc9bb01570a8d9286ba379..6e5b8f14ecf4737de489bf0683dde756743e92f5 100755 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java @@ -45,11 +45,12 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.graph.gsa.ApplyFunction; +import org.apache.flink.graph.gsa.GSAConfiguration; import org.apache.flink.graph.gsa.GatherFunction; import org.apache.flink.graph.gsa.GatherSumApplyIteration; import org.apache.flink.graph.gsa.SumFunction; -import org.apache.flink.graph.spargel.IterationConfiguration; import org.apache.flink.graph.spargel.MessagingFunction; +import org.apache.flink.graph.spargel.VertexCentricConfiguration; import org.apache.flink.graph.spargel.VertexCentricIteration; import org.apache.flink.graph.spargel.VertexUpdateFunction; import org.apache.flink.graph.utils.EdgeToTuple3Map; @@ -1173,7 +1174,7 @@ public class Graph { public Graph runVertexCentricIteration( VertexUpdateFunction vertexUpdateFunction, MessagingFunction messagingFunction, - int maximumNumberOfIterations, IterationConfiguration parameters) { + int maximumNumberOfIterations, VertexCentricConfiguration parameters) { VertexCentricIteration iteration = VertexCentricIteration.withEdges( edges, vertexUpdateFunction, messagingFunction, maximumNumberOfIterations); @@ -1195,16 +1196,40 @@ public class Graph { * @param maximumNumberOfIterations maximum number of iterations to perform * @param the intermediate type used between gather, sum and apply * - * @return the updated Graph after the vertex-centric iteration has converged or + * @return the updated Graph after the gather-sum-apply iteration has converged or * after maximumNumberOfIterations. */ public Graph runGatherSumApplyIteration( GatherFunction gatherFunction, SumFunction sumFunction, ApplyFunction applyFunction, int maximumNumberOfIterations) { + return this.runGatherSumApplyIteration(gatherFunction, sumFunction, applyFunction, + maximumNumberOfIterations, null); + } + + /** + * Runs a Gather-Sum-Apply iteration on the graph with configuration options. + * + * @param gatherFunction the gather function collects information about adjacent vertices and edges + * @param sumFunction the sum function aggregates the gathered information + * @param applyFunction the apply function updates the vertex values with the aggregates + * @param maximumNumberOfIterations maximum number of iterations to perform + * @param parameters the iteration configuration parameters + * @param the intermediate type used between gather, sum and apply + * + * @return the updated Graph after the gather-sum-apply iteration has converged or + * after maximumNumberOfIterations. + */ + public Graph runGatherSumApplyIteration( + GatherFunction gatherFunction, SumFunction sumFunction, + ApplyFunction applyFunction, int maximumNumberOfIterations, + GSAConfiguration parameters) { + GatherSumApplyIteration iteration = GatherSumApplyIteration.withEdges( edges, gatherFunction, sumFunction, applyFunction, maximumNumberOfIterations); + iteration.configure(parameters); + DataSet> newVertices = vertices.runOperation(iteration); return new Graph(newVertices, this.edges, this.context); diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/IterationConfiguration.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/IterationConfiguration.java similarity index 57% rename from flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/IterationConfiguration.java rename to flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/IterationConfiguration.java index f38c687061d8270ce4e9880352cb0759a4b08f7e..e1d4a1e4d8f445832c0c1d94950c2b66499f624e 100644 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/IterationConfiguration.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/IterationConfiguration.java @@ -16,32 +16,19 @@ * limitations under the License. */ -package org.apache.flink.graph.spargel; +package org.apache.flink.graph; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.apache.flink.api.common.aggregators.Aggregator; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.tuple.Tuple2; import com.google.common.base.Preconditions; /** - * This class is used to configure a vertex-centric iteration. - * - * An IterationConfiguration object can be used to set the iteration name and - * degree of parallelism, to register aggregators and use broadcast sets in - * the {@link VertexUpdateFunction} and {@link MessagingFunction}. - * - * The IterationConfiguration object is passed as an argument to - * {@link org.apache.flink.graph.Graph#runVertexCentricIteration( - * VertexUpdateFunction, MessagingFunction, int, IterationConfiguration)}. - * + * This is used as a base class for vertex-centric iteration or gather-sum-apply iteration configuration. */ -public class IterationConfiguration { +public abstract class IterationConfiguration { /** the iteration name **/ private String name; @@ -52,20 +39,13 @@ public class IterationConfiguration { /** the iteration aggregators **/ private Map> aggregators = new HashMap>(); - /** the broadcast variables for the update function **/ - private List>> bcVarsUpdate = new ArrayList>>(); - - /** the broadcast variables for the messaging function **/ - private List>> bcVarsMessaging = new ArrayList>>(); - /** flag that defines whether the solution set is kept in managed memory **/ private boolean unmanagedSolutionSet = false; public IterationConfiguration() {} - /** - * Sets the name for the vertex-centric iteration. The name is displayed in logs and messages. + * Sets the name for the iteration. The name is displayed in logs and messages. * * @param name The name for the iteration. */ @@ -74,7 +54,7 @@ public class IterationConfiguration { } /** - * Gets the name of the vertex-centric iteration. + * Gets the name of the iteration. * @param defaultName * * @return The name of the iteration. @@ -131,8 +111,8 @@ public class IterationConfiguration { /** * Registers a new aggregator. Aggregators registered here are available during the execution of the vertex updates - * via {@link VertexUpdateFunction#getIterationAggregator(String)} and - * {@link VertexUpdateFunction#getPreviousIterationAggregate(String)}. + * via {@link org.apache.flink.graph.spargel.VertexUpdateFunction#getIterationAggregator(String)} and + * {@link org.apache.flink.graph.spargel.VertexUpdateFunction#getPreviousIterationAggregate(String)}. * * @param name The name of the aggregator, used to retrieve it and its aggregates during execution. * @param aggregator The aggregator. @@ -140,26 +120,6 @@ public class IterationConfiguration { public void registerAggregator(String name, Aggregator aggregator) { this.aggregators.put(name, aggregator); } - - /** - * Adds a data set as a broadcast set to the messaging function. - * - * @param name The name under which the broadcast data is available in the messaging function. - * @param data The data set to be broadcasted. - */ - public void addBroadcastSetForMessagingFunction(String name, DataSet data) { - this.bcVarsMessaging.add(new Tuple2>(name, data)); - } - - /** - * Adds a data set as a broadcast set to the vertex update function. - * - * @param name The name under which the broadcast data is available in the vertex update function. - * @param data The data set to be broadcasted. - */ - public void addBroadcastSetForUpdateFunction(String name, DataSet data) { - this.bcVarsUpdate.add(new Tuple2>(name, data)); - } /** * Gets the set of aggregators that are registered for this vertex-centric iteration. @@ -170,24 +130,4 @@ public class IterationConfiguration { public Map> getAggregators() { return this.aggregators; } - - /** - * Get the broadcast variables of the VertexUpdateFunction. - * - * @return a List of Tuple2, where the first field is the broadcast variable name - * and the second field is the broadcast DataSet. - */ - public List>> getUpdateBcastVars() { - return this.bcVarsUpdate; - } - - /** - * Get the broadcast variables of the MessagingFunction. - * - * @return a List of Tuple2, where the first field is the broadcast variable name - * and the second field is the broadcast DataSet. - */ - public List>> getMessagingBcastVars() { - return this.bcVarsMessaging; - } } diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/ApplyFunction.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/ApplyFunction.java index 7d24253d7581beeafd9fd93abf97a1890d24f108..d88fe0d5cdd3052037268321965d90d64d90a7f7 100755 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/ApplyFunction.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/ApplyFunction.java @@ -18,11 +18,14 @@ package org.apache.flink.graph.gsa; +import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.functions.IterationRuntimeContext; import org.apache.flink.graph.Vertex; +import org.apache.flink.types.Value; import org.apache.flink.util.Collector; import java.io.Serializable; +import java.util.Collection; @SuppressWarnings("serial") public abstract class ApplyFunction implements Serializable { @@ -62,6 +65,38 @@ public abstract class ApplyFunction implements Serializable { return this.runtimeContext.getSuperstepNumber(); } + /** + * Gets the iteration aggregator registered under the given name. The iteration aggregator combines + * all aggregates globally once per superstep and makes them available in the next superstep. + * + * @param name The name of the aggregator. + * @return The aggregator registered under this name, or null, if no aggregator was registered. + */ + public > T getIterationAggregator(String name) { + return this.runtimeContext.getIterationAggregator(name); + } + + /** + * Get the aggregated value that an aggregator computed in the previous iteration. + * + * @param name The name of the aggregator. + * @return The aggregated value of the previous iteration. + */ + public T getPreviousIterationAggregate(String name) { + return this.runtimeContext.getPreviousIterationAggregate(name); + } + + /** + * Gets the broadcast data set registered under the given name. Broadcast data sets + * are available on all parallel instances of a function. + * + * @param name The name under which the broadcast set is registered. + * @return The broadcast data set. + */ + public Collection getBroadcastSet(String name) { + return this.runtimeContext.getBroadcastVariable(name); + } + // -------------------------------------------------------------------------------------------- // Internal methods // -------------------------------------------------------------------------------------------- diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GSAConfiguration.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GSAConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..de472808129e7ab63787e012f05437dc1f50e44b --- /dev/null +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GSAConfiguration.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.graph.gsa; + +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.graph.IterationConfiguration; + +import java.util.ArrayList; +import java.util.List; + +/** + * A GSAConfiguration object can be used to set the iteration name and + * degree of parallelism, to register aggregators and use broadcast sets in + * the {@link org.apache.flink.graph.gsa.GatherFunction}, {@link org.apache.flink.graph.gsa.SumFunction} as well as + * {@link org.apache.flink.graph.gsa.ApplyFunction}. + * + * The GSAConfiguration object is passed as an argument to + * {@link org.apache.flink.graph.Graph#runGatherSumApplyIteration(org.apache.flink.graph.gsa.GatherFunction, + * org.apache.flink.graph.gsa.SumFunction, org.apache.flink.graph.gsa.ApplyFunction, int)} + */ +public class GSAConfiguration extends IterationConfiguration { + + /** the broadcast variables for the gather function **/ + private List>> bcVarsGather = new ArrayList>>(); + + /** the broadcast variables for the sum function **/ + private List>> bcVarsSum = new ArrayList>>(); + + /** the broadcast variables for the apply function **/ + private List>> bcVarsApply = new ArrayList>>(); + + public GSAConfiguration() {} + + /** + * Adds a data set as a broadcast set to the gather function. + * + * @param name The name under which the broadcast data is available in the gather function. + * @param data The data set to be broadcasted. + */ + public void addBroadcastSetForGatherFunction(String name, DataSet data) { + this.bcVarsGather.add(new Tuple2>(name, data)); + } + + /** + * Adds a data set as a broadcast set to the sum function. + * + * @param name The name under which the broadcast data is available in the sum function. + * @param data The data set to be broadcasted. + */ + public void addBroadcastSetForSumFunction(String name, DataSet data) { + this.bcVarsSum.add(new Tuple2>(name, data)); + } + + /** + * Adds a data set as a broadcast set to the apply function. + * + * @param name The name under which the broadcast data is available in the apply function. + * @param data The data set to be broadcasted. + */ + public void addBroadcastSetForApplyFunction(String name, DataSet data) { + this.bcVarsApply.add(new Tuple2>(name, data)); + } + + /** + * Get the broadcast variables of the GatherFunction. + * + * @return a List of Tuple2, where the first field is the broadcast variable name + * and the second field is the broadcast DataSet. + */ + public List>> getGatherBcastVars() { + return this.bcVarsGather; + } + + /** + * Get the broadcast variables of the SumFunction. + * + * @return a List of Tuple2, where the first field is the broadcast variable name + * and the second field is the broadcast DataSet. + */ + public List>> getSumBcastVars() { + return this.bcVarsSum; + } + + /** + * Get the broadcast variables of the ApplyFunction. + * + * @return a List of Tuple2, where the first field is the broadcast variable name + * and the second field is the broadcast DataSet. + */ + public List>> getApplyBcastVars() { + return this.bcVarsApply; + } +} diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherFunction.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherFunction.java index 1c4b2c429f24b6043d7d95ed814576c35406333a..37ff2d696a2df7f06d9343d61ac9f62c4337f04c 100755 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherFunction.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherFunction.java @@ -18,9 +18,12 @@ package org.apache.flink.graph.gsa; +import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.functions.IterationRuntimeContext; +import org.apache.flink.types.Value; import java.io.Serializable; +import java.util.Collection; @SuppressWarnings("serial") public abstract class GatherFunction implements Serializable { @@ -50,6 +53,38 @@ public abstract class GatherFunction implements Serializable { return this.runtimeContext.getSuperstepNumber(); } + /** + * Gets the iteration aggregator registered under the given name. The iteration aggregator combines + * all aggregates globally once per superstep and makes them available in the next superstep. + * + * @param name The name of the aggregator. + * @return The aggregator registered under this name, or null, if no aggregator was registered. + */ + public > T getIterationAggregator(String name) { + return this.runtimeContext.getIterationAggregator(name); + } + + /** + * Get the aggregated value that an aggregator computed in the previous iteration. + * + * @param name The name of the aggregator. + * @return The aggregated value of the previous iteration. + */ + public T getPreviousIterationAggregate(String name) { + return this.runtimeContext.getPreviousIterationAggregate(name); + } + + /** + * Gets the broadcast data set registered under the given name. Broadcast data sets + * are available on all parallel instances of a function. + * + * @param name The name under which the broadcast set is registered. + * @return The broadcast data set. + */ + public Collection getBroadcastSet(String name) { + return this.runtimeContext.getBroadcastVariable(name); + } + // -------------------------------------------------------------------------------------------- // Internal methods // -------------------------------------------------------------------------------------------- diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java index 7476ea912d35653333f64aafb62a84b48f12bd32..a80369d16ac0736a1b22af101c6b5fe7d0eabc99 100755 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java @@ -18,6 +18,7 @@ package org.apache.flink.graph.gsa; +import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.common.functions.RichMapFunction; @@ -29,6 +30,8 @@ import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSec import org.apache.flink.api.java.operators.CustomUnaryOperation; import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.operators.JoinOperator; +import org.apache.flink.api.java.operators.MapOperator; +import org.apache.flink.api.java.operators.ReduceOperator; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.TupleTypeInfo; @@ -37,6 +40,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.graph.Edge; import org.apache.flink.graph.Vertex; import org.apache.flink.util.Collector; +import java.util.Map; import com.google.common.base.Preconditions; @@ -59,6 +63,8 @@ public class GatherSumApplyIteration implements CustomUnaryOperati private final ApplyFunction apply; private final int maximumNumberOfIterations; + private GSAConfiguration configuration; + // ---------------------------------------------------------------------------------- private GatherSumApplyIteration(GatherFunction gather, SumFunction sum, @@ -119,20 +125,69 @@ public class GatherSumApplyIteration implements CustomUnaryOperati final DeltaIteration, Vertex> iteration = vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos); + // set up the iteration operator + if (this.configuration != null) { + + iteration.name(this.configuration.getName( + "Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")")); + iteration.parallelism(this.configuration.getParallelism()); + iteration.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory()); + + // register all aggregators + for (Map.Entry> entry : this.configuration.getAggregators().entrySet()) { + iteration.registerAggregator(entry.getKey(), entry.getValue()); + } + } + else { + // no configuration provided; set default name + iteration.name("Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")"); + } + // Prepare the neighbors DataSet>> neighbors = iteration .getWorkset().join(edgeDataSet) .where(0).equalTo(0).with(new ProjectKeyWithNeighbor()); // Gather, sum and apply - DataSet> gatheredSet = neighbors.map(gatherUdf); - DataSet> summedSet = gatheredSet.groupBy(0).reduce(sumUdf); + MapOperator>, Tuple2> gatherMapOperator = neighbors.map(gatherUdf); + + // configure map gather function with name and broadcast variables + gatherMapOperator = gatherMapOperator.name("Gather"); + + if (this.configuration != null) { + for (Tuple2> e : this.configuration.getGatherBcastVars()) { + gatherMapOperator = gatherMapOperator.withBroadcastSet(e.f1, e.f0); + } + } + DataSet> gatheredSet = gatherMapOperator; + + ReduceOperator> sumReduceOperator = gatheredSet.groupBy(0).reduce(sumUdf); + + // configure reduce sum function with name and broadcast variables + sumReduceOperator = sumReduceOperator.name("Sum"); + + if (this.configuration != null) { + for (Tuple2> e : this.configuration.getSumBcastVars()) { + sumReduceOperator = sumReduceOperator.withBroadcastSet(e.f1, e.f0); + } + } + DataSet> summedSet = sumReduceOperator; + JoinOperator> appliedSet = summedSet .join(iteration.getSolutionSet()) .where(0) .equalTo(0) .with(applyUdf); + // configure join apply function with name and broadcast variables + appliedSet = appliedSet.name("Apply"); + + if (this.configuration != null) { + for (Tuple2> e : this.configuration.getApplyBcastVars()) { + appliedSet = appliedSet.withBroadcastSet(e.f1, e.f0); + } + } + // let the operator know that we preserve the key field appliedSet.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0"); @@ -293,4 +348,19 @@ public class GatherSumApplyIteration implements CustomUnaryOperati } } + /** + * Configures this gather-sum-apply iteration with the provided parameters. + * + * @param parameters the configuration parameters + */ + public void configure(GSAConfiguration parameters) { + this.configuration = parameters; + } + + /** + * @return the configuration parameters of this gather-sum-apply iteration + */ + public GSAConfiguration getIterationConfiguration() { + return this.configuration; + } } diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/SumFunction.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/SumFunction.java index 0a5e4aee06a7236e6b4783c25c14db14f75ccb48..16cd682e2f884dd5087df72a9c088a778dbeb174 100755 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/SumFunction.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/gsa/SumFunction.java @@ -18,9 +18,12 @@ package org.apache.flink.graph.gsa; +import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.functions.IterationRuntimeContext; +import org.apache.flink.types.Value; import java.io.Serializable; +import java.util.Collection; @SuppressWarnings("serial") public abstract class SumFunction implements Serializable { @@ -50,6 +53,38 @@ public abstract class SumFunction implements Serializable { return this.runtimeContext.getSuperstepNumber(); } + /** + * Gets the iteration aggregator registered under the given name. The iteration aggregator combines + * all aggregates globally once per superstep and makes them available in the next superstep. + * + * @param name The name of the aggregator. + * @return The aggregator registered under this name, or null, if no aggregator was registered. + */ + public > T getIterationAggregator(String name) { + return this.runtimeContext.getIterationAggregator(name); + } + + /** + * Get the aggregated value that an aggregator computed in the previous iteration. + * + * @param name The name of the aggregator. + * @return The aggregated value of the previous iteration. + */ + public T getPreviousIterationAggregate(String name) { + return this.runtimeContext.getPreviousIterationAggregate(name); + } + + /** + * Gets the broadcast data set registered under the given name. Broadcast data sets + * are available on all parallel instances of a function. + * + * @param name The name under which the broadcast set is registered. + * @return The broadcast data set. + */ + public Collection getBroadcastSet(String name) { + return this.runtimeContext.getBroadcastVariable(name); + } + // -------------------------------------------------------------------------------------------- // Internal methods // -------------------------------------------------------------------------------------------- diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/MessagingFunction.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/MessagingFunction.java index b7e74e38dfc108ab29185f41dbfaff60ba7b9124..ab60f15f1d6866535d3277ea9e88baae54820aed 100644 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/MessagingFunction.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/MessagingFunction.java @@ -136,7 +136,7 @@ public abstract class MessagingFunction>> bcVarsUpdate = new ArrayList>>(); + + /** the broadcast variables for the messaging function **/ + private List>> bcVarsMessaging = new ArrayList>>(); + + public VertexCentricConfiguration() {} + + /** + * Adds a data set as a broadcast set to the messaging function. + * + * @param name The name under which the broadcast data is available in the messaging function. + * @param data The data set to be broadcasted. + */ + public void addBroadcastSetForMessagingFunction(String name, DataSet data) { + this.bcVarsMessaging.add(new Tuple2>(name, data)); + } + + /** + * Adds a data set as a broadcast set to the vertex update function. + * + * @param name The name under which the broadcast data is available in the vertex update function. + * @param data The data set to be broadcasted. + */ + public void addBroadcastSetForUpdateFunction(String name, DataSet data) { + this.bcVarsUpdate.add(new Tuple2>(name, data)); + } + + /** + * Get the broadcast variables of the VertexUpdateFunction. + * + * @return a List of Tuple2, where the first field is the broadcast variable name + * and the second field is the broadcast DataSet. + */ + public List>> getUpdateBcastVars() { + return this.bcVarsUpdate; + } + + /** + * Get the broadcast variables of the MessagingFunction. + * + * @return a List of Tuple2, where the first field is the broadcast variable name + * and the second field is the broadcast DataSet. + */ + public List>> getMessagingBcastVars() { + return this.bcVarsMessaging; + } + + +} diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexCentricIteration.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexCentricIteration.java index 8488be695c570440db793bddc87df181d3bc826b..79da664aa98dca08ef84be92c28801567ff10285 100644 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexCentricIteration.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexCentricIteration.java @@ -60,9 +60,7 @@ import com.google.common.base.Preconditions; * *

* Vertex-centric graph iterations are instantiated by the - * {@link #withPlainEdges(DataSet, VertexUpdateFunction, MessagingFunction, int)} method, or the - * {@link #withValuedEdges(DataSet, VertexUpdateFunction, MessagingFunction, int)} method, depending on whether - * the graph's edges are carrying values. + * {@link #withEdges(DataSet, VertexUpdateFunction, MessagingFunction, int)} method. * * @param The type of the vertex key (the vertex identifier). * @param The type of the vertex value (the state of the vertex). @@ -84,7 +82,7 @@ public class VertexCentricIteration private DataSet> initialVertices; - private IterationConfiguration configuration; + private VertexCentricConfiguration configuration; // ---------------------------------------------------------------------------------- @@ -362,14 +360,14 @@ public class VertexCentricIteration * * @param parameters the configuration parameters */ - public void configure(IterationConfiguration parameters) { + public void configure(VertexCentricConfiguration parameters) { this.configuration = parameters; } /** * @return the configuration parameters of this vertex-centric iteration */ - public IterationConfiguration getIterationConfiguration() { + public VertexCentricConfiguration getIterationConfiguration() { return this.configuration; } } diff --git a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexUpdateFunction.java b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexUpdateFunction.java index 561c87a5c25443350c13cbfc77db57c221ee47c6..9122053b419aaa8f46b3b41b00d026a2e1ed4cfc 100644 --- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexUpdateFunction.java +++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/spargel/VertexUpdateFunction.java @@ -114,7 +114,7 @@ public abstract class VertexUpdateFunction impl /** * Gets the broadcast data set registered under the given name. Broadcast data sets * are available on all parallel instances of a function. They can be registered via - * {@link VertexCentricIteration#addBroadcastSetForUpdateFunction(String, org.apache.flink.api.java.DataSet)}. + * {@link org.apache.flink.graph.spargel.VertexCentricConfiguration#addBroadcastSetForUpdateFunction(String, org.apache.flink.api.java.DataSet)}. * * @param name The name under which the broadcast set is registered. * @return The broadcast data set. diff --git a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/GatherSumApplyConfigurationITCase.java b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/GatherSumApplyConfigurationITCase.java new file mode 100644 index 0000000000000000000000000000000000000000..5f5f8b29d562fac9a2008e85cce4f3346e205b15 --- /dev/null +++ b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/GatherSumApplyConfigurationITCase.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.graph.test; + +import org.apache.flink.api.common.aggregators.LongSumAggregator; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.gsa.ApplyFunction; +import org.apache.flink.graph.gsa.GSAConfiguration; +import org.apache.flink.graph.gsa.GatherFunction; +import org.apache.flink.graph.gsa.GatherSumApplyIteration; +import org.apache.flink.graph.gsa.Neighbor; +import org.apache.flink.graph.gsa.SumFunction; +import org.apache.flink.graph.IterationConfiguration; +import org.apache.flink.test.util.MultipleProgramsTestBase; +import org.apache.flink.types.LongValue; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.List; + +@RunWith(Parameterized.class) +public class GatherSumApplyConfigurationITCase extends MultipleProgramsTestBase { + + public GatherSumApplyConfigurationITCase(TestExecutionMode mode) { + super(mode); + } + + private String resultPath; + private String expectedResult; + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Before + public void before() throws Exception{ + resultPath = tempFolder.newFile().toURI().toString(); + } + + @After + public void after() throws Exception{ + compareResultsByLinesInMemory(expectedResult, resultPath); + } + + @Test + public void testRunWithConfiguration() throws Exception { + /* + * Test Graph's runGatherSumApplyIteration when configuration parameters are provided + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + Graph graph = Graph.fromCollection(TestGraphUtils.getLongLongVertices(), + TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper()); + + // create the configuration object + GSAConfiguration parameters = new GSAConfiguration(); + + parameters.addBroadcastSetForGatherFunction("gatherBcastSet", env.fromElements(1, 2, 3)); + parameters.addBroadcastSetForSumFunction("sumBcastSet", env.fromElements(4, 5, 6)); + parameters.addBroadcastSetForApplyFunction("applyBcastSet", env.fromElements(7, 8, 9)); + parameters.registerAggregator("superstepAggregator", new LongSumAggregator()); + + Graph result = graph.runGatherSumApplyIteration(new Gather(), new Sum(), + new Apply(), 10, parameters); + + result.getVertices().writeAsCsv(resultPath, "\n", "\t"); + env.execute(); + + expectedResult = "1 11\n" + + "2 11\n" + + "3 11\n" + + "4 11\n" + + "5 11"; + } + + @Test + public void testIterationConfiguration() throws Exception { + + /* + * Test name, parallelism and solutionSetUnmanaged parameters + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + GatherSumApplyIteration iteration = GatherSumApplyIteration + .withEdges(TestGraphUtils.getLongLongEdgeData(env), new DummyGather(), + new DummySum(), new DummyApply(), 10); + + GSAConfiguration parameters = new GSAConfiguration(); + parameters.setName("gelly iteration"); + parameters.setParallelism(2); + parameters.setSolutionSetUnmanagedMemory(true); + + iteration.configure(parameters); + + Assert.assertEquals("gelly iteration", iteration.getIterationConfiguration().getName("")); + Assert.assertEquals(2, iteration.getIterationConfiguration().getParallelism()); + Assert.assertEquals(true, iteration.getIterationConfiguration().isSolutionSetUnmanagedMemory()); + + DataSet> result = TestGraphUtils.getLongLongVertexData(env).runOperation(iteration); + + result.writeAsCsv(resultPath, "\n", "\t"); + env.execute(); + expectedResult = "1 11\n" + + "2 12\n" + + "3 13\n" + + "4 14\n" + + "5 15"; + } + + @SuppressWarnings("serial") + private static final class Gather extends GatherFunction { + + @Override + public void preSuperstep() { + + // test bcast variable + @SuppressWarnings("unchecked") + List> bcastSet = (List>)(List)getBroadcastSet("gatherBcastSet"); + Assert.assertEquals(1, bcastSet.get(0)); + Assert.assertEquals(2, bcastSet.get(1)); + Assert.assertEquals(3, bcastSet.get(2)); + + // test aggregator + if (getSuperstepNumber() == 2) { + long aggrValue = ((LongValue)getPreviousIterationAggregate("superstepAggregator")).getValue(); + + Assert.assertEquals(7, aggrValue); + } + } + + public Long gather(Neighbor neighbor) { + return neighbor.getNeighborValue(); + } + } + + @SuppressWarnings("serial") + private static final class Sum extends SumFunction { + + LongSumAggregator aggregator = new LongSumAggregator(); + + @Override + public void preSuperstep() { + + // test bcast variable + @SuppressWarnings("unchecked") + List> bcastSet = (List>)(List)getBroadcastSet("sumBcastSet"); + Assert.assertEquals(4, bcastSet.get(0)); + Assert.assertEquals(5, bcastSet.get(1)); + Assert.assertEquals(6, bcastSet.get(2)); + + // test aggregator + aggregator = getIterationAggregator("superstepAggregator"); + } + + public Long sum(Long newValue, Long currentValue) { + long superstep = getSuperstepNumber(); + aggregator.aggregate(superstep); + return 0l; + } + } + + @SuppressWarnings("serial") + private static final class Apply extends ApplyFunction { + + LongSumAggregator aggregator = new LongSumAggregator(); + + @Override + public void preSuperstep() { + + // test bcast variable + @SuppressWarnings("unchecked") + List> bcastSet = (List>)(List)getBroadcastSet("applyBcastSet"); + Assert.assertEquals(7, bcastSet.get(0)); + Assert.assertEquals(8, bcastSet.get(1)); + Assert.assertEquals(9, bcastSet.get(2)); + + // test aggregator + aggregator = getIterationAggregator("superstepAggregator"); + } + + public void apply(Long summedValue, Long origValue) { + long superstep = getSuperstepNumber(); + aggregator.aggregate(superstep); + setResult(origValue + 1); + } + } + + @SuppressWarnings("serial") + private static final class DummyGather extends GatherFunction { + + public Long gather(Neighbor neighbor) { + return neighbor.getNeighborValue(); + } + } + + @SuppressWarnings("serial") + private static final class DummySum extends SumFunction { + + public Long sum(Long newValue, Long currentValue) { + return 0l; + } + } + + @SuppressWarnings("serial") + private static final class DummyApply extends ApplyFunction { + + public void apply(Long summedValue, Long origValue) { + setResult(origValue + 1); + } + } + + @SuppressWarnings("serial") + public static final class AssignOneMapper implements MapFunction, Long> { + + public Long map(Vertex value) { + return 1l; + } + } +} diff --git a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/VertexCentricConfigurationITCase.java b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/VertexCentricConfigurationITCase.java index b49707085cc3df7f7b5c64eb7cff08e0e5163335..4e8412cadf94899542f804f76c83c33216d6b220 100644 --- a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/VertexCentricConfigurationITCase.java +++ b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/VertexCentricConfigurationITCase.java @@ -27,12 +27,12 @@ import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; -import org.apache.flink.graph.spargel.IterationConfiguration; +import org.apache.flink.graph.IterationConfiguration; import org.apache.flink.graph.spargel.MessageIterator; import org.apache.flink.graph.spargel.MessagingFunction; +import org.apache.flink.graph.spargel.VertexCentricConfiguration; import org.apache.flink.graph.spargel.VertexCentricIteration; import org.apache.flink.graph.spargel.VertexUpdateFunction; -import org.apache.flink.graph.utils.VertexToTuple2Map; import org.apache.flink.test.util.MultipleProgramsTestBase; import org.apache.flink.types.LongValue; import org.junit.After; @@ -78,7 +78,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase { TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper()); // create the configuration object - IterationConfiguration parameters = new IterationConfiguration(); + VertexCentricConfiguration parameters = new VertexCentricConfiguration(); parameters.addBroadcastSetForUpdateFunction("updateBcastSet", env.fromElements(1, 2, 3)); parameters.addBroadcastSetForMessagingFunction("messagingBcastSet", env.fromElements(4, 5, 6)); @@ -87,7 +87,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase { Graph result = graph.runVertexCentricIteration( new UpdateFunction(), new MessageFunction(), 10, parameters); - result.getVertices().map(new VertexToTuple2Map()).writeAsCsv(resultPath, "\n", "\t"); + result.getVertices().writeAsCsv(resultPath, "\n", "\t"); env.execute(); expectedResult = "1 11\n" + "2 11\n" + @@ -108,7 +108,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase { .withEdges(TestGraphUtils.getLongLongEdgeData(env), new DummyUpdateFunction(), new DummyMessageFunction(), 10); - IterationConfiguration parameters = new IterationConfiguration(); + VertexCentricConfiguration parameters = new VertexCentricConfiguration(); parameters.setName("gelly iteration"); parameters.setParallelism(2); parameters.setSolutionSetUnmanagedMemory(true); @@ -121,7 +121,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase { DataSet> result = TestGraphUtils.getLongLongVertexData(env).runOperation(iteration); - result.map(new VertexToTuple2Map()).writeAsCsv(resultPath, "\n", "\t"); + result.writeAsCsv(resultPath, "\n", "\t"); env.execute(); expectedResult = "1 11\n" + "2 12\n" +