提交 918e5d0c 编写于 作者: G Greg Hogan

[FLINK-3618] [gelly] Rename abstract UDF classes in Scatter-Gather implementation

Rename MessageFunction to ScatterFunction
and VertexUpdateFunction to GatherFunction.

Change the parameter order in
  Graph.runScatterGatherIteration(VertexUpdateFunction, MessagingFunction)
to
  Graph.runScatterGatherIteration(ScatterFunction, GatherFunction)

This closes #2184
上级 6c6b17b4
......@@ -1083,10 +1083,10 @@ final class Compute extends ComputeFunction {
### Scatter-Gather Iterations
The scatter-gather model, also known as "signal/collect" model, expresses computation from the perspective of a vertex in the graph. The computation proceeds in synchronized iteration steps, called supersteps. In each superstep, a vertex produces messages for other vertices and updates its value based on the messages it receives. To use scatter-gather iterations in Gelly, the user only needs to define how a vertex behaves in each superstep:
* <strong>Messaging</strong>: corresponds to the scatter phase and produces the messages that a vertex will send to other vertices.
* <strong>Value Update</strong>: corresponds to the gather phase and updates the vertex value using the received messages.
* <strong>Scatter</strong>: produces the messages that a vertex will send to other vertices.
* <strong>Gather</strong>: updates the vertex value using received messages.
Gelly provides methods for scatter-gather iterations. The user only needs to implement two functions, corresponding to the scatter and gather phases. The first function is a `MessagingFunction`, which allows a vertex to send out messages for other vertices. Messages are recieved during the same superstep as they are sent. The second function is `VertexUpdateFunction`, which defines how a vertex will update its value based on the received messages.
Gelly provides methods for scatter-gather iterations. The user only needs to implement two functions, corresponding to the scatter and gather phases. The first function is a `ScatterFunction`, which allows a vertex to send out messages to other vertices. Messages are received during the same superstep as they are sent. The second function is `GatherFunction`, which defines how a vertex will update its value based on the received messages.
These functions and the maximum number of iterations to run are given as parameters to Gelly's `runScatterGatherIteration`. This method will execute the scatter-gather iteration on the input Graph and return a new Graph, with updated vertex values.
A scatter-gather iteration can be extended with information such as the total number of vertices, the in degree and out degree.
......@@ -1109,7 +1109,7 @@ int maxIterations = 10;
// Execute the scatter-gather iteration
Graph<Long, Double, Double> result = graph.runScatterGatherIteration(
new VertexDistanceUpdater(), new MinDistanceMessenger(), maxIterations);
new MinDistanceMessenger(), new VertexDistanceUpdater(), maxIterations);
// Extract the vertices as the result
DataSet<Vertex<Long, Double>> singleSourceShortestPaths = result.getVertices();
......@@ -1118,7 +1118,7 @@ DataSet<Vertex<Long, Double>> singleSourceShortestPaths = result.getVertices();
// - - - UDFs - - - //
// scatter: messaging
public static final class MinDistanceMessenger extends MessagingFunction<Long, Double, Double, Double> {
public static final class MinDistanceMessenger extends ScatterFunction<Long, Double, Double, Double> {
public void sendMessages(Vertex<Long, Double> vertex) {
for (Edge<Long, Double> edge : getEdges()) {
......@@ -1128,7 +1128,7 @@ public static final class MinDistanceMessenger extends MessagingFunction<Long, D
}
// gather: vertex update
public static final class VertexDistanceUpdater extends VertexUpdateFunction<Long, Double, Double> {
public static final class VertexDistanceUpdater extends GatherFunction<Long, Double, Double> {
public void updateVertex(Vertex<Long, Double> vertex, MessageIterator<Double> inMessages) {
Double minDistance = Double.MAX_VALUE;
......@@ -1157,7 +1157,7 @@ val graph: Graph[Long, Double, Double] = ...
val maxIterations = 10
// Execute the scatter-gather iteration
val result = graph.runScatterGatherIteration(new VertexDistanceUpdater, new MinDistanceMessenger, maxIterations)
val result = graph.runScatterGatherIteration(new MinDistanceMessenger, new VertexDistanceUpdater, maxIterations)
// Extract the vertices as the result
val singleSourceShortestPaths = result.getVertices
......@@ -1166,7 +1166,7 @@ val singleSourceShortestPaths = result.getVertices
// - - - UDFs - - - //
// messaging
final class MinDistanceMessenger extends MessagingFunction[Long, Double, Double, Double] {
final class MinDistanceMessenger extends ScatterFunction[Long, Double, Double, Double] {
override def sendMessages(vertex: Vertex[Long, Double]) = {
for (edge: Edge[Long, Double] <- getEdges) {
......@@ -1176,7 +1176,7 @@ final class MinDistanceMessenger extends MessagingFunction[Long, Double, Double,
}
// vertex update
final class VertexDistanceUpdater extends VertexUpdateFunction[Long, Double, Double] {
final class VertexDistanceUpdater extends GatherFunction[Long, Double, Double] {
override def updateVertex(vertex: Vertex[Long, Double], inMessages: MessageIterator[Double]) = {
var minDistance = Double.MaxValue
......@@ -1211,9 +1211,9 @@ and can be specified using the `setName()` method.
* <strong>Solution set in unmanaged memory</strong>: Defines whether the solution set is kept in managed memory (Flink's internal way of keeping objects in serialized form) or as a simple object map. By default, the solution set runs in managed memory. This property can be set using the `setSolutionSetUnmanagedMemory()` method.
* <strong>Aggregators</strong>: Iteration aggregators can be registered using the `registerAggregator()` method. An iteration aggregator combines
all aggregates globally once per superstep and makes them available in the next superstep. Registered aggregators can be accessed inside the user-defined `VertexUpdateFunction` and `MessagingFunction`.
all aggregates globally once per superstep and makes them available in the next superstep. Registered aggregators can be accessed inside the user-defined `ScatterFunction` and `GatherFunction`.
* <strong>Broadcast Variables</strong>: DataSets can be added as [Broadcast Variables]({{site.baseurl}}/apis/batch/index.html#broadcast-variables) to the `VertexUpdateFunction` and `MessagingFunction`, using the `addBroadcastSetForUpdateFunction()` and `addBroadcastSetForMessagingFunction()` methods, respectively.
* <strong>Broadcast Variables</strong>: DataSets can be added as [Broadcast Variables]({{site.baseurl}}/apis/batch/index.html#broadcast-variables) to the `ScatterFunction` and `GatherFunction`, using the `addBroadcastSetForUpdateFunction()` and `addBroadcastSetForMessagingFunction()` methods, respectively.
* <strong>Number of Vertices</strong>: Accessing the total number of vertices within the iteration. This property can be set using the `setOptNumVertices()` method.
The number of vertices can then be accessed in the vertex update function and in the messaging function using the `getNumberOfVertices()` method. If the option is not set in the configuration, this method will return -1.
......@@ -1245,10 +1245,12 @@ parameters.registerAggregator("sumAggregator", new LongSumAggregator());
// run the scatter-gather iteration, also passing the configuration parameters
Graph<Long, Double, Double> result =
graph.runScatterGatherIteration(
new VertexUpdater(), new Messenger(), maxIterations, parameters);
new Messenger(), new VertexUpdater(), maxIterations, parameters);
// user-defined functions
public static final class VertexUpdater extends VertexUpdateFunction {
public static final class Messenger extends ScatterFunction {...}
public static final class VertexUpdater extends GatherFunction {
LongSumAggregator aggregator = new LongSumAggregator();
......@@ -1272,8 +1274,6 @@ public static final class VertexUpdater extends VertexUpdateFunction {
}
}
public static final class Messenger extends MessagingFunction {...}
{% endhighlight %}
</div>
......@@ -1294,10 +1294,12 @@ parameters.setParallelism(16)
parameters.registerAggregator("sumAggregator", new LongSumAggregator)
// run the scatter-gather iteration, also passing the configuration parameters
val result = graph.runScatterGatherIteration(new VertexUpdater, new Messenger, maxIterations, parameters)
val result = graph.runScatterGatherIteration(new Messenger, new VertexUpdater, maxIterations, parameters)
// user-defined functions
final class VertexUpdater extends VertexUpdateFunction {
final class Messenger extends ScatterFunction {...}
final class VertexUpdater extends GatherFunction {
var aggregator = new LongSumAggregator
......@@ -1321,8 +1323,6 @@ final class VertexUpdater extends VertexUpdateFunction {
}
}
final class Messenger extends MessagingFunction {...}
{% endhighlight %}
</div>
</div>
......@@ -1347,20 +1347,20 @@ parameters.setOptDegrees(true);
// run the scatter-gather iteration, also passing the configuration parameters
Graph<Long, Double, Double> result =
graph.runScatterGatherIteration(
new VertexUpdater(), new Messenger(), maxIterations, parameters);
new Messenger(), new VertexUpdater(), maxIterations, parameters);
// user-defined functions
public static final class VertexUpdater {
public static final class Messenger extends ScatterFunction {
...
// get the number of vertices
long numVertices = getNumberOfVertices();
// retrieve the vertex out-degree
outDegree = getOutDegree();
...
}
public static final class Messenger {
public static final class VertexUpdater extends GatherFunction {
...
// retrieve the vertex out-degree
outDegree = getOutDegree();
// get the number of vertices
long numVertices = getNumberOfVertices();
...
}
......@@ -1382,20 +1382,20 @@ parameters.setOptNumVertices(true)
parameters.setOptDegrees(true)
// run the scatter-gather iteration, also passing the configuration parameters
val result = graph.runScatterGatherIteration(new VertexUpdater, new Messenger, maxIterations, parameters)
val result = graph.runScatterGatherIteration(new Messenger, new VertexUpdater, maxIterations, parameters)
// user-defined functions
final class VertexUpdater {
final class Messenger extends ScatterFunction {
...
// get the number of vertices
val numVertices = getNumberOfVertices
// retrieve the vertex out-degree
val outDegree = getOutDegree
...
}
final class Messenger {
final class VertexUpdater extends GatherFunction {
...
// retrieve the vertex out-degree
val outDegree = getOutDegree
// get the number of vertices
val numVertices = getNumberOfVertices
...
}
......@@ -1419,13 +1419,13 @@ parameters.setDirection(EdgeDirection.IN);
// run the scatter-gather iteration, also passing the configuration parameters
DataSet<Vertex<Long, HashSet<Long>>> result =
graph.runScatterGatherIteration(
new VertexUpdater(), new Messenger(), maxIterations, parameters)
new Messenger(), new VertexUpdater(), maxIterations, parameters)
.getVertices();
// user-defined functions
public static final class VertexUpdater {...}
public static final class Messenger extends GatherFunction {...}
public static final class Messenger {...}
public static final class VertexUpdater extends ScatterFunction {...}
{% endhighlight %}
</div>
......@@ -1441,13 +1441,13 @@ val parameters = new ScatterGatherConfiguration
parameters.setDirection(EdgeDirection.IN)
// run the scatter-gather iteration, also passing the configuration parameters
val result = graph.runScatterGatherIteration(new VertexUpdater, new Messenger, maxIterations, parameters)
val result = graph.runScatterGatherIteration(new Messenger, new VertexUpdater, maxIterations, parameters)
.getVertices
// user-defined functions
final class VertexUpdater {...}
final class Messenger extends ScatterFunction {...}
final class Messenger {...}
final class VertexUpdater extends GatherFunction {...}
{% endhighlight %}
</div>
......
......@@ -21,17 +21,16 @@ package org.apache.flink.graph.examples;
import org.apache.flink.api.common.aggregators.DoubleSumAggregator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.spargel.ScatterGatherConfiguration;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.util.Preconditions;
......@@ -100,17 +99,55 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS
parameter.registerAggregator("diffValueSum", new DoubleSumAggregator());
return newGraph
.runScatterGatherIteration(new VertexUpdate<K>(maxIterations, convergeThreshold),
new MessageUpdate<K>(maxIterations), maxIterations, parameter)
.runScatterGatherIteration(new MessageUpdate<K>(maxIterations),
new VertexUpdate<K>(maxIterations, convergeThreshold), maxIterations, parameter)
.getVertices();
}
/**
* Distributes the value of a vertex among all neighbor vertices and sum all the
* value in every superstep.
*/
private static final class MessageUpdate<K> extends ScatterFunction<K, Tuple2<DoubleValue, DoubleValue>, Double, Boolean> {
private int maxIteration;
public MessageUpdate(int maxIteration) {
this.maxIteration = maxIteration;
}
@Override
public void sendMessages(Vertex<K, Tuple2<DoubleValue, DoubleValue>> vertex) {
// in the first iteration, no aggregation to call, init sum with value of vertex
double iterationValueSum = 1.0;
if (getSuperstepNumber() > 1) {
iterationValueSum = Math.sqrt(((DoubleValue) getPreviousIterationAggregate("updatedValueSum")).getValue());
}
for (Edge<K, Boolean> edge : getEdges()) {
if (getSuperstepNumber() != maxIteration) {
if (getSuperstepNumber() % 2 == 1) {
if (edge.getValue()) {
sendMessageTo(edge.getTarget(), vertex.getValue().f0.getValue() / iterationValueSum);
}
} else {
if (!edge.getValue()) {
sendMessageTo(edge.getTarget(), vertex.getValue().f1.getValue() / iterationValueSum);
}
}
} else {
if (!edge.getValue()) {
sendMessageTo(edge.getTarget(), iterationValueSum);
}
}
}
}
}
/**
* Function that updates the value of a vertex by summing up the partial
* values from all messages and normalize the value.
*/
@SuppressWarnings("serial")
public static final class VertexUpdate<K> extends VertexUpdateFunction<K, Tuple2<DoubleValue, DoubleValue>, Double> {
private static final class VertexUpdate<K> extends GatherFunction<K, Tuple2<DoubleValue, DoubleValue>, Double> {
private int maxIteration;
private double convergeThreshold;
private DoubleSumAggregator updatedValueSumAggregator;
......@@ -162,7 +199,7 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS
diffValueSum = ((DoubleValue) getPreviousIterationAggregate("diffValueSum")).getValue();
}
authoritySumAggregator.aggregate(previousAuthAverage);
if (diffValueSum > convergeThreshold) {
newHubValue.setValue(newHubValue.getValue() / iterationValueSum);
newAuthorityValue.setValue(updateValue);
......@@ -191,77 +228,24 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS
}
}
/**
* Distributes the value of a vertex among all neighbor vertices and sum all the
* value in every superstep.
*/
@SuppressWarnings("serial")
public static final class MessageUpdate<K> extends MessagingFunction<K, Tuple2<DoubleValue, DoubleValue>, Double, Boolean> {
private int maxIteration;
public MessageUpdate(int maxIteration) {
this.maxIteration = maxIteration;
}
@Override
public void sendMessages(Vertex<K, Tuple2<DoubleValue, DoubleValue>> vertex) {
// in the first iteration, no aggregation to call, init sum with value of vertex
double iterationValueSum = 1.0;
if (getSuperstepNumber() > 1) {
iterationValueSum = Math.sqrt(((DoubleValue) getPreviousIterationAggregate("updatedValueSum")).getValue());
}
for (Edge<K, Boolean> edge : getEdges()) {
if (getSuperstepNumber() != maxIteration) {
if (getSuperstepNumber() % 2 == 1) {
if (edge.getValue()) {
sendMessageTo(edge.getTarget(), vertex.getValue().f0.getValue() / iterationValueSum);
}
} else {
if (!edge.getValue()) {
sendMessageTo(edge.getTarget(), vertex.getValue().f1.getValue() / iterationValueSum);
}
}
} else {
if (!edge.getValue()) {
sendMessageTo(edge.getTarget(), iterationValueSum);
}
}
}
}
}
public static class VertexInitMapper<K, VV> implements MapFunction<Vertex<K, VV>, Tuple2<DoubleValue, DoubleValue>> {
private static final long serialVersionUID = 1L;
private static class VertexInitMapper<K, VV> implements MapFunction<Vertex<K, VV>, Tuple2<DoubleValue, DoubleValue>> {
private Tuple2<DoubleValue, DoubleValue> initVertexValue = new Tuple2<>(new DoubleValue(1.0), new DoubleValue(1.0));
public Tuple2<DoubleValue, DoubleValue> map(Vertex<K, VV> value) {
//init hub and authority value of each vertex
return initVertexValue;
}
}
public static class AuthorityEdgeMapper<K, EV> implements MapFunction<Edge<K, EV>, Boolean> {
private static final long serialVersionUID = 1L;
private static class AuthorityEdgeMapper<K, EV> implements MapFunction<Edge<K, EV>, Boolean> {
public Boolean map(Edge<K, EV> edge) {
// mark edge as true for authority updating
return true;
}
}
public static class HubEdgeMapper<K, EV> implements MapFunction<Edge<K, EV>, Boolean> {
private static final long serialVersionUID = 1L;
private static class HubEdgeMapper<K, EV> implements MapFunction<Edge<K, EV>, Boolean> {
public Boolean map(Edge<K, EV> edge) {
// mark edge as false for hub updating
return false;
}
......
......@@ -27,10 +27,10 @@ import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.examples.data.IncrementalSSSPData;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.spargel.ScatterGatherConfiguration;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
/**
* This example illustrates how to
......@@ -97,8 +97,8 @@ public class IncrementalSSSP implements ProgramDescription {
parameters.setOptDegrees(true);
// run the scatter-gather iteration to propagate info
Graph<Long, Double, Double> result = ssspGraph.runScatterGatherIteration(new VertexDistanceUpdater(),
new InvalidateMessenger(edgeToBeRemoved), maxIterations, parameters);
Graph<Long, Double, Double> result = ssspGraph.runScatterGatherIteration(new InvalidateMessenger(edgeToBeRemoved),
new VertexDistanceUpdater(), maxIterations, parameters);
DataSet<Vertex<Long, Double>> resultedVertices = result.getVertices();
......@@ -147,22 +147,7 @@ public class IncrementalSSSP implements ProgramDescription {
}).count() > 0;
}
public static final class VertexDistanceUpdater extends VertexUpdateFunction<Long, Double, Double> {
@Override
public void updateVertex(Vertex<Long, Double> vertex, MessageIterator<Double> inMessages) throws Exception {
if (inMessages.hasNext()) {
Long outDegree = getOutDegree() - 1;
// check if the vertex has another SP-Edge
if (outDegree <= 0) {
// set own value to infinity
setNewVertexValue(Double.MAX_VALUE);
}
}
}
}
public static final class InvalidateMessenger extends MessagingFunction<Long, Double, Double, Double> {
public static final class InvalidateMessenger extends ScatterFunction<Long, Double, Double, Double> {
private Edge<Long, Double> edgeToBeRemoved;
......@@ -190,6 +175,21 @@ public class IncrementalSSSP implements ProgramDescription {
}
}
public static final class VertexDistanceUpdater extends GatherFunction<Long, Double, Double> {
@Override
public void updateVertex(Vertex<Long, Double> vertex, MessageIterator<Double> inMessages) throws Exception {
if (inMessages.hasNext()) {
Long outDegree = getOutDegree() - 1;
// check if the vertex has another SP-Edge
if (outDegree <= 0) {
// set own value to infinity
setNewVertexValue(Double.MAX_VALUE);
}
}
}
}
// ******************************************************************************************************************
// UTIL METHODS
// ******************************************************************************************************************
......
......@@ -26,9 +26,9 @@ import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.utils.Tuple3ToEdgeMap;
/**
......@@ -62,7 +62,7 @@ public class SingleSourceShortestPaths implements ProgramDescription {
// Execute the scatter-gather iteration
Graph<Long, Double, Double> result = graph.runScatterGatherIteration(
new VertexDistanceUpdater(), new MinDistanceMessenger(), maxIterations);
new MinDistanceMessenger(), new VertexDistanceUpdater(), maxIterations);
// Extract the vertices as the result
DataSet<Vertex<Long, Double>> singleSourceShortestPaths = result.getVertices();
......@@ -102,12 +102,29 @@ public class SingleSourceShortestPaths implements ProgramDescription {
}
}
/**
* Distributes the minimum distance associated with a given vertex among all
* the target vertices summed up with the edge's value.
*/
@SuppressWarnings("serial")
private static final class MinDistanceMessenger extends ScatterFunction<Long, Double, Double, Double> {
@Override
public void sendMessages(Vertex<Long, Double> vertex) {
if (vertex.getValue() < Double.POSITIVE_INFINITY) {
for (Edge<Long, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), vertex.getValue() + edge.getValue());
}
}
}
}
/**
* Function that updates the value of a vertex by picking the minimum
* distance from all incoming messages.
*/
@SuppressWarnings("serial")
public static final class VertexDistanceUpdater extends VertexUpdateFunction<Long, Double, Double> {
private static final class VertexDistanceUpdater extends GatherFunction<Long, Double, Double> {
@Override
public void updateVertex(Vertex<Long, Double> vertex, MessageIterator<Double> inMessages) {
......@@ -126,23 +143,6 @@ public class SingleSourceShortestPaths implements ProgramDescription {
}
}
/**
* Distributes the minimum distance associated with a given vertex among all
* the target vertices summed up with the edge's value.
*/
@SuppressWarnings("serial")
public static final class MinDistanceMessenger extends MessagingFunction<Long, Double, Double, Double> {
@Override
public void sendMessages(Vertex<Long, Double> vertex) {
if (vertex.getValue() < Double.POSITIVE_INFINITY) {
for (Edge<Long, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), vertex.getValue() + edge.getValue());
}
}
}
}
// ******************************************************************************************************************
// UTIL METHODS
// ******************************************************************************************************************
......
......@@ -22,11 +22,10 @@ import org.apache.flink.api.scala._
import org.apache.flink.graph.scala._
import org.apache.flink.graph.Edge
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.graph.spargel.VertexUpdateFunction
import org.apache.flink.graph.spargel.MessageIterator
import org.apache.flink.graph.spargel.{MessageIterator, ScatterFunction, GatherFunction}
import org.apache.flink.graph.Vertex
import org.apache.flink.graph.spargel.MessagingFunction
import org.apache.flink.graph.examples.data.SingleSourceShortestPathsData
import scala.collection.JavaConversions._
import org.apache.flink.graph.scala.utils.Tuple3ToEdgeMap
......@@ -55,8 +54,8 @@ object SingleSourceShortestPaths {
val graph = Graph.fromDataSet[Long, Double, Double](edges, new InitVertices(srcVertexId), env)
// Execute the scatter-gather iteration
val result = graph.runScatterGatherIteration(new VertexDistanceUpdater,
new MinDistanceMessenger, maxIterations)
val result = graph.runScatterGatherIteration(new MinDistanceMessenger,
new VertexDistanceUpdater, maxIterations)
// Extract the vertices as the result
val singleSourceShortestPaths = result.getVertices
......@@ -86,10 +85,26 @@ object SingleSourceShortestPaths {
}
/**
* Function that updates the value of a vertex by picking the minimum
* distance from all incoming messages.
* Distributes the minimum distance associated with a given vertex among all
* the target vertices summed up with the edge's value.
*/
private final class VertexDistanceUpdater extends VertexUpdateFunction[Long, Double, Double] {
private final class MinDistanceMessenger extends
ScatterFunction[Long, Double, Double, Double] {
override def sendMessages(vertex: Vertex[Long, Double]) {
if (vertex.getValue < Double.PositiveInfinity) {
for (edge: Edge[Long, Double] <- getEdges) {
sendMessageTo(edge.getTarget, vertex.getValue + edge.getValue)
}
}
}
}
/**
* Function that updates the value of a vertex by picking the minimum
* distance from all incoming messages.
*/
private final class VertexDistanceUpdater extends GatherFunction[Long, Double, Double] {
override def updateVertex(vertex: Vertex[Long, Double], inMessages: MessageIterator[Double]) {
var minDistance = Double.MaxValue
......@@ -105,22 +120,6 @@ object SingleSourceShortestPaths {
}
}
/**
* Distributes the minimum distance associated with a given vertex among all
* the target vertices summed up with the edge's value.
*/
private final class MinDistanceMessenger extends
MessagingFunction[Long, Double, Double, Double] {
override def sendMessages(vertex: Vertex[Long, Double]) {
if (vertex.getValue < Double.PositiveInfinity) {
for (edge: Edge[Long, Double] <- getEdges) {
sendMessageTo(edge.getTarget, vertex.getValue + edge.getValue)
}
}
}
}
// ****************************************************************************
// UTIL METHODS
// ****************************************************************************
......
......@@ -111,8 +111,8 @@ public class IncrementalSSSPITCase extends MultipleProgramsTestBase {
// run the scatter gather iteration to propagate info
Graph<Long, Double, Double> result = ssspGraph.runScatterGatherIteration(
new IncrementalSSSP.VertexDistanceUpdater(),
new IncrementalSSSP.InvalidateMessenger(edgeToBeRemoved),
new IncrementalSSSP.VertexDistanceUpdater(),
IncrementalSSSPData.NUM_VERTICES, parameters);
DataSet<Vertex<Long, Double>> resultedVertices = result.getVertices();
......
......@@ -45,17 +45,15 @@ import org.apache.flink.graph.asm.translate.TranslateGraphIds;
import org.apache.flink.graph.asm.translate.TranslateVertexValues;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GSAConfiguration;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.pregel.ComputeFunction;
import org.apache.flink.graph.pregel.MessageCombiner;
import org.apache.flink.graph.pregel.VertexCentricConfiguration;
import org.apache.flink.graph.pregel.VertexCentricIteration;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.spargel.ScatterGatherConfiguration;
import org.apache.flink.graph.spargel.ScatterGatherIteration;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.utils.EdgeToTuple3Map;
import org.apache.flink.graph.utils.Tuple2ToVertexMap;
import org.apache.flink.graph.utils.Tuple3ToEdgeMap;
......@@ -1652,27 +1650,27 @@ public class Graph<K, VV, EV> {
* Runs a ScatterGather iteration on the graph.
* No configuration options are provided.
*
* @param vertexUpdateFunction the vertex update function
* @param messagingFunction the messaging function
* @param scatterFunction the scatter function
* @param gatherFunction the gather function
* @param maximumNumberOfIterations maximum number of iterations to perform
*
* @return the updated Graph after the scatter-gather iteration has converged or
* after maximumNumberOfIterations.
*/
public <M> Graph<K, VV, EV> runScatterGatherIteration(
VertexUpdateFunction<K, VV, M> vertexUpdateFunction,
MessagingFunction<K, VV, M, EV> messagingFunction,
ScatterFunction<K, VV, M, EV> scatterFunction,
org.apache.flink.graph.spargel.GatherFunction<K, VV, M> gatherFunction,
int maximumNumberOfIterations) {
return this.runScatterGatherIteration(vertexUpdateFunction, messagingFunction,
return this.runScatterGatherIteration(scatterFunction, gatherFunction,
maximumNumberOfIterations, null);
}
/**
* Runs a ScatterGather iteration on the graph with configuration options.
*
* @param vertexUpdateFunction the vertex update function
* @param messagingFunction the messaging function
*
* @param scatterFunction the scatter function
* @param gatherFunction the gather function
* @param maximumNumberOfIterations maximum number of iterations to perform
* @param parameters the iteration configuration parameters
*
......@@ -1680,12 +1678,12 @@ public class Graph<K, VV, EV> {
* after maximumNumberOfIterations.
*/
public <M> Graph<K, VV, EV> runScatterGatherIteration(
VertexUpdateFunction<K, VV, M> vertexUpdateFunction,
MessagingFunction<K, VV, M, EV> messagingFunction,
ScatterFunction<K, VV, M, EV> scatterFunction,
org.apache.flink.graph.spargel.GatherFunction<K, VV, M> gatherFunction,
int maximumNumberOfIterations, ScatterGatherConfiguration parameters) {
ScatterGatherIteration<K, VV, M, EV> iteration = ScatterGatherIteration.withEdges(
edges, vertexUpdateFunction, messagingFunction, maximumNumberOfIterations);
edges, scatterFunction, gatherFunction, maximumNumberOfIterations);
iteration.configure(parameters);
......@@ -1708,7 +1706,7 @@ public class Graph<K, VV, EV> {
* after maximumNumberOfIterations.
*/
public <M> Graph<K, VV, EV> runGatherSumApplyIteration(
GatherFunction<VV, EV, M> gatherFunction, SumFunction<VV, EV, M> sumFunction,
org.apache.flink.graph.gsa.GatherFunction gatherFunction, SumFunction<VV, EV, M> sumFunction,
ApplyFunction<K, VV, M> applyFunction, int maximumNumberOfIterations) {
return this.runGatherSumApplyIteration(gatherFunction, sumFunction, applyFunction,
......@@ -1729,7 +1727,7 @@ public class Graph<K, VV, EV> {
* after maximumNumberOfIterations.
*/
public <M> Graph<K, VV, EV> runGatherSumApplyIteration(
GatherFunction<VV, EV, M> gatherFunction, SumFunction<VV, EV, M> sumFunction,
org.apache.flink.graph.gsa.GatherFunction gatherFunction, SumFunction<VV, EV, M> sumFunction,
ApplyFunction<K, VV, M> applyFunction, int maximumNumberOfIterations,
GSAConfiguration parameters) {
......
......@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.util.Preconditions;
/**
......@@ -133,8 +134,8 @@ public abstract class IterationConfiguration {
/**
* Registers a new aggregator. Aggregators registered here are available during the execution of the vertex updates
* via {@link org.apache.flink.graph.spargel.VertexUpdateFunction#getIterationAggregator(String)} and
* {@link org.apache.flink.graph.spargel.VertexUpdateFunction#getPreviousIterationAggregate(String)}.
* via {@link GatherFunction#getIterationAggregator(String)} and
* {@link GatherFunction#getPreviousIterationAggregate(String)}.
*
* @param name The name of the aggregator, used to retrieve it and its aggregates during execution.
* @param aggregator The aggregator.
......
......@@ -26,9 +26,9 @@ import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import java.util.Map;
import java.util.TreeMap;
......@@ -73,19 +73,33 @@ public class CommunityDetection<K> implements GraphAlgorithm<K, Long, Double, Gr
public Graph<K, Long, Double> run(Graph<K, Long, Double> graph) {
DataSet<Vertex<K, Tuple2<Long, Double>>> initializedVertices = graph.getVertices()
.map(new AddScoreToVertexValuesMapper<K>());
.map(new AddScoreToVertexValuesMapper<K>());
Graph<K, Tuple2<Long, Double>, Double> graphWithScoredVertices =
Graph.fromDataSet(initializedVertices, graph.getEdges(), graph.getContext()).getUndirected();
Graph.fromDataSet(initializedVertices, graph.getEdges(), graph.getContext()).getUndirected();
return graphWithScoredVertices.runScatterGatherIteration(new VertexLabelUpdater<K>(delta),
new LabelMessenger<K>(), maxIterations)
return graphWithScoredVertices.runScatterGatherIteration(new LabelMessenger<K>(),
new VertexLabelUpdater<K>(delta), maxIterations)
.mapVertices(new RemoveScoreFromVertexValuesMapper<K>());
}
@SuppressWarnings("serial")
public static final class VertexLabelUpdater<K> extends VertexUpdateFunction<
K, Tuple2<Long, Double>, Tuple2<Long, Double>> {
public static final class LabelMessenger<K> extends ScatterFunction<K, Tuple2<Long, Double>,
Tuple2<Long, Double>, Double> {
@Override
public void sendMessages(Vertex<K, Tuple2<Long, Double>> vertex) throws Exception {
for(Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), new Tuple2<Long, Double>(vertex.getValue().f0,
vertex.getValue().f1 * edge.getValue()));
}
}
}
@SuppressWarnings("serial")
public static final class VertexLabelUpdater<K> extends GatherFunction<
K, Tuple2<Long, Double>, Tuple2<Long, Double>> {
private Double delta;
......@@ -153,20 +167,6 @@ public class CommunityDetection<K> implements GraphAlgorithm<K, Long, Double, Gr
}
}
@SuppressWarnings("serial")
public static final class LabelMessenger<K> extends MessagingFunction<K, Tuple2<Long, Double>,
Tuple2<Long, Double>, Double> {
@Override
public void sendMessages(Vertex<K, Tuple2<Long, Double>> vertex) throws Exception {
for(Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), new Tuple2<Long, Double>(vertex.getValue().f0,
vertex.getValue().f1 * edge.getValue()));
}
}
}
@SuppressWarnings("serial")
@ForwardedFields("f0")
public static final class AddScoreToVertexValuesMapper<K> implements MapFunction<
......@@ -174,7 +174,7 @@ public class CommunityDetection<K> implements GraphAlgorithm<K, Long, Double, Gr
public Vertex<K, Tuple2<Long, Double>> map(Vertex<K, Long> vertex) {
return new Vertex<K, Tuple2<Long, Double>>(
vertex.getId(), new Tuple2<Long, Double>(vertex.getValue(), 1.0));
vertex.getId(), new Tuple2<Long, Double>(vertex.getValue(), 1.0));
}
}
......
......@@ -25,9 +25,9 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.utils.NullValueEdgeMapper;
import org.apache.flink.types.NullValue;
......@@ -76,39 +76,16 @@ public class ConnectedComponents<K, VV extends Comparable<VV>, EV>
.getUndirected();
return undirectedGraph.runScatterGatherIteration(
new CCUpdater<K, VV>(),
new CCMessenger<K, VV>(valueTypeInfo),
new CCUpdater<K, VV>(),
maxIterations).getVertices();
}
/**
* Updates the value of a vertex by picking the minimum neighbor value out of all the incoming messages.
*/
public static final class CCUpdater<K, VV extends Comparable<VV>>
extends VertexUpdateFunction<K, VV, VV> {
@Override
public void updateVertex(Vertex<K, VV> vertex, MessageIterator<VV> messages) throws Exception {
VV current = vertex.getValue();
VV min = current;
for (VV msg : messages) {
if (msg.compareTo(min) < 0) {
min = msg;
}
}
if (!min.equals(current)) {
setNewVertexValue(min);
}
}
}
/**
* Sends the current vertex value to all adjacent vertices.
*/
public static final class CCMessenger<K, VV extends Comparable<VV>>
extends MessagingFunction<K, VV, VV, NullValue>
extends ScatterFunction<K, VV, VV, NullValue>
implements ResultTypeQueryable<VV> {
private final TypeInformation<VV> typeInformation;
......@@ -128,4 +105,27 @@ public class ConnectedComponents<K, VV extends Comparable<VV>, EV>
return typeInformation;
}
}
/**
* Updates the value of a vertex by picking the minimum neighbor value out of all the incoming messages.
*/
public static final class CCUpdater<K, VV extends Comparable<VV>>
extends GatherFunction<K, VV, VV> {
@Override
public void updateVertex(Vertex<K, VV> vertex, MessageIterator<VV> messages) throws Exception {
VV current = vertex.getValue();
VV min = current;
for (VV msg : messages) {
if (msg.compareTo(min) < 0) {
min = msg;
}
}
if (!min.equals(current)) {
setNewVertexValue(min);
}
}
}
}
......@@ -27,8 +27,8 @@ import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.gsa.Neighbor;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.utils.NullValueEdgeMapper;
import org.apache.flink.types.NullValue;
......
......@@ -25,8 +25,8 @@ import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.gsa.Neighbor;
import org.apache.flink.graph.gsa.SumFunction;
/**
* This is an implementation of the Single Source Shortest Paths algorithm, using a gather-sum-apply iteration
......
......@@ -25,9 +25,9 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.utils.NullValueEdgeMapper;
import org.apache.flink.types.NullValue;
......@@ -78,15 +78,38 @@ public class LabelPropagation<K, VV extends Comparable<VV>, EV>
return input
.mapEdges(new NullValueEdgeMapper<K, EV>())
.runScatterGatherIteration(
new UpdateVertexLabel<K, VV>(), new SendNewLabelToNeighbors<K, VV>(valueType), maxIterations)
new SendNewLabelToNeighbors<K, VV>(valueType), new UpdateVertexLabel<K, VV>(), maxIterations)
.getVertices();
}
/**
* Sends the vertex label to all out-neighbors
*/
public static final class SendNewLabelToNeighbors<K, VV extends Comparable<VV>>
extends ScatterFunction<K, VV, VV, NullValue>
implements ResultTypeQueryable<VV> {
private final TypeInformation<VV> typeInformation;
public SendNewLabelToNeighbors(TypeInformation<VV> typeInformation) {
this.typeInformation = typeInformation;
}
public void sendMessages(Vertex<K, VV> vertex) {
sendMessageToAllNeighbors(vertex.getValue());
}
@Override
public TypeInformation<VV> getProducedType() {
return typeInformation;
}
}
/**
* Function that updates the value of a vertex by adopting the most frequent
* label among its in-neighbors
*/
public static final class UpdateVertexLabel<K, VV extends Comparable<VV>> extends VertexUpdateFunction<K, VV, VV> {
public static final class UpdateVertexLabel<K, VV extends Comparable<VV>> extends GatherFunction<K, VV, VV> {
public void updateVertex(Vertex<K, VV> vertex, MessageIterator<VV> inMessages) {
Map<VV, Long> labelsWithFrequencies = new HashMap<VV, Long>();
......@@ -119,27 +142,4 @@ public class LabelPropagation<K, VV extends Comparable<VV>, EV>
setNewVertexValue(mostFrequentLabel);
}
}
/**
* Sends the vertex label to all out-neighbors
*/
public static final class SendNewLabelToNeighbors<K, VV extends Comparable<VV>>
extends MessagingFunction<K, VV, VV, NullValue>
implements ResultTypeQueryable<VV> {
private final TypeInformation<VV> typeInformation;
public SendNewLabelToNeighbors(TypeInformation<VV> typeInformation) {
this.typeInformation = typeInformation;
}
public void sendMessages(Vertex<K, VV> vertex) {
sendMessageToAllNeighbors(vertex.getValue());
}
@Override
public TypeInformation<VV> getProducedType() {
return typeInformation;
}
}
}
......@@ -25,10 +25,10 @@ import org.apache.flink.graph.EdgeJoinFunction;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.spargel.ScatterGatherConfiguration;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.types.LongValue;
/**
......@@ -65,18 +65,37 @@ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Ve
ScatterGatherConfiguration parameters = new ScatterGatherConfiguration();
parameters.setOptNumVertices(true);
return networkWithWeights.runScatterGatherIteration(new VertexRankUpdater<K>(beta),
new RankMessenger<K>(), maxIterations, parameters)
return networkWithWeights.runScatterGatherIteration(new RankMessenger<K>(),
new VertexRankUpdater<K>(beta), maxIterations, parameters)
.getVertices();
}
/**
* Distributes the rank of a vertex among all target vertices according to
* the transition probability, which is associated with an edge as the edge
* value.
*/
@SuppressWarnings("serial")
public static final class RankMessenger<K> extends ScatterFunction<K, Double, Double, Double> {
@Override
public void sendMessages(Vertex<K, Double> vertex) {
if (getSuperstepNumber() == 1) {
// initialize vertex ranks
vertex.setValue(1.0 / this.getNumberOfVertices());
}
for (Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), vertex.getValue() * edge.getValue());
}
}
}
/**
* Function that updates the rank of a vertex by summing up the partial
* ranks from all incoming messages and then applying the dampening formula.
*/
@SuppressWarnings("serial")
public static final class VertexRankUpdater<K> extends VertexUpdateFunction<K, Double, Double> {
public static final class VertexRankUpdater<K> extends GatherFunction<K, Double, Double> {
private final double beta;
public VertexRankUpdater(double beta) {
......@@ -96,30 +115,8 @@ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Ve
}
}
/**
* Distributes the rank of a vertex among all target vertices according to
* the transition probability, which is associated with an edge as the edge
* value.
*/
@SuppressWarnings("serial")
public static final class RankMessenger<K> extends MessagingFunction<K, Double, Double, Double> {
@Override
public void sendMessages(Vertex<K, Double> vertex) {
if (getSuperstepNumber() == 1) {
// initialize vertex ranks
vertex.setValue(1.0 / this.getNumberOfVertices());
}
for (Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), vertex.getValue() * edge.getValue());
}
}
}
@SuppressWarnings("serial")
private static final class InitWeights implements EdgeJoinFunction<Double, LongValue> {
public Double edgeJoin(Double edgeValue, LongValue inputValue) {
return edgeValue / (double) inputValue.getValue();
}
......
......@@ -24,9 +24,9 @@ import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
/**
* This is an implementation of the Single-Source-Shortest Paths algorithm, using a scatter-gather iteration.
......@@ -52,7 +52,7 @@ public class SingleSourceShortestPaths<K> implements GraphAlgorithm<K, Double, D
public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> input) {
return input.mapVertices(new InitVerticesMapper<K>(srcVertexId))
.runScatterGatherIteration(new VertexDistanceUpdater<K>(), new MinDistanceMessenger<K>(),
.runScatterGatherIteration(new MinDistanceMessenger<K>(), new VertexDistanceUpdater<K>(),
maxIterations).getVertices();
}
......@@ -73,13 +73,31 @@ public class SingleSourceShortestPaths<K> implements GraphAlgorithm<K, Double, D
}
}
/**
* Distributes the minimum distance associated with a given vertex among all
* the target vertices summed up with the edge's value.
*
* @param <K>
*/
public static final class MinDistanceMessenger<K> extends ScatterFunction<K, Double, Double, Double> {
@Override
public void sendMessages(Vertex<K, Double> vertex) {
if (vertex.getValue() < Double.POSITIVE_INFINITY) {
for (Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), vertex.getValue() + edge.getValue());
}
}
}
}
/**
* Function that updates the value of a vertex by picking the minimum
* distance from all incoming messages.
*
* @param <K>
*/
public static final class VertexDistanceUpdater<K> extends VertexUpdateFunction<K, Double, Double> {
public static final class VertexDistanceUpdater<K> extends GatherFunction<K, Double, Double> {
@Override
public void updateVertex(Vertex<K, Double> vertex,
......@@ -98,22 +116,4 @@ public class SingleSourceShortestPaths<K> implements GraphAlgorithm<K, Double, D
}
}
}
/**
* Distributes the minimum distance associated with a given vertex among all
* the target vertices summed up with the edge's value.
*
* @param <K>
*/
public static final class MinDistanceMessenger<K> extends MessagingFunction<K, Double, Double, Double> {
@Override
public void sendMessages(Vertex<K, Double> vertex) {
if (vertex.getValue() < Double.POSITIVE_INFINITY) {
for (Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), vertex.getValue() + edge.getValue());
}
}
}
}
}
......@@ -18,11 +18,11 @@
package org.apache.flink.graph.library;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
import org.apache.flink.api.java.DataSet;
......
......@@ -18,9 +18,6 @@
package org.apache.flink.graph.spargel;
import java.io.Serializable;
import java.util.Collection;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.java.tuple.Tuple3;
......@@ -28,6 +25,9 @@ import org.apache.flink.graph.Vertex;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import java.io.Serializable;
import java.util.Collection;
/**
* This class must be extended by functions that compute the state of the vertex depending on the old state and the
* incoming messages. The central method is {@link #updateVertex(Vertex, MessageIterator)}, which is
......@@ -37,7 +37,7 @@ import org.apache.flink.util.Collector;
* {@code <VV>} The vertex value type.
* {@code <Message>} The message type.
*/
public abstract class VertexUpdateFunction<K, VV, Message> implements Serializable {
public abstract class GatherFunction<K, VV, Message> implements Serializable {
private static final long serialVersionUID = 1L;
......@@ -76,11 +76,11 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
// --------------------------------------------------------------------------------------------
// Public API Methods
// --------------------------------------------------------------------------------------------
/**
* This method is invoked once per vertex per superstep. It receives the current state of the vertex, as well as
* the incoming messages. It may set a new vertex state via {@link #setNewVertexValue(Object)}. If the vertex
* state is changed, it will trigger the sending of messages via the {@link MessagingFunction}.
* state is changed, it will trigger the sending of messages via the {@link ScatterFunction}.
*
* @param vertex The vertex.
* @param inMessages The incoming messages to this vertex.
......@@ -88,21 +88,21 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
* @throws Exception The computation may throw exceptions, which causes the superstep to fail.
*/
public abstract void updateVertex(Vertex<K, VV> vertex, MessageIterator<Message> inMessages) throws Exception;
/**
* This method is executed one per superstep before the vertex update function is invoked for each vertex.
* This method is executed once per superstep before the gather function is invoked for each vertex.
*
* @throws Exception Exceptions in the pre-superstep phase cause the superstep to fail.
*/
public void preSuperstep() throws Exception {}
/**
* This method is executed one per superstep after the vertex update function has been invoked for each vertex.
* This method is executed once per superstep after the gather function has been invoked for each vertex.
*
* @throws Exception Exceptions in the post-superstep phase cause the superstep to fail.
*/
public void postSuperstep() throws Exception {}
/**
* Sets the new value of this vertex. Setting a new value triggers the sending of outgoing messages from this vertex.
*
......@@ -123,7 +123,7 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
out.collect(outVal);
}
}
/**
* Gets the number of the superstep, starting at <tt>1</tt>.
*
......@@ -132,7 +132,7 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
public int getSuperstepNumber() {
return this.runtimeContext.getSuperstepNumber();
}
/**
* Gets the iteration aggregator registered under the given name. The iteration aggregator combines
* all aggregates globally once per superstep and makes them available in the next superstep.
......@@ -143,7 +143,7 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
public <T extends Aggregator<?>> T getIterationAggregator(String name) {
return this.runtimeContext.<T>getIterationAggregator(name);
}
/**
* Get the aggregated value that an aggregator computed in the previous iteration.
*
......@@ -153,11 +153,11 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
public <T extends Value> T getPreviousIterationAggregate(String name) {
return this.runtimeContext.<T>getPreviousIterationAggregate(name);
}
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function. They can be registered via
* {@link org.apache.flink.graph.spargel.ScatterGatherConfiguration#addBroadcastSetForUpdateFunction(String, org.apache.flink.api.java.DataSet)}.
* {@link org.apache.flink.graph.spargel.ScatterGatherConfiguration#addBroadcastSetForGatherFunction(String, org.apache.flink.api.java.DataSet)}.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
......@@ -165,15 +165,15 @@ public abstract class VertexUpdateFunction<K, VV, Message> implements Serializab
public <T> Collection<T> getBroadcastSet(String name) {
return this.runtimeContext.<T>getBroadcastVariable(name);
}
// --------------------------------------------------------------------------------------------
// internal methods
// --------------------------------------------------------------------------------------------
private IterationRuntimeContext runtimeContext;
private Collector<Vertex<K, VV>> out;
private Collector<Vertex<K, Tuple3<VV, Long, Long>>> outWithDegrees;
private Vertex<K, VV> outVal;
......
......@@ -18,10 +18,6 @@
package org.apache.flink.graph.spargel;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.java.tuple.Tuple;
......@@ -32,6 +28,10 @@ import org.apache.flink.graph.Vertex;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
/**
* The base class for functions that produce messages between vertices as a part of a {@link ScatterGatherIteration}.
*
......@@ -40,7 +40,7 @@ import org.apache.flink.util.Collector;
* @param <Message> The type of the message sent between vertices along the edges.
* @param <EV> The type of the values that are associated with the edges.
*/
public abstract class MessagingFunction<K, VV, Message, EV> implements Serializable {
public abstract class ScatterFunction<K, VV, Message, EV> implements Serializable {
private static final long serialVersionUID = 1L;
......@@ -96,22 +96,22 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
* @throws Exception The computation may throw exceptions, which causes the superstep to fail.
*/
public abstract void sendMessages(Vertex<K, VV> vertex) throws Exception;
/**
* This method is executed once per superstep before the vertex update function is invoked for each vertex.
* This method is executed once per superstep before the scatter function is invoked for each vertex.
*
* @throws Exception Exceptions in the pre-superstep phase cause the superstep to fail.
*/
public void preSuperstep() throws Exception {}
/**
* This method is executed once per superstep after the vertex update function has been invoked for each vertex.
* This method is executed once per superstep after the scatter function has been invoked for each vertex.
*
* @throws Exception Exceptions in the post-superstep phase cause the superstep to fail.
*/
public void postSuperstep() throws Exception {}
/**
* Gets an {@link java.lang.Iterable} with all edges. This method is mutually exclusive with
* {@link #sendMessageToAllNeighbors(Object)} and may be called only once.
......@@ -147,17 +147,17 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
throw new IllegalStateException("Can use either 'getEdges()' or 'sendMessageToAllNeighbors()'"
+ "exactly once.");
}
edgesUsed = true;
outValue.f1 = m;
while (edges.hasNext()) {
Tuple next = (Tuple) edges.next();
/*
* When EdgeDirection is OUT, the edges iterator only has the out-edges
* of the vertex, i.e. the ones where this vertex is src.
* next.getField(1) gives the neighbor of the vertex running this MessagingFunction.
* next.getField(1) gives the neighbor of the vertex running this ScatterFunction.
*/
if (getDirection().equals(EdgeDirection.OUT)) {
outValue.f0 = next.getField(1);
......@@ -165,7 +165,7 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
/*
* When EdgeDirection is IN, the edges iterator only has the in-edges
* of the vertex, i.e. the ones where this vertex is trg.
* next.getField(10) gives the neighbor of the vertex running this MessagingFunction.
* next.getField(10) gives the neighbor of the vertex running this ScatterFunction.
*/
else if (getDirection().equals(EdgeDirection.IN)) {
outValue.f0 = next.getField(0);
......@@ -184,7 +184,7 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
out.collect(outValue);
}
}
/**
* Sends the given message to the vertex identified by the given key. If the target vertex does not exist,
* the next superstep will cause an exception due to a non-deliverable message.
......@@ -199,7 +199,7 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
}
// --------------------------------------------------------------------------------------------
/**
* Gets the number of the superstep, starting at <tt>1</tt>.
*
......@@ -208,7 +208,7 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
public int getSuperstepNumber() {
return this.runtimeContext.getSuperstepNumber();
}
/**
* Gets the iteration aggregator registered under the given name. The iteration aggregator combines
* all aggregates globally once per superstep and makes them available in the next superstep.
......@@ -219,7 +219,7 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
public <T extends Aggregator<?>> T getIterationAggregator(String name) {
return this.runtimeContext.<T>getIterationAggregator(name);
}
/**
* Get the aggregated value that an aggregator computed in the previous iteration.
*
......@@ -229,11 +229,11 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
public <T extends Value> T getPreviousIterationAggregate(String name) {
return this.runtimeContext.<T>getPreviousIterationAggregate(name);
}
/**
* Gets the broadcast data set registered under the given name. Broadcast data sets
* are available on all parallel instances of a function. They can be registered via
* {@link org.apache.flink.graph.spargel.ScatterGatherConfiguration#addBroadcastSetForMessagingFunction(String, org.apache.flink.api.java.DataSet)}.
* {@link org.apache.flink.graph.spargel.ScatterGatherConfiguration#addBroadcastSetForScatterFunction(String, org.apache.flink.api.java.DataSet)}.
*
* @param name The name under which the broadcast set is registered.
* @return The broadcast data set.
......@@ -245,49 +245,49 @@ public abstract class MessagingFunction<K, VV, Message, EV> implements Serializa
// --------------------------------------------------------------------------------------------
// internal methods and state
// --------------------------------------------------------------------------------------------
private Tuple2<K, Message> outValue;
private IterationRuntimeContext runtimeContext;
private Iterator<?> edges;
private Collector<Tuple2<K, Message>> out;
private K vertexId;
private EdgesIterator<K, EV> edgeIterator;
private boolean edgesUsed;
private long inDegree = -1;
private long outDegree = -1;
void init(IterationRuntimeContext context) {
this.runtimeContext = context;
this.outValue = new Tuple2<K, Message>();
this.edgeIterator = new EdgesIterator<K, EV>();
}
void set(Iterator<?> edges, Collector<Tuple2<K, Message>> out, K id) {
this.edges = edges;
this.out = out;
this.vertexId = id;
this.edgesUsed = false;
}
private static final class EdgesIterator<K, EV>
implements Iterator<Edge<K, EV>>, Iterable<Edge<K, EV>>
{
private Iterator<Edge<K, EV>> input;
private Edge<K, EV> edge = new Edge<K, EV>();
void set(Iterator<Edge<K, EV>> input) {
this.input = input;
}
@Override
public boolean hasNext() {
return input.hasNext();
......
......@@ -29,20 +29,20 @@ import java.util.List;
/**
* A ScatterGatherConfiguration object can be used to set the iteration name and
* degree of parallelism, to register aggregators and use broadcast sets in
* the {@link org.apache.flink.graph.spargel.VertexUpdateFunction} and {@link org.apache.flink.graph.spargel.MessagingFunction}
* the {@link GatherFunction} and {@link ScatterFunction}
*
* The VertexCentricConfiguration object is passed as an argument to
* {@link org.apache.flink.graph.Graph#runScatterGatherIteration (
* org.apache.flink.graph.spargel.VertexUpdateFunction, org.apache.flink.graph.spargel.MessagingFunction, int,
* org.apache.flink.graph.spargel.GatherFunction, org.apache.flink.graph.spargel.ScatterFunction, int,
* ScatterGatherConfiguration)}.
*/
public class ScatterGatherConfiguration extends IterationConfiguration {
/** the broadcast variables for the update function **/
private List<Tuple2<String, DataSet<?>>> bcVarsUpdate = new ArrayList<Tuple2<String,DataSet<?>>>();
/** the broadcast variables for the scatter function **/
private List<Tuple2<String, DataSet<?>>> bcVarsScatter = new ArrayList<>();
/** the broadcast variables for the messaging function **/
private List<Tuple2<String, DataSet<?>>> bcVarsMessaging = new ArrayList<Tuple2<String,DataSet<?>>>();
/** the broadcast variables for the gather function **/
private List<Tuple2<String, DataSet<?>>> bcVarsGather = new ArrayList<>();
/** flag that defines whether the degrees option is set **/
private boolean optDegrees = false;
......@@ -53,43 +53,43 @@ public class ScatterGatherConfiguration extends IterationConfiguration {
public ScatterGatherConfiguration() {}
/**
* Adds a data set as a broadcast set to the messaging function.
* Adds a data set as a broadcast set to the scatter function.
*
* @param name The name under which the broadcast data is available in the messaging function.
* @param name The name under which the broadcast data is available in the scatter function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForMessagingFunction(String name, DataSet<?> data) {
this.bcVarsMessaging.add(new Tuple2<String, DataSet<?>>(name, data));
public void addBroadcastSetForScatterFunction(String name, DataSet<?> data) {
this.bcVarsScatter.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Adds a data set as a broadcast set to the vertex update function.
* Adds a data set as a broadcast set to the gather function.
*
* @param name The name under which the broadcast data is available in the vertex update function.
* @param name The name under which the broadcast data is available in the gather function.
* @param data The data set to be broadcasted.
*/
public void addBroadcastSetForUpdateFunction(String name, DataSet<?> data) {
this.bcVarsUpdate.add(new Tuple2<String, DataSet<?>>(name, data));
public void addBroadcastSetForGatherFunction(String name, DataSet<?> data) {
this.bcVarsGather.add(new Tuple2<String, DataSet<?>>(name, data));
}
/**
* Get the broadcast variables of the VertexUpdateFunction.
* Get the broadcast variables of the ScatterFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getUpdateBcastVars() {
return this.bcVarsUpdate;
public List<Tuple2<String, DataSet<?>>> getScatterBcastVars() {
return this.bcVarsScatter;
}
/**
* Get the broadcast variables of the MessagingFunction.
* Get the broadcast variables of the GatherFunction.
*
* @return a List of Tuple2, where the first field is the broadcast variable name
* and the second field is the broadcast DataSet.
*/
public List<Tuple2<String, DataSet<?>>> getMessagingBcastVars() {
return this.bcVarsMessaging;
public List<Tuple2<String, DataSet<?>>> getGatherBcastVars() {
return this.bcVarsGather;
}
/**
......@@ -113,7 +113,7 @@ public class ScatterGatherConfiguration extends IterationConfiguration {
}
/**
* Gets the direction in which messages are sent in the MessagingFunction.
* Gets the direction in which messages are sent in the ScatterFunction.
* By default the messaging direction is OUT.
*
* @return an EdgeDirection, which can be either IN, OUT or ALL.
......@@ -123,7 +123,7 @@ public class ScatterGatherConfiguration extends IterationConfiguration {
}
/**
* Sets the direction in which messages are sent in the MessagingFunction.
* Sets the direction in which messages are sent in the ScatterFunction.
* By default the messaging direction is OUT.
*
* @param direction - IN, OUT or ALL
......
......@@ -75,8 +75,8 @@ public class SpargelCompilerTest extends CompilerTestBase {
Graph<Long, Long, NullValue> graph = Graph.fromDataSet(initialVertices, edges, env);
DataSet<Vertex<Long, Long>> result = graph.runScatterGatherIteration(
new ConnectedComponents.CCUpdater<Long, Long>(),
new ConnectedComponents.CCMessenger<Long, Long>(BasicTypeInfo.LONG_TYPE_INFO), 100)
new ConnectedComponents.CCMessenger<Long, Long>(BasicTypeInfo.LONG_TYPE_INFO),
new ConnectedComponents.CCUpdater<Long, Long>(), 100)
.getVertices();
result.output(new DiscardingOutputFormat<Vertex<Long, Long>>());
......@@ -157,12 +157,12 @@ public class SpargelCompilerTest extends CompilerTestBase {
Graph<Long, Long, NullValue> graph = Graph.fromDataSet(initialVertices, edges, env);
ScatterGatherConfiguration parameters = new ScatterGatherConfiguration();
parameters.addBroadcastSetForMessagingFunction(BC_VAR_NAME, bcVar);
parameters.addBroadcastSetForUpdateFunction(BC_VAR_NAME, bcVar);
parameters.addBroadcastSetForScatterFunction(BC_VAR_NAME, bcVar);
parameters.addBroadcastSetForGatherFunction(BC_VAR_NAME, bcVar);
DataSet<Vertex<Long, Long>> result = graph.runScatterGatherIteration(
new ConnectedComponents.CCUpdater<Long, Long>(),
new ConnectedComponents.CCMessenger<Long, Long>(BasicTypeInfo.LONG_TYPE_INFO), 100)
new ConnectedComponents.CCMessenger<Long, Long>(BasicTypeInfo.LONG_TYPE_INFO),
new ConnectedComponents.CCUpdater<Long, Long>(), 100)
.getVertices();
result.output(new DiscardingOutputFormat<Vertex<Long, Long>>());
......
......@@ -16,28 +16,27 @@
* limitations under the License.
*/
package org.apache.flink.graph.spargel;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.DeltaIterationResultSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.TwoInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.types.NullValue;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@SuppressWarnings("serial")
public class SpargelTranslationTest {
......@@ -46,28 +45,28 @@ public class SpargelTranslationTest {
public void testTranslationPlainEdges() {
try {
final String ITERATION_NAME = "Test Name";
final String AGGREGATOR_NAME = "AggregatorName";
final String BC_SET_MESSAGES_NAME = "borat messages";
final String BC_SET_UPDATES_NAME = "borat updates";
final int NUM_ITERATIONS = 13;
final int ITERATION_parallelism = 77;
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Long> bcMessaging = env.fromElements(1L);
DataSet<Long> bcUpdate = env.fromElements(1L);
DataSet<Vertex<String, Double>> result;
// ------------ construct the test program ------------------
{
DataSet<Tuple2<String, Double>> initialVertices = env.fromElements(new Tuple2<>("abc", 3.44));
DataSet<Tuple2<String, String>> edges = env.fromElements(new Tuple2<>("a", "c"));
......@@ -83,41 +82,41 @@ public class SpargelTranslationTest {
ScatterGatherConfiguration parameters = new ScatterGatherConfiguration();
parameters.addBroadcastSetForMessagingFunction(BC_SET_MESSAGES_NAME, bcMessaging);
parameters.addBroadcastSetForUpdateFunction(BC_SET_UPDATES_NAME, bcUpdate);
parameters.addBroadcastSetForScatterFunction(BC_SET_MESSAGES_NAME, bcMessaging);
parameters.addBroadcastSetForGatherFunction(BC_SET_UPDATES_NAME, bcUpdate);
parameters.setName(ITERATION_NAME);
parameters.setParallelism(ITERATION_parallelism);
parameters.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
result = graph.runScatterGatherIteration(new UpdateFunction(), new MessageFunctionNoEdgeValue(),
result = graph.runScatterGatherIteration(new MessageFunctionNoEdgeValue(), new UpdateFunction(),
NUM_ITERATIONS, parameters).getVertices();
result.output(new DiscardingOutputFormat<Vertex<String, Double>>());
}
// ------------- validate the java program ----------------
assertTrue(result instanceof DeltaIterationResultSet);
DeltaIterationResultSet<?, ?> resultSet = (DeltaIterationResultSet<?, ?>) result;
DeltaIteration<?, ?> iteration = resultSet.getIterationHead();
// check the basic iteration properties
assertEquals(NUM_ITERATIONS, resultSet.getMaxIterations());
assertArrayEquals(new int[] {0}, resultSet.getKeyPositions());
assertEquals(ITERATION_parallelism, iteration.getParallelism());
assertEquals(ITERATION_NAME, iteration.getName());
assertEquals(AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
// validate that the semantic properties are set as they should
TwoInputUdfOperator<?, ?, ?, ?> solutionSetJoin = (TwoInputUdfOperator<?, ?, ?, ?>) resultSet.getNextWorkset();
assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(0, 0).contains(0));
assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(1, 0).contains(0));
TwoInputUdfOperator<?, ?, ?, ?> edgesJoin = (TwoInputUdfOperator<?, ?, ?, ?>) solutionSetJoin.getInput1();
// validate that the broadcast sets are forwarded
assertEquals(bcUpdate, solutionSetJoin.getBroadcastSets().get(BC_SET_UPDATES_NAME));
assertEquals(bcMessaging, edgesJoin.getBroadcastSets().get(BC_SET_MESSAGES_NAME));
......@@ -128,29 +127,29 @@ public class SpargelTranslationTest {
fail(e.getMessage());
}
}
@Test
public void testTranslationPlainEdgesWithForkedBroadcastVariable() {
try {
final String ITERATION_NAME = "Test Name";
final String AGGREGATOR_NAME = "AggregatorName";
final String BC_SET_MESSAGES_NAME = "borat messages";
final String BC_SET_UPDATES_NAME = "borat updates";
final int NUM_ITERATIONS = 13;
final int ITERATION_parallelism = 77;
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Long> bcVar = env.fromElements(1L);
DataSet<Vertex<String, Double>> result;
// ------------ construct the test program ------------------
{
......@@ -169,41 +168,41 @@ public class SpargelTranslationTest {
ScatterGatherConfiguration parameters = new ScatterGatherConfiguration();
parameters.addBroadcastSetForMessagingFunction(BC_SET_MESSAGES_NAME, bcVar);
parameters.addBroadcastSetForUpdateFunction(BC_SET_UPDATES_NAME, bcVar);
parameters.addBroadcastSetForScatterFunction(BC_SET_MESSAGES_NAME, bcVar);
parameters.addBroadcastSetForGatherFunction(BC_SET_UPDATES_NAME, bcVar);
parameters.setName(ITERATION_NAME);
parameters.setParallelism(ITERATION_parallelism);
parameters.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
result = graph.runScatterGatherIteration(new UpdateFunction(), new MessageFunctionNoEdgeValue(),
result = graph.runScatterGatherIteration(new MessageFunctionNoEdgeValue(), new UpdateFunction(),
NUM_ITERATIONS, parameters).getVertices();
result.output(new DiscardingOutputFormat<Vertex<String, Double>>());
}
// ------------- validate the java program ----------------
assertTrue(result instanceof DeltaIterationResultSet);
DeltaIterationResultSet<?, ?> resultSet = (DeltaIterationResultSet<?, ?>) result;
DeltaIteration<?, ?> iteration = resultSet.getIterationHead();
// check the basic iteration properties
assertEquals(NUM_ITERATIONS, resultSet.getMaxIterations());
assertArrayEquals(new int[] {0}, resultSet.getKeyPositions());
assertEquals(ITERATION_parallelism, iteration.getParallelism());
assertEquals(ITERATION_NAME, iteration.getName());
assertEquals(AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
// validate that the semantic properties are set as they should
TwoInputUdfOperator<?, ?, ?, ?> solutionSetJoin = (TwoInputUdfOperator<?, ?, ?, ?>) resultSet.getNextWorkset();
assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(0, 0).contains(0));
assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(1, 0).contains(0));
TwoInputUdfOperator<?, ?, ?, ?> edgesJoin = (TwoInputUdfOperator<?, ?, ?, ?>) solutionSetJoin.getInput1();
// validate that the broadcast sets are forwarded
assertEquals(bcVar, solutionSetJoin.getBroadcastSets().get(BC_SET_UPDATES_NAME));
assertEquals(bcVar, edgesJoin.getBroadcastSets().get(BC_SET_MESSAGES_NAME));
......@@ -214,18 +213,18 @@ public class SpargelTranslationTest {
fail(e.getMessage());
}
}
// --------------------------------------------------------------------------------------------
public static class UpdateFunction extends VertexUpdateFunction<String, Double, Long> {
private static class MessageFunctionNoEdgeValue extends ScatterFunction<String, Double, Long, NullValue> {
@Override
public void updateVertex(Vertex<String, Double> vertex, MessageIterator<Long> inMessages) {}
public void sendMessages(Vertex<String, Double> vertex) {}
}
public static class MessageFunctionNoEdgeValue extends MessagingFunction<String, Double, Long, NullValue> {
private static class UpdateFunction extends GatherFunction<String, Double, Long> {
@Override
public void sendMessages(Vertex<String, Double> vertex) {}
public void updateVertex(Vertex<String, Double> vertex, MessageIterator<Long> inMessages) {}
}
}
......@@ -24,9 +24,9 @@ import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.GatherFunction;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.spargel.ScatterFunction;
import org.apache.flink.graph.utils.VertexToTuple2Map;
import org.junit.Assert;
import org.junit.Test;
......@@ -36,47 +36,46 @@ public class CollectionModeSuperstepITCase {
/**
* Dummy iteration to test that the supersteps are correctly incremented
* and can be retrieved from inside the updated and messaging functions.
* and can be retrieved from inside the scatter and gather functions.
* All vertices start with value 1 and increase their value by 1
* in each iteration.
*/
@Test
public void testProgram() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment();
Graph<Long, Long, Long> graph = Graph.fromCollection(TestGraphUtils.getLongLongVertices(),
Graph<Long, Long, Long> graph = Graph.fromCollection(TestGraphUtils.getLongLongVertices(),
TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper());
Graph<Long, Long, Long> result = graph.runScatterGatherIteration(
new UpdateFunction(), new MessageFunction(), 10);
new MessageFunction(), new UpdateFunction(), 10);
result.getVertices().map(
new VertexToTuple2Map<Long, Long>()).output(
new DiscardingOutputFormat<Tuple2<Long, Long>>());
env.execute();
}
public static final class UpdateFunction extends VertexUpdateFunction<Long, Long, Long> {
private static final class MessageFunction extends ScatterFunction<Long, Long, Long, Long> {
@Override
public void updateVertex(Vertex<Long, Long> vertex, MessageIterator<Long> inMessages) {
public void sendMessages(Vertex<Long, Long> vertex) {
long superstep = getSuperstepNumber();
Assert.assertEquals(true, vertex.getValue() == superstep);
setNewVertexValue(vertex.getValue() + 1);
//send message to keep vertices active
sendMessageToAllNeighbors(vertex.getValue());
}
}
public static final class MessageFunction extends MessagingFunction<Long, Long, Long, Long> {
private static final class UpdateFunction extends GatherFunction<Long, Long, Long> {
@Override
public void sendMessages(Vertex<Long, Long> vertex) {
public void updateVertex(Vertex<Long, Long> vertex, MessageIterator<Long> inMessages) {
long superstep = getSuperstepNumber();
Assert.assertEquals(true, vertex.getValue() == superstep);
//send message to keep vertices active
sendMessageToAllNeighbors(vertex.getValue());
setNewVertexValue(vertex.getValue() + 1);
}
}
public static final class AssignOneMapper implements MapFunction<Vertex<Long, Long>, Long> {
private static final class AssignOneMapper implements MapFunction<Vertex<Long, Long>, Long> {
public Long map(Vertex<Long, Long> value) {
return 1L;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册