提交 e64237d8 编写于 作者: G Greg Hogan

[FLINK-3879] [gelly] Native implementation of HITS algorithm

This closes #1967
上级 40749ddc
......@@ -2042,14 +2042,13 @@ computes two interdependent scores for every vertex in a directed graph. Good hu
good authorities and good authorities are those pointed to by many good hubs.
#### Details
HITS ranking relies on an iterative method converging to a stationary solution. Each vertex in the directed graph is assigned same non-negative
hub and authority scores. Then the algorithm iteratively updates the scores until termination. Current implementation divides the iteration
into two phases, authority scores can be computed until hub scores updating and normalising finished, hub scores can be computed until
authority scores updating and normalising finished.
Every vertex is assigned the same initial hub and authority scores. The algorithm then iteratively updates the scores
until termination. During each iteration new hub scores are computed from the authority scores, then new authority
scores are computed from the new hub scores. The scores are then normalized and optionally tested for convergence.
#### Usage
The algorithm takes a directed graph as input and outputs a `DataSet` of vertices, where the vertex value is a `Tuple2`
containing the hub and authority score after maximum iterations.
The algorithm takes a directed graph as input and outputs a `DataSet` of `Tuple3` containing the vertex ID, hub score,
and authority score.
### Summarization
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.examples;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.commons.lang3.text.WordUtils;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.CsvOutputFormat;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.asm.simple.directed.Simplify;
import org.apache.flink.graph.asm.translate.LongValueToIntValue;
import org.apache.flink.graph.asm.translate.TranslateGraphIds;
import org.apache.flink.graph.generator.RMatGraph;
import org.apache.flink.graph.generator.random.JDKRandomGeneratorFactory;
import org.apache.flink.graph.generator.random.RandomGenerableFactory;
import org.apache.flink.graph.library.link_analysis.HITS.Result;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.NullValue;
import java.text.NumberFormat;
/**
* Driver for the library implementation of HITS (Hubs and Authorities).
*
* This example reads a simple, undirected graph from a CSV file or generates
* an undirected RMat graph with the given scale and edge factor then calculates
* hub and authority scores for each vertex.
*
* @see org.apache.flink.graph.library.link_analysis.HITS
*/
public class HITS {
public static final int DEFAULT_ITERATIONS = 10;
public static final int DEFAULT_SCALE = 10;
public static final int DEFAULT_EDGE_FACTOR = 16;
private static void printUsage() {
System.out.println(WordUtils.wrap("", 80));
System.out.println();
System.out.println(WordUtils.wrap("", 80));
System.out.println();
System.out.println("usage: HITS --input <csv | rmat [options]> --output <print | hash | csv [options]");
System.out.println();
System.out.println("options:");
System.out.println(" --input csv --input_filename FILENAME [--input_line_delimiter LINE_DELIMITER] [--input_field_delimiter FIELD_DELIMITER]");
System.out.println(" --input rmat [--scale SCALE] [--edge_factor EDGE_FACTOR]");
System.out.println();
System.out.println(" --output print");
System.out.println(" --output hash");
System.out.println(" --output csv --output_filename FILENAME [--output_line_delimiter LINE_DELIMITER] [--output_field_delimiter FIELD_DELIMITER]");
}
public static void main(String[] args) throws Exception {
// Set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.getConfig().enableObjectReuse();
ParameterTool parameters = ParameterTool.fromArgs(args);
int iterations = parameters.getInt("iterations", DEFAULT_ITERATIONS);
DataSet hits;
switch (parameters.get("input", "")) {
case "csv": {
String lineDelimiter = StringEscapeUtils.unescapeJava(
parameters.get("input_line_delimiter", CsvOutputFormat.DEFAULT_LINE_DELIMITER));
String fieldDelimiter = StringEscapeUtils.unescapeJava(
parameters.get("input_field_delimiter", CsvOutputFormat.DEFAULT_FIELD_DELIMITER));
Graph<LongValue, NullValue, NullValue> graph = Graph
.fromCsvReader(parameters.get("input_filename"), env)
.ignoreCommentsEdges("#")
.lineDelimiterEdges(lineDelimiter)
.fieldDelimiterEdges(fieldDelimiter)
.keyType(LongValue.class);
hits = graph
.run(new org.apache.flink.graph.library.link_analysis.HITS<LongValue, NullValue, NullValue>(iterations));
} break;
case "rmat": {
int scale = parameters.getInt("scale", DEFAULT_SCALE);
int edgeFactor = parameters.getInt("edge_factor", DEFAULT_EDGE_FACTOR);
RandomGenerableFactory<JDKRandomGenerator> rnd = new JDKRandomGeneratorFactory();
long vertexCount = 1L << scale;
long edgeCount = vertexCount * edgeFactor;
Graph<LongValue, NullValue, NullValue> graph = new RMatGraph<>(env, rnd, vertexCount, edgeCount)
.generate();
if (scale > 32) {
hits = graph
.run(new Simplify<LongValue, NullValue, NullValue>())
.run(new org.apache.flink.graph.library.link_analysis.HITS<LongValue, NullValue, NullValue>(iterations));
} else {
hits = graph
.run(new TranslateGraphIds<LongValue, IntValue, NullValue, NullValue>(new LongValueToIntValue()))
.run(new Simplify<IntValue, NullValue, NullValue>())
.run(new org.apache.flink.graph.library.link_analysis.HITS<IntValue, NullValue, NullValue>(iterations));
}
} break;
default:
printUsage();
return;
}
switch (parameters.get("output", "")) {
case "print":
for (Object e: hits.collect()) {
System.out.println(((Result)e).toVerboseString());
}
break;
case "hash":
System.out.println(DataSetUtils.checksumHashCode(hits));
break;
case "csv":
String filename = parameters.get("output_filename");
String lineDelimiter = StringEscapeUtils.unescapeJava(
parameters.get("output_line_delimiter", CsvOutputFormat.DEFAULT_LINE_DELIMITER));
String fieldDelimiter = StringEscapeUtils.unescapeJava(
parameters.get("output_field_delimiter", CsvOutputFormat.DEFAULT_FIELD_DELIMITER));
hits.writeAsCsv(filename, lineDelimiter, fieldDelimiter);
env.execute();
break;
default:
printUsage();
return;
}
JobExecutionResult result = env.getLastJobExecutionResult();
NumberFormat nf = NumberFormat.getInstance();
System.out.println("Execution runtime: " + nf.format(result.getNetRuntime()) + " ms");
}
}
......@@ -16,7 +16,7 @@
* limitations under the License.
*/
package org.apache.flink.graph.library;
package org.apache.flink.graph.examples;
import org.apache.flink.api.common.aggregators.DoubleSumAggregator;
import org.apache.flink.api.common.functions.MapFunction;
......@@ -49,7 +49,7 @@ import org.apache.flink.util.Preconditions;
public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataSet<Vertex<K, Tuple2<DoubleValue, DoubleValue>>>> {
private final static int MAXIMUMITERATION = (Integer.MAX_VALUE - 1) / 2;
private final static double MINIMUMTHRESHOLD = 1e-9;
private final static double MINIMUMTHRESHOLD = Double.MIN_VALUE;
private int maxIterations;
private double convergeThreshold;
......@@ -179,7 +179,7 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS
double previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / getNumberOfVertices();
// count the diff value of sum of authority scores
diffSumAggregator.aggregate((previousAuthAverage - newAuthorityValue.getValue()));
diffSumAggregator.aggregate(previousAuthAverage - newAuthorityValue.getValue());
}
setNewVertexValue(new Tuple2<>(newHubValue, newAuthorityValue));
} else if (getSuperstepNumber() == maxIteration) {
......
......@@ -119,15 +119,16 @@ public class JaccardIndex {
boolean clipAndFlip = parameters.getBoolean("clip_and_flip", DEFAULT_CLIP_AND_FLIP);
Graph<LongValue, NullValue, NullValue> graph = new RMatGraph<>(env, rnd, vertexCount, edgeCount)
.generate()
.run(new Simplify<LongValue, NullValue, NullValue>(clipAndFlip));
.generate();
if (scale > 32) {
ji = graph
.run(new Simplify<LongValue, NullValue, NullValue>(clipAndFlip))
.run(new org.apache.flink.graph.library.similarity.JaccardIndex<LongValue, NullValue, NullValue>());
} else {
ji = graph
.run(new TranslateGraphIds<LongValue, IntValue, NullValue, NullValue>(new LongValueToIntValue()))
.run(new Simplify<IntValue, NullValue, NullValue>(clipAndFlip))
.run(new org.apache.flink.graph.library.similarity.JaccardIndex<IntValue, NullValue, NullValue>());
}
} break;
......
......@@ -22,6 +22,7 @@ import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.examples.HITSAlgorithm;
import org.apache.flink.graph.examples.data.HITSData;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.types.DoubleValue;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.library.link_analysis;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.aggregators.DoubleSumAggregator;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsFirst;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSecond;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.utils.Murmur3_32;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import java.util.Collection;
import static org.apache.flink.api.common.ExecutionConfig.PARALLELISM_DEFAULT;
/**
* http://www.cs.cornell.edu/home/kleinber/auth.pdf
*
* Hyperlink-Induced Topic Search computes two interdependent scores for every
* vertex in a directed graph. A good "hub" links to good "authorities" and
* good "authorities" are linked from good "hubs".
*
* This algorithm can be configured to terminate either by a limit on the number
* of iterations, a convergence threshold, or both.
*
* @param <K> graph ID type
* @param <VV> vertex value type
* @param <EV> edge value type
*/
public class HITS<K, VV, EV>
implements GraphAlgorithm<K, VV, EV, DataSet<HITS.Result<K>>> {
private static final String CHANGE_IN_SCORES = "change in scores";
private static final String HUBBINESS_SUM_SQUARED = "hubbiness sum squared";
private static final String AUTHORITY_SUM_SQUARED = "authority sum squared";
// Required configuration
private int maxIterations;
private double convergenceThreshold;
// Optional configuration
private int parallelism = PARALLELISM_DEFAULT;
/**
* Hyperlink-Induced Topic Search with a fixed number of iterations.
*
* @param iterations fixed number of iterations
*/
public HITS(int iterations) {
this(iterations, Double.MAX_VALUE);
}
/**
* Hyperlink-Induced Topic Search with a convergence threshold. The algorithm
* terminates When the total change in hub and authority scores over all
* vertices falls to or below the given threshold value.
*
* @param convergenceThreshold convergence threshold for sum of scores
*/
public HITS(double convergenceThreshold) {
this(Integer.MAX_VALUE, convergenceThreshold);
}
/**
* Hyperlink-Induced Topic Search with a convergence threshold and a maximum
* iteration count. The algorithm terminates after either the given number
* of iterations or when the total change in hub and authority scores over all
* vertices falls to or below the given threshold value.
*
* @param maxIterations maximum number of iterations
* @param convergenceThreshold convergence threshold for sum of scores
*/
public HITS(int maxIterations, double convergenceThreshold) {
Preconditions.checkArgument(maxIterations > 0, "Number of iterations must be greater than zero");
Preconditions.checkArgument(convergenceThreshold > 0.0, "Convergence threshold must be greater than zero");
this.maxIterations = maxIterations;
this.convergenceThreshold = convergenceThreshold;
}
/**
* Override the operator parallelism.
*
* @param parallelism operator parallelism
* @return this
*/
public HITS<K, VV, EV> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
@Override
public DataSet<Result<K>> run(Graph<K, VV, EV> input)
throws Exception {
DataSet<Tuple2<K, K>> edges = input
.getEdges()
.flatMap(new ExtractEdgeIDs<K, EV>())
.setParallelism(parallelism)
.name("Extract edge IDs");
// ID, hub, authority
DataSet<Tuple3<K, DoubleValue, DoubleValue>> initialScores = edges
.map(new InitializeScores<K>())
.setParallelism(parallelism)
.name("Initial scores")
.groupBy(0)
.reduce(new SumScores<K>())
.setParallelism(parallelism)
.name("Sum");
IterativeDataSet<Tuple3<K, DoubleValue, DoubleValue>> iterative = initialScores
.iterate(maxIterations);
// ID, hubbiness
DataSet<Tuple2<K, DoubleValue>> hubbiness = iterative
.coGroup(edges)
.where(0)
.equalTo(1)
.with(new Hubbiness<K>())
.setParallelism(parallelism)
.name("Hub")
.groupBy(0)
.reduce(new SumScore<K>())
.setParallelism(parallelism)
.name("Sum");
// sum-of-hubbiness-squared
DataSet<DoubleValue> hubbinessSumSquared = hubbiness
.map(new Square<K>())
.setParallelism(parallelism)
.name("Square")
.reduce(new Sum())
.setParallelism(parallelism)
.name("Sum");
// ID, new authority
DataSet<Tuple2<K, DoubleValue>> authority = hubbiness
.coGroup(edges)
.where(0)
.equalTo(0)
.with(new Authority<K>())
.setParallelism(parallelism)
.name("Authority")
.groupBy(0)
.reduce(new SumScore<K>())
.setParallelism(parallelism)
.name("Sum");
// sum-of-authority-squared
DataSet<DoubleValue> authoritySumSquared = authority
.map(new Square<K>())
.setParallelism(parallelism)
.name("Square")
.reduce(new Sum())
.setParallelism(parallelism)
.name("Sum");
// ID, normalized hubbiness, normalized authority
DataSet<Tuple3<K, DoubleValue, DoubleValue>> scores = hubbiness
.fullOuterJoin(authority, JoinHint.REPARTITION_SORT_MERGE)
.where(0)
.equalTo(0)
.with(new JoinAndNormalizeHubAndAuthority<K>())
.withBroadcastSet(hubbinessSumSquared, HUBBINESS_SUM_SQUARED)
.withBroadcastSet(authoritySumSquared, AUTHORITY_SUM_SQUARED)
.setParallelism(parallelism)
.name("Join scores");
DataSet<Tuple3<K, DoubleValue, DoubleValue>> passThrough;
if (convergenceThreshold < Double.MAX_VALUE) {
passThrough = iterative
.fullOuterJoin(scores, JoinHint.REPARTITION_SORT_MERGE)
.where(0)
.equalTo(0)
.with(new ChangeInScores<K>())
.setParallelism(parallelism)
.name("Change in scores");
iterative.registerAggregationConvergenceCriterion(CHANGE_IN_SCORES, new DoubleSumAggregator(), new ScoreConvergence(convergenceThreshold));
} else {
passThrough = scores;
}
return iterative
.closeWith(passThrough)
.map(new TranslateResult<K>())
.setParallelism(parallelism)
.name("Map result");
}
/**
* Map edges and remove the edge value.
*
* @param <T> ID type
* @param <ET> edge value type
*
* @see Graph.ExtractEdgeIDsMapper
*/
@ForwardedFields("0; 1")
private static class ExtractEdgeIDs<T, ET>
implements FlatMapFunction<Edge<T, ET>, Tuple2<T, T>> {
private Tuple2<T, T> output = new Tuple2<>();
@Override
public void flatMap(Edge<T, ET> value, Collector<Tuple2<T, T>> out)
throws Exception {
output.f0 = value.f0;
output.f1 = value.f1;
out.collect(output);
}
}
/**
* Initialize vertices' authority scores by assigning each vertex with an
* initial hub score of 1.0. The hub scores are initialized to zero since
* these will be computed based on the initial authority scores.
*
* The initial scores are non-normalized.
*
* @param <T> ID type
*/
@ForwardedFields("1->0")
private static class InitializeScores<T>
implements MapFunction<Tuple2<T, T>, Tuple3<T, DoubleValue, DoubleValue>> {
private Tuple3<T, DoubleValue, DoubleValue> output = new Tuple3<>(null, new DoubleValue(0.0), new DoubleValue(1.0));
@Override
public Tuple3<T, DoubleValue, DoubleValue> map(Tuple2<T, T> value) throws Exception {
output.f0 = value.f1;
return output;
}
}
/**
* Sum vertices' hub and authority scores.
*
* @param <T> ID type
*/
@ForwardedFieldsFirst("0")
@ForwardedFieldsSecond("0")
private static class SumScores<T>
implements ReduceFunction<Tuple3<T, DoubleValue, DoubleValue>> {
@Override
public Tuple3<T, DoubleValue, DoubleValue> reduce(Tuple3<T, DoubleValue, DoubleValue> left, Tuple3<T, DoubleValue, DoubleValue> right)
throws Exception {
left.f1.setValue(left.f1.getValue() + right.f1.getValue());
left.f2.setValue(left.f2.getValue() + right.f2.getValue());
return left;
}
}
/**
* The hub score is the sum of authority scores of vertices on out-edges.
*
* @param <T> ID type
*/
@ForwardedFieldsFirst("2->1")
@ForwardedFieldsSecond("0")
private static class Hubbiness<T>
implements CoGroupFunction<Tuple3<T, DoubleValue, DoubleValue>, Tuple2<T, T>, Tuple2<T, DoubleValue>> {
private Tuple2<T, DoubleValue> output = new Tuple2<>();
@Override
public void coGroup(Iterable<Tuple3<T, DoubleValue, DoubleValue>> vertex, Iterable<Tuple2<T, T>> edges, Collector<Tuple2<T, DoubleValue>> out)
throws Exception {
output.f1 = vertex.iterator().next().f2;
for (Tuple2<T, T> edge : edges) {
output.f0 = edge.f0;
out.collect(output);
}
}
}
/**
* Sum vertices' scores.
*
* @param <T> ID type
*/
@ForwardedFieldsFirst("0")
@ForwardedFieldsSecond("0")
private static class SumScore<T>
implements ReduceFunction<Tuple2<T, DoubleValue>> {
@Override
public Tuple2<T, DoubleValue> reduce(Tuple2<T, DoubleValue> left, Tuple2<T, DoubleValue> right)
throws Exception {
left.f1.setValue(left.f1.getValue() + right.f1.getValue());
return left;
}
}
/**
* The authority score is the sum of hub scores of vertices on in-edges.
*
* @param <T> ID type
*/
@ForwardedFieldsFirst("1")
@ForwardedFieldsSecond("1->0")
private static class Authority<T>
implements CoGroupFunction<Tuple2<T, DoubleValue>, Tuple2<T, T>, Tuple2<T, DoubleValue>> {
private Tuple2<T, DoubleValue> output = new Tuple2<>();
@Override
public void coGroup(Iterable<Tuple2<T, DoubleValue>> vertex, Iterable<Tuple2<T, T>> edges, Collector<Tuple2<T, DoubleValue>> out)
throws Exception {
output.f1 = vertex.iterator().next().f1;
for (Tuple2<T, T> edge : edges) {
output.f0 = edge.f1;
out.collect(output);
}
}
}
/**
* Compute the square of each score.
*
* @param <T> ID type
*/
private static class Square<T>
implements MapFunction<Tuple2<T, DoubleValue>, DoubleValue> {
private DoubleValue output = new DoubleValue();
@Override
public DoubleValue map(Tuple2<T, DoubleValue> value)
throws Exception {
double val = value.f1.getValue();
output.setValue(val * val);
return output;
}
}
/**
* Sum over values. This specialized function is used in place of generic aggregation.
*/
private static class Sum
implements ReduceFunction<DoubleValue> {
@Override
public DoubleValue reduce(DoubleValue first, DoubleValue second)
throws Exception {
first.setValue(first.getValue() + second.getValue());
return first;
}
}
/**
* Join and normalize the hub and authority scores.
*
* @param <T> ID type
*/
@ForwardedFieldsFirst("0")
@ForwardedFieldsSecond("0")
private static class JoinAndNormalizeHubAndAuthority<T>
extends RichJoinFunction<Tuple2<T, DoubleValue>, Tuple2<T, DoubleValue>, Tuple3<T, DoubleValue, DoubleValue>> {
private Tuple3<T, DoubleValue, DoubleValue> output = new Tuple3<>(null, new DoubleValue(), new DoubleValue());
private double hubbinessRootSumSquared;
private double authorityRootSumSquared;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
Collection<DoubleValue> var;
var = getRuntimeContext().getBroadcastVariable(HUBBINESS_SUM_SQUARED);
hubbinessRootSumSquared = Math.sqrt(var.iterator().next().getValue());
var = getRuntimeContext().getBroadcastVariable(AUTHORITY_SUM_SQUARED);
authorityRootSumSquared = Math.sqrt(var.iterator().next().getValue());
}
@Override
public Tuple3<T, DoubleValue, DoubleValue> join(Tuple2<T, DoubleValue> hubbiness, Tuple2<T, DoubleValue> authority)
throws Exception {
output.f0 = (authority == null) ? hubbiness.f0 : authority.f0;
output.f1.setValue(hubbiness == null ? 0.0 : hubbiness.f1.getValue() / hubbinessRootSumSquared);
output.f2.setValue(authority == null ? 0.0 : authority.f1.getValue() / authorityRootSumSquared);
return output;
}
}
/**
* Computes the total sum of the change in hub and authority scores over
* all vertices between iterations. A negative score is emitted after the
* first iteration to prevent premature convergence.
*
* @param <T> ID type
*/
@ForwardedFieldsFirst("0")
@ForwardedFieldsSecond("*")
private static class ChangeInScores<T>
extends RichJoinFunction<Tuple3<T, DoubleValue, DoubleValue>, Tuple3<T, DoubleValue, DoubleValue>, Tuple3<T, DoubleValue, DoubleValue>> {
private boolean isInitialSuperstep;
private double changeInScores;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
isInitialSuperstep = (getIterationRuntimeContext().getSuperstepNumber() == 1);
changeInScores = (isInitialSuperstep) ? -1.0 : 0.0;
}
@Override
public void close() throws Exception {
super.close();
DoubleSumAggregator agg = getIterationRuntimeContext().getIterationAggregator(CHANGE_IN_SCORES);
agg.aggregate(changeInScores);
}
@Override
public Tuple3<T, DoubleValue, DoubleValue> join(Tuple3<T, DoubleValue, DoubleValue> first, Tuple3<T, DoubleValue, DoubleValue> second)
throws Exception {
if (! isInitialSuperstep) {
changeInScores += Math.abs(second.f1.getValue() - first.f1.getValue());
changeInScores += Math.abs(second.f2.getValue() - first.f2.getValue());
}
return second;
}
}
/**
* Monitors the total change in hub and authority scores over all vertices.
* The iteration terminates when the change in scores compared against the
* prior iteration falls below the given convergence threshold.
*
* An optimization of this implementation of HITS is to leave the initial
* scores non-normalized; therefore, the change in scores after the first
* superstep cannot be measured and a negative value is emitted to signal
* that the iteration should continue.
*/
private static class ScoreConvergence
implements ConvergenceCriterion<DoubleValue> {
private double convergenceThreshold;
public ScoreConvergence(double convergenceThreshold) {
this.convergenceThreshold = convergenceThreshold;
}
@Override
public boolean isConverged(int iteration, DoubleValue value) {
double val = value.getValue();
return (0 <= val && val <= convergenceThreshold);
}
}
/**
* Map the Tuple result to the return type.
*
* @param <T> ID type
*/
@ForwardedFields("0")
private static class TranslateResult<T>
implements MapFunction<Tuple3<T, DoubleValue, DoubleValue>, Result<T>> {
private Result<T> output = new Result<>();
@Override
public Result<T> map(Tuple3<T, DoubleValue, DoubleValue> value) throws Exception {
output.f0 = value.f0;
output.f1.f0 = value.f1;
output.f1.f1 = value.f2;
return output;
}
}
/**
* Wraps the vertex type to encapsulate results from the HITS algorithm.
*
* @param <T> ID type
*/
public static class Result<T>
extends Vertex<T, Tuple2<DoubleValue, DoubleValue>> {
public static final int HASH_SEED = 0xc7e39a63;
private Murmur3_32 hasher = new Murmur3_32(HASH_SEED);
public Result() {
f1 = new Tuple2<>();
}
/**
* Get the hub score. Good hubs link to good authorities.
*
* @return the hub score
*/
public DoubleValue getHubScore() {
return f1.f0;
}
/**
* Get the authority score. Good authorities link to good hubs.
*
* @return the authority score
*/
public DoubleValue getAuthorityScore() {
return f1.f1;
}
public String toVerboseString() {
return "Vertex ID: " + f0
+ ", hub score: " + getHubScore()
+ ", authority score: " + getAuthorityScore();
}
@Override
public int hashCode() {
return hasher.reset()
.hash(f0.hashCode())
.hash(f1.f0.getValue())
.hash(f1.f1.getValue())
.hash();
}
}
}
......@@ -40,9 +40,6 @@ extends AsmTestBase {
@Test
public void testSimpleGraph()
throws Exception {
DataSet<Result<IntValue>> cc = directedSimpleGraph
.run(new LocalClusteringCoefficient<IntValue, NullValue, NullValue>());
String expectedResult =
"(0,(2,1))\n" +
"(1,(3,2))\n" +
......@@ -51,6 +48,9 @@ extends AsmTestBase {
"(4,(1,0))\n" +
"(5,(1,0))";
DataSet<Result<IntValue>> cc = directedSimpleGraph
.run(new LocalClusteringCoefficient<IntValue, NullValue, NullValue>());
TestBaseUtils.compareResultAsText(cc.collect(), expectedResult);
}
......
......@@ -40,9 +40,6 @@ extends AsmTestBase {
@Test
public void testSimpleGraph()
throws Exception {
DataSet<Result<IntValue>> cc = undirectedSimpleGraph
.run(new LocalClusteringCoefficient<IntValue, NullValue, NullValue>());
String expectedResult =
"(0,(2,1))\n" +
"(1,(3,2))\n" +
......@@ -51,6 +48,9 @@ extends AsmTestBase {
"(4,(1,0))\n" +
"(5,(1,0))";
DataSet<Result<IntValue>> cc = undirectedSimpleGraph
.run(new LocalClusteringCoefficient<IntValue, NullValue, NullValue>());
TestBaseUtils.compareResultAsText(cc.collect(), expectedResult);
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.library.link_analysis;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils.ChecksumHashCode;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.graph.asm.AsmTestBase;
import org.apache.flink.graph.library.link_analysis.HITS.Result;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.NullValue;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class HITSTest
extends AsmTestBase {
@Test
public void testWithSimpleGraph()
throws Exception {
DataSet<Result<IntValue>> hits = new HITS<IntValue, NullValue, NullValue>(10)
.run(directedSimpleGraph);
List<Tuple2<Double, Double>> expectedResults = new ArrayList<>();
expectedResults.add(Tuple2.of(0.5446287864731747, 0.0));
expectedResults.add(Tuple2.of(0.0, 0.8363240238999012));
expectedResults.add(Tuple2.of(0.6072453524686667,0.26848532437604833));
expectedResults.add(Tuple2.of(0.5446287864731747,0.39546603929699625));
expectedResults.add(Tuple2.of(0.0, 0.26848532437604833));
expectedResults.add(Tuple2.of(0.194966796646811, 0.0));
for (Result<IntValue> result : hits.collect()) {
int id = result.f0.getValue();
assertEquals(expectedResults.get(id).f0, result.getHubScore().getValue(), 0.000001);
assertEquals(expectedResults.get(id).f1, result.getAuthorityScore().getValue(), 0.000001);
}
}
@Test
public void testWithCompleteGraph()
throws Exception {
double expectedScore = 1.0 / Math.sqrt(completeGraphVertexCount);
DataSet<Result<LongValue>> hits = new HITS<LongValue, NullValue, NullValue>(0.000001)
.run(completeGraph);
List<Result<LongValue>> results = hits.collect();
assertEquals(completeGraphVertexCount, results.size());
for (Result<LongValue> result : results) {
assertEquals(expectedScore, result.getHubScore().getValue(), 0.000001);
assertEquals(expectedScore, result.getAuthorityScore().getValue(), 0.000001);
}
}
@Test
public void testWithRMatGraph()
throws Exception {
ChecksumHashCode checksum = DataSetUtils.checksumHashCode(directedRMatGraph
.run(new HITS<LongValue, NullValue, NullValue>(0.000001)));
assertEquals(902, checksum.getCount());
assertEquals(0x000001cbba6dbcd0L, checksum.getChecksum());
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册