提交 562f8a31 编写于 作者: A andralungu 提交者: vasia

[FLINK-1942] [gelly] GSA Iteration Configuration

This squashes the following commits:

Split Iteration Configuration into VertexCentric/GSAConf

Fixed checkstyle errors

This closes #635
上级 6dd241d3
......@@ -373,7 +373,7 @@ public static final class MinDistanceMessenger {...}
{% endhighlight %}
### Configuring a Vertex-Centric Iteration
A vertex-centric iteration can be configured using an `IterationConfiguration` object.
A vertex-centric iteration can be configured using a `VertexCentricConfiguration` object.
Currently, the following parameters can be specified:
* <strong>Name</strong>: The name for the vertex-centric iteration. The name is displayed in logs and messages
......@@ -393,7 +393,7 @@ all aggregates globally once per superstep and makes them available in the next
Graph<Long, Double, Double> graph = ...
// configure the iteration
IterationConfiguration parameters = new IterationConfiguration();
VertexCentricConfiguration parameters = new VertexCentricConfiguration();
// set the iteration name
parameters.setName("Gelly Iteration");
......
......@@ -45,11 +45,12 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GSAConfiguration;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.spargel.IterationConfiguration;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexCentricConfiguration;
import org.apache.flink.graph.spargel.VertexCentricIteration;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.utils.EdgeToTuple3Map;
......@@ -1173,7 +1174,7 @@ public class Graph<K, VV, EV> {
public <M> Graph<K, VV, EV> runVertexCentricIteration(
VertexUpdateFunction<K, VV, M> vertexUpdateFunction,
MessagingFunction<K, VV, M, EV> messagingFunction,
int maximumNumberOfIterations, IterationConfiguration parameters) {
int maximumNumberOfIterations, VertexCentricConfiguration parameters) {
VertexCentricIteration<K, VV, M, EV> iteration = VertexCentricIteration.withEdges(
edges, vertexUpdateFunction, messagingFunction, maximumNumberOfIterations);
......@@ -1195,16 +1196,40 @@ public class Graph<K, VV, EV> {
* @param maximumNumberOfIterations maximum number of iterations to perform
* @param <M> the intermediate type used between gather, sum and apply
*
* @return the updated Graph after the vertex-centric iteration has converged or
* @return the updated Graph after the gather-sum-apply iteration has converged or
* after maximumNumberOfIterations.
*/
public <M> Graph<K, VV, EV> runGatherSumApplyIteration(
GatherFunction<VV, EV, M> gatherFunction, SumFunction<VV, EV, M> sumFunction,
ApplyFunction<K, VV, M> applyFunction, int maximumNumberOfIterations) {
return this.runGatherSumApplyIteration(gatherFunction, sumFunction, applyFunction,
maximumNumberOfIterations, null);
}
/**
* Runs a Gather-Sum-Apply iteration on the graph with configuration options.
*
* @param gatherFunction the gather function collects information about adjacent vertices and edges
* @param sumFunction the sum function aggregates the gathered information
* @param applyFunction the apply function updates the vertex values with the aggregates
* @param maximumNumberOfIterations maximum number of iterations to perform
* @param parameters the iteration configuration parameters
* @param <M> the intermediate type used between gather, sum and apply
*
* @return the updated Graph after the gather-sum-apply iteration has converged or
* after maximumNumberOfIterations.
*/
public <M> Graph<K, VV, EV> runGatherSumApplyIteration(
GatherFunction<VV, EV, M> gatherFunction, SumFunction<VV, EV, M> sumFunction,
ApplyFunction<K, VV, M> applyFunction, int maximumNumberOfIterations,
GSAConfiguration parameters) {
GatherSumApplyIteration<K, VV, EV, M> iteration = GatherSumApplyIteration.withEdges(
edges, gatherFunction, sumFunction, applyFunction, maximumNumberOfIterations);
iteration.configure(parameters);
DataSet<Vertex<K, VV>> newVertices = vertices.runOperation(iteration);
return new Graph<K, VV, EV>(newVertices, this.edges, this.context);
......
......@@ -16,32 +16,19 @@
* limitations under the License.
*/
package org.apache.flink.graph.spargel;
package org.apache.flink.graph;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import com.google.common.base.Preconditions;
/**
* This class is used to configure a vertex-centric iteration.
*
* An IterationConfiguration object can be used to set the iteration name and
* degree of parallelism, to register aggregators and use broadcast sets in
* the {@link VertexUpdateFunction} and {@link MessagingFunction}.
*
* The IterationConfiguration object is passed as an argument to
* {@link org.apache.flink.graph.Graph#runVertexCentricIteration(
* VertexUpdateFunction, MessagingFunction, int, IterationConfiguration)}.
*
* This is used as a base class for vertex-centric iteration or gather-sum-apply iteration configuration.
*/
public class IterationConfiguration {
public abstract class IterationConfiguration {
/** the iteration name **/
private String name;
......@@ -52,20 +39,13 @@ public class IterationConfiguration {
/** the iteration aggregators **/
private Map<String, Aggregator<?>> aggregators = new HashMap<String, Aggregator<?>>();
/** the broadcast variables for the update function **/
private List<Tuple2<String, DataSet<?>>> bcVarsUpdate = new ArrayList<Tuple2<String,DataSet<?>>>();
/** the broadcast variables for the messaging function **/
private List<Tuple2<String, DataSet<?>>> bcVarsMessaging = new ArrayList<Tuple2<String,DataSet<?>>>();
/** flag that defines whether the solution set is kept in managed memory **/
private boolean unmanagedSolutionSet = false;
public IterationConfiguration() {}
/**
* Sets the name for the vertex-centric iteration. The name is displayed in logs and messages.
* Sets the name for the iteration. The name is displayed in logs and messages.
*
* @param name The name for the iteration.
*/
......@@ -74,7 +54,7 @@ public class IterationConfiguration {
}
/**
* Gets the name of the vertex-centric iteration.
* Gets the name of the iteration.
* @param defaultName
*
* @return The name of the iteration.
......@@ -131,8 +111,8 @@ public class IterationConfiguration {
/**
* Registers a new aggregator. Aggregators registered here are available during the execution of the vertex updates
* via {@link VertexUpdateFunction#getIterationAggregator(String)} and
* {@link VertexUpdateFunction#getPreviousIterationAggregate(String)}.
* via {@link org.apache.flink.graph.spargel.VertexUpdateFunction#getIterationAggregator(String)} and
* {@link org.apache.flink.graph.spargel.VertexUpdateFunction#getPreviousIterationAggregate(String)}.
*
* @param name The name of the aggregator, used to retrieve it and its aggregates during execution.
* @param aggregator The aggregator.
......@@ -140,26 +120,6 @@ public class IterationConfiguration {
public void registerAggregator(String name, Aggregator<?> aggregator) {
this.aggregators.put(name, aggregator);
}
/**
* Adds a data set as a broadcast set to the messaging function.
*
* @param name The name under which the broadcast data is available in the messaging function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForMessagingFunction(String name, DataSet<?> data) {
this.bcVarsMessaging.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Adds a data set as a broadcast set to the vertex update function.
*
* @param name The name under which the broadcast data is available in the vertex update function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForUpdateFunction(String name, DataSet<?> data) {
this.bcVarsUpdate.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Gets the set of aggregators that are registered for this vertex-centric iteration.
......@@ -170,24 +130,4 @@ public class IterationConfiguration {
public Map<String, Aggregator<?>> getAggregators() {
return this.aggregators;
}
/**
* Get the broadcast variables of the VertexUpdateFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getUpdateBcastVars() {
return this.bcVarsUpdate;
}
/**
* Get the broadcast variables of the MessagingFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getMessagingBcastVars() {
return this.bcVarsMessaging;
}
}
......@@ -18,11 +18,14 @@
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.graph.Vertex;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import java.io.Serializable;
import java.util.Collection;
@SuppressWarnings("serial")
public abstract class ApplyFunction<K, VV, M> implements Serializable {
......@@ -62,6 +65,38 @@ public abstract class ApplyFunction<K, VV, M> implements Serializable {
return this.runtimeContext.getSuperstepNumber();
}
/**
* Gets the iteration aggregator registered under the given name. The iteration aggregator combines
* all aggregates globally once per superstep and makes them available in the next superstep.
*
* @param name The name of the aggregator.
* @return The aggregator registered under this name, or null, if no aggregator was registered.
*/
public <T extends Aggregator<?>> T getIterationAggregator(String name) {
return this.runtimeContext.<T>getIterationAggregator(name);
}
/**
* Get the aggregated value that an aggregator computed in the previous iteration.
*
* @param name The name of the aggregator.
* @return The aggregated value of the previous iteration.
*/
public <T extends Value> T getPreviousIterationAggregate(String name) {
return this.runtimeContext.<T>getPreviousIterationAggregate(name);
}
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
*/
public <T> Collection<T> getBroadcastSet(String name) {
return this.runtimeContext.<T>getBroadcastVariable(name);
}
// --------------------------------------------------------------------------------------------
// Internal methods
// --------------------------------------------------------------------------------------------
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.graph.IterationConfiguration;
import java.util.ArrayList;
import java.util.List;
/**
* A GSAConfiguration object can be used to set the iteration name and
* degree of parallelism, to register aggregators and use broadcast sets in
* the {@link org.apache.flink.graph.gsa.GatherFunction}, {@link org.apache.flink.graph.gsa.SumFunction} as well as
* {@link org.apache.flink.graph.gsa.ApplyFunction}.
*
* The GSAConfiguration object is passed as an argument to
* {@link org.apache.flink.graph.Graph#runGatherSumApplyIteration(org.apache.flink.graph.gsa.GatherFunction,
* org.apache.flink.graph.gsa.SumFunction, org.apache.flink.graph.gsa.ApplyFunction, int)}
*/
public class GSAConfiguration extends IterationConfiguration {
/** the broadcast variables for the gather function **/
private List<Tuple2<String, DataSet<?>>> bcVarsGather = new ArrayList<Tuple2<String,DataSet<?>>>();
/** the broadcast variables for the sum function **/
private List<Tuple2<String, DataSet<?>>> bcVarsSum = new ArrayList<Tuple2<String,DataSet<?>>>();
/** the broadcast variables for the apply function **/
private List<Tuple2<String, DataSet<?>>> bcVarsApply = new ArrayList<Tuple2<String,DataSet<?>>>();
public GSAConfiguration() {}
/**
* Adds a data set as a broadcast set to the gather function.
*
* @param name The name under which the broadcast data is available in the gather function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForGatherFunction(String name, DataSet<?> data) {
this.bcVarsGather.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Adds a data set as a broadcast set to the sum function.
*
* @param name The name under which the broadcast data is available in the sum function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForSumFunction(String name, DataSet<?> data) {
this.bcVarsSum.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Adds a data set as a broadcast set to the apply function.
*
* @param name The name under which the broadcast data is available in the apply function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForApplyFunction(String name, DataSet<?> data) {
this.bcVarsApply.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Get the broadcast variables of the GatherFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getGatherBcastVars() {
return this.bcVarsGather;
}
/**
* Get the broadcast variables of the SumFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getSumBcastVars() {
return this.bcVarsSum;
}
/**
* Get the broadcast variables of the ApplyFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getApplyBcastVars() {
return this.bcVarsApply;
}
}
......@@ -18,9 +18,12 @@
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.types.Value;
import java.io.Serializable;
import java.util.Collection;
@SuppressWarnings("serial")
public abstract class GatherFunction<VV, EV, M> implements Serializable {
......@@ -50,6 +53,38 @@ public abstract class GatherFunction<VV, EV, M> implements Serializable {
return this.runtimeContext.getSuperstepNumber();
}
/**
* Gets the iteration aggregator registered under the given name. The iteration aggregator combines
* all aggregates globally once per superstep and makes them available in the next superstep.
*
* @param name The name of the aggregator.
* @return The aggregator registered under this name, or null, if no aggregator was registered.
*/
public <T extends Aggregator<?>> T getIterationAggregator(String name) {
return this.runtimeContext.<T>getIterationAggregator(name);
}
/**
* Get the aggregated value that an aggregator computed in the previous iteration.
*
* @param name The name of the aggregator.
* @return The aggregated value of the previous iteration.
*/
public <T extends Value> T getPreviousIterationAggregate(String name) {
return this.runtimeContext.<T>getPreviousIterationAggregate(name);
}
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
*/
public <T> Collection<T> getBroadcastSet(String name) {
return this.runtimeContext.<T>getBroadcastVariable(name);
}
// --------------------------------------------------------------------------------------------
// Internal methods
// --------------------------------------------------------------------------------------------
......
......@@ -18,6 +18,7 @@
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
......@@ -29,6 +30,8 @@ import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSec
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
......@@ -37,6 +40,7 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Vertex;
import org.apache.flink.util.Collector;
import java.util.Map;
import com.google.common.base.Preconditions;
......@@ -59,6 +63,8 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
private final ApplyFunction<K, VV, M> apply;
private final int maximumNumberOfIterations;
private GSAConfiguration configuration;
// ----------------------------------------------------------------------------------
private GatherSumApplyIteration(GatherFunction<VV, EV, M> gather, SumFunction<VV, EV, M> sum,
......@@ -119,20 +125,69 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
final DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration =
vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos);
// set up the iteration operator
if (this.configuration != null) {
iteration.name(this.configuration.getName(
"Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")"));
iteration.parallelism(this.configuration.getParallelism());
iteration.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory());
// register all aggregators
for (Map.Entry<String, Aggregator<?>> entry : this.configuration.getAggregators().entrySet()) {
iteration.registerAggregator(entry.getKey(), entry.getValue());
}
}
else {
// no configuration provided; set default name
iteration.name("Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")");
}
// Prepare the neighbors
DataSet<Tuple2<K, Neighbor<VV, EV>>> neighbors = iteration
.getWorkset().join(edgeDataSet)
.where(0).equalTo(0).with(new ProjectKeyWithNeighbor<K, VV, EV>());
// Gather, sum and apply
DataSet<Tuple2<K, M>> gatheredSet = neighbors.map(gatherUdf);
DataSet<Tuple2<K, M>> summedSet = gatheredSet.groupBy(0).reduce(sumUdf);
MapOperator<Tuple2<K, Neighbor<VV, EV>>, Tuple2<K, M>> gatherMapOperator = neighbors.map(gatherUdf);
// configure map gather function with name and broadcast variables
gatherMapOperator = gatherMapOperator.name("Gather");
if (this.configuration != null) {
for (Tuple2<String, DataSet<?>> e : this.configuration.getGatherBcastVars()) {
gatherMapOperator = gatherMapOperator.withBroadcastSet(e.f1, e.f0);
}
}
DataSet<Tuple2<K, M>> gatheredSet = gatherMapOperator;
ReduceOperator<Tuple2<K, M>> sumReduceOperator = gatheredSet.groupBy(0).reduce(sumUdf);
// configure reduce sum function with name and broadcast variables
sumReduceOperator = sumReduceOperator.name("Sum");
if (this.configuration != null) {
for (Tuple2<String, DataSet<?>> e : this.configuration.getSumBcastVars()) {
sumReduceOperator = sumReduceOperator.withBroadcastSet(e.f1, e.f0);
}
}
DataSet<Tuple2<K, M>> summedSet = sumReduceOperator;
JoinOperator<?, ?, Vertex<K, VV>> appliedSet = summedSet
.join(iteration.getSolutionSet())
.where(0)
.equalTo(0)
.with(applyUdf);
// configure join apply function with name and broadcast variables
appliedSet = appliedSet.name("Apply");
if (this.configuration != null) {
for (Tuple2<String, DataSet<?>> e : this.configuration.getApplyBcastVars()) {
appliedSet = appliedSet.withBroadcastSet(e.f1, e.f0);
}
}
// let the operator know that we preserve the key field
appliedSet.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0");
......@@ -293,4 +348,19 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
}
}
/**
* Configures this gather-sum-apply iteration with the provided parameters.
*
* @param parameters the configuration parameters
*/
public void configure(GSAConfiguration parameters) {
this.configuration = parameters;
}
/**
* @return the configuration parameters of this gather-sum-apply iteration
*/
public GSAConfiguration getIterationConfiguration() {
return this.configuration;
}
}
......@@ -18,9 +18,12 @@
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.types.Value;
import java.io.Serializable;
import java.util.Collection;
@SuppressWarnings("serial")
public abstract class SumFunction<VV, EV, M> implements Serializable {
......@@ -50,6 +53,38 @@ public abstract class SumFunction<VV, EV, M> implements Serializable {
return this.runtimeContext.getSuperstepNumber();
}
/**
* Gets the iteration aggregator registered under the given name. The iteration aggregator combines
* all aggregates globally once per superstep and makes them available in the next superstep.
*
* @param name The name of the aggregator.
* @return The aggregator registered under this name, or null, if no aggregator was registered.
*/
public <T extends Aggregator<?>> T getIterationAggregator(String name) {
return this.runtimeContext.<T>getIterationAggregator(name);
}
/**
* Get the aggregated value that an aggregator computed in the previous iteration.
*
* @param name The name of the aggregator.
* @return The aggregated value of the previous iteration.
*/
public <T extends Value> T getPreviousIterationAggregate(String name) {
return this.runtimeContext.<T>getPreviousIterationAggregate(name);
}
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
*/
public <T> Collection<T> getBroadcastSet(String name) {
return this.runtimeContext.<T>getBroadcastVariable(name);
}
// --------------------------------------------------------------------------------------------
// Internal methods
// --------------------------------------------------------------------------------------------
......
......@@ -136,7 +136,7 @@ public abstract class MessagingFunction<VertexKey, VertexValue, Message, EdgeVal
}
/**
* Gets the iteration aggregator registered under the given name. The iteration aggregator is combines
* Gets the iteration aggregator registered under the given name. The iteration aggregator combines
* all aggregates globally once per superstep and makes them available in the next superstep.
*
* @param name The name of the aggregator.
......@@ -159,7 +159,7 @@ public abstract class MessagingFunction<VertexKey, VertexValue, Message, EdgeVal
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function. They can be registered via
* {@link VertexCentricIteration#addBroadcastSetForMessagingFunction(String, org.apache.flink.api.java.DataSet)}.
* {@link org.apache.flink.graph.spargel.VertexCentricConfiguration#addBroadcastSetForMessagingFunction(String, org.apache.flink.api.java.DataSet)}.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.spargel;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.graph.IterationConfiguration;
import java.util.ArrayList;
import java.util.List;
/**
* A VertexCentricConfiguration object can be used to set the iteration name and
* degree of parallelism, to register aggregators and use broadcast sets in
* the {@link org.apache.flink.graph.spargel.VertexUpdateFunction} and {@link org.apache.flink.graph.spargel.MessagingFunction}
*
* The VertexCentricConfiguration object is passed as an argument to
* {@link org.apache.flink.graph.Graph#runVertexCentricIteration (
* org.apache.flink.graph.spargel.VertexUpdateFunction, org.apache.flink.graph.spargel.MessagingFunction, int,
* VertexCentricConfiguration)}.
*/
public class VertexCentricConfiguration extends IterationConfiguration {
/** the broadcast variables for the update function **/
private List<Tuple2<String, DataSet<?>>> bcVarsUpdate = new ArrayList<Tuple2<String,DataSet<?>>>();
/** the broadcast variables for the messaging function **/
private List<Tuple2<String, DataSet<?>>> bcVarsMessaging = new ArrayList<Tuple2<String,DataSet<?>>>();
public VertexCentricConfiguration() {}
/**
* Adds a data set as a broadcast set to the messaging function.
*
* @param name The name under which the broadcast data is available in the messaging function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForMessagingFunction(String name, DataSet<?> data) {
this.bcVarsMessaging.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Adds a data set as a broadcast set to the vertex update function.
*
* @param name The name under which the broadcast data is available in the vertex update function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForUpdateFunction(String name, DataSet<?> data) {
this.bcVarsUpdate.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Get the broadcast variables of the VertexUpdateFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getUpdateBcastVars() {
return this.bcVarsUpdate;
}
/**
* Get the broadcast variables of the MessagingFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getMessagingBcastVars() {
return this.bcVarsMessaging;
}
}
......@@ -60,9 +60,7 @@ import com.google.common.base.Preconditions;
* </ul>
* <p>
* Vertex-centric graph iterations are instantiated by the
* {@link #withPlainEdges(DataSet, VertexUpdateFunction, MessagingFunction, int)} method, or the
* {@link #withValuedEdges(DataSet, VertexUpdateFunction, MessagingFunction, int)} method, depending on whether
* the graph's edges are carrying values.
* {@link #withEdges(DataSet, VertexUpdateFunction, MessagingFunction, int)} method.
*
* @param <VertexKey> The type of the vertex key (the vertex identifier).
* @param <VertexValue> The type of the vertex value (the state of the vertex).
......@@ -84,7 +82,7 @@ public class VertexCentricIteration<VertexKey, VertexValue, Message, EdgeValue>
private DataSet<Vertex<VertexKey, VertexValue>> initialVertices;
private IterationConfiguration configuration;
private VertexCentricConfiguration configuration;
// ----------------------------------------------------------------------------------
......@@ -362,14 +360,14 @@ public class VertexCentricIteration<VertexKey, VertexValue, Message, EdgeValue>
*
* @param parameters the configuration parameters
*/
public void configure(IterationConfiguration parameters) {
public void configure(VertexCentricConfiguration parameters) {
this.configuration = parameters;
}
/**
* @return the configuration parameters of this vertex-centric iteration
*/
public IterationConfiguration getIterationConfiguration() {
public VertexCentricConfiguration getIterationConfiguration() {
return this.configuration;
}
}
......@@ -114,7 +114,7 @@ public abstract class VertexUpdateFunction<VertexKey, VertexValue, Message> impl
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function. They can be registered via
* {@link VertexCentricIteration#addBroadcastSetForUpdateFunction(String, org.apache.flink.api.java.DataSet)}.
* {@link org.apache.flink.graph.spargel.VertexCentricConfiguration#addBroadcastSetForUpdateFunction(String, org.apache.flink.api.java.DataSet)}.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.test;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GSAConfiguration;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.Neighbor;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.IterationConfiguration;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.types.LongValue;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.List;
@RunWith(Parameterized.class)
public class GatherSumApplyConfigurationITCase extends MultipleProgramsTestBase {
public GatherSumApplyConfigurationITCase(TestExecutionMode mode) {
super(mode);
}
private String resultPath;
private String expectedResult;
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
@Before
public void before() throws Exception{
resultPath = tempFolder.newFile().toURI().toString();
}
@After
public void after() throws Exception{
compareResultsByLinesInMemory(expectedResult, resultPath);
}
@Test
public void testRunWithConfiguration() throws Exception {
/*
* Test Graph's runGatherSumApplyIteration when configuration parameters are provided
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Graph<Long, Long, Long> graph = Graph.fromCollection(TestGraphUtils.getLongLongVertices(),
TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper());
// create the configuration object
GSAConfiguration parameters = new GSAConfiguration();
parameters.addBroadcastSetForGatherFunction("gatherBcastSet", env.fromElements(1, 2, 3));
parameters.addBroadcastSetForSumFunction("sumBcastSet", env.fromElements(4, 5, 6));
parameters.addBroadcastSetForApplyFunction("applyBcastSet", env.fromElements(7, 8, 9));
parameters.registerAggregator("superstepAggregator", new LongSumAggregator());
Graph<Long, Long, Long> result = graph.runGatherSumApplyIteration(new Gather(), new Sum(),
new Apply(), 10, parameters);
result.getVertices().writeAsCsv(resultPath, "\n", "\t");
env.execute();
expectedResult = "1 11\n" +
"2 11\n" +
"3 11\n" +
"4 11\n" +
"5 11";
}
@Test
public void testIterationConfiguration() throws Exception {
/*
* Test name, parallelism and solutionSetUnmanaged parameters
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
GatherSumApplyIteration<Long, Long, Long, Long> iteration = GatherSumApplyIteration
.withEdges(TestGraphUtils.getLongLongEdgeData(env), new DummyGather(),
new DummySum(), new DummyApply(), 10);
GSAConfiguration parameters = new GSAConfiguration();
parameters.setName("gelly iteration");
parameters.setParallelism(2);
parameters.setSolutionSetUnmanagedMemory(true);
iteration.configure(parameters);
Assert.assertEquals("gelly iteration", iteration.getIterationConfiguration().getName(""));
Assert.assertEquals(2, iteration.getIterationConfiguration().getParallelism());
Assert.assertEquals(true, iteration.getIterationConfiguration().isSolutionSetUnmanagedMemory());
DataSet<Vertex<Long, Long>> result = TestGraphUtils.getLongLongVertexData(env).runOperation(iteration);
result.writeAsCsv(resultPath, "\n", "\t");
env.execute();
expectedResult = "1 11\n" +
"2 12\n" +
"3 13\n" +
"4 14\n" +
"5 15";
}
@SuppressWarnings("serial")
private static final class Gather extends GatherFunction<Long, Long, Long> {
@Override
public void preSuperstep() {
// test bcast variable
@SuppressWarnings("unchecked")
List<Tuple1<Integer>> bcastSet = (List<Tuple1<Integer>>)(List<?>)getBroadcastSet("gatherBcastSet");
Assert.assertEquals(1, bcastSet.get(0));
Assert.assertEquals(2, bcastSet.get(1));
Assert.assertEquals(3, bcastSet.get(2));
// test aggregator
if (getSuperstepNumber() == 2) {
long aggrValue = ((LongValue)getPreviousIterationAggregate("superstepAggregator")).getValue();
Assert.assertEquals(7, aggrValue);
}
}
public Long gather(Neighbor<Long, Long> neighbor) {
return neighbor.getNeighborValue();
}
}
@SuppressWarnings("serial")
private static final class Sum extends SumFunction<Long, Long, Long> {
LongSumAggregator aggregator = new LongSumAggregator();
@Override
public void preSuperstep() {
// test bcast variable
@SuppressWarnings("unchecked")
List<Tuple1<Integer>> bcastSet = (List<Tuple1<Integer>>)(List<?>)getBroadcastSet("sumBcastSet");
Assert.assertEquals(4, bcastSet.get(0));
Assert.assertEquals(5, bcastSet.get(1));
Assert.assertEquals(6, bcastSet.get(2));
// test aggregator
aggregator = getIterationAggregator("superstepAggregator");
}
public Long sum(Long newValue, Long currentValue) {
long superstep = getSuperstepNumber();
aggregator.aggregate(superstep);
return 0l;
}
}
@SuppressWarnings("serial")
private static final class Apply extends ApplyFunction<Long, Long, Long> {
LongSumAggregator aggregator = new LongSumAggregator();
@Override
public void preSuperstep() {
// test bcast variable
@SuppressWarnings("unchecked")
List<Tuple1<Integer>> bcastSet = (List<Tuple1<Integer>>)(List<?>)getBroadcastSet("applyBcastSet");
Assert.assertEquals(7, bcastSet.get(0));
Assert.assertEquals(8, bcastSet.get(1));
Assert.assertEquals(9, bcastSet.get(2));
// test aggregator
aggregator = getIterationAggregator("superstepAggregator");
}
public void apply(Long summedValue, Long origValue) {
long superstep = getSuperstepNumber();
aggregator.aggregate(superstep);
setResult(origValue + 1);
}
}
@SuppressWarnings("serial")
private static final class DummyGather extends GatherFunction<Long, Long, Long> {
public Long gather(Neighbor<Long, Long> neighbor) {
return neighbor.getNeighborValue();
}
}
@SuppressWarnings("serial")
private static final class DummySum extends SumFunction<Long, Long, Long> {
public Long sum(Long newValue, Long currentValue) {
return 0l;
}
}
@SuppressWarnings("serial")
private static final class DummyApply extends ApplyFunction<Long, Long, Long> {
public void apply(Long summedValue, Long origValue) {
setResult(origValue + 1);
}
}
@SuppressWarnings("serial")
public static final class AssignOneMapper implements MapFunction<Vertex<Long, Long>, Long> {
public Long map(Vertex<Long, Long> value) {
return 1l;
}
}
}
......@@ -27,12 +27,12 @@ import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.IterationConfiguration;
import org.apache.flink.graph.IterationConfiguration;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexCentricConfiguration;
import org.apache.flink.graph.spargel.VertexCentricIteration;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.utils.VertexToTuple2Map;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.types.LongValue;
import org.junit.After;
......@@ -78,7 +78,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase {
TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper());
// create the configuration object
IterationConfiguration parameters = new IterationConfiguration();
VertexCentricConfiguration parameters = new VertexCentricConfiguration();
parameters.addBroadcastSetForUpdateFunction("updateBcastSet", env.fromElements(1, 2, 3));
parameters.addBroadcastSetForMessagingFunction("messagingBcastSet", env.fromElements(4, 5, 6));
......@@ -87,7 +87,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase {
Graph<Long, Long, Long> result = graph.runVertexCentricIteration(
new UpdateFunction(), new MessageFunction(), 10, parameters);
result.getVertices().map(new VertexToTuple2Map<Long, Long>()).writeAsCsv(resultPath, "\n", "\t");
result.getVertices().writeAsCsv(resultPath, "\n", "\t");
env.execute();
expectedResult = "1 11\n" +
"2 11\n" +
......@@ -108,7 +108,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase {
.withEdges(TestGraphUtils.getLongLongEdgeData(env), new DummyUpdateFunction(),
new DummyMessageFunction(), 10);
IterationConfiguration parameters = new IterationConfiguration();
VertexCentricConfiguration parameters = new VertexCentricConfiguration();
parameters.setName("gelly iteration");
parameters.setParallelism(2);
parameters.setSolutionSetUnmanagedMemory(true);
......@@ -121,7 +121,7 @@ public class VertexCentricConfigurationITCase extends MultipleProgramsTestBase {
DataSet<Vertex<Long, Long>> result = TestGraphUtils.getLongLongVertexData(env).runOperation(iteration);
result.map(new VertexToTuple2Map<Long, Long>()).writeAsCsv(resultPath, "\n", "\t");
result.writeAsCsv(resultPath, "\n", "\t");
env.execute();
expectedResult = "1 11\n" +
"2 12\n" +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册