提交 2ac09c08 编写于 作者: G Greg Hogan

[FLINK-7199] [gelly] Graph simplification does not set parallelism

The Simplify parameter should accept and set the parallelism when
calling the Simplify algorithms.

The LocalClusteringCoefficient "count triangles" reduce now uses the
assigned ("little") parallelism as this computation is proportional to
the number of vertices (the combine computation is proportional to the
potentially much larger number of triangles).

The ignored CombineHint on the HITS all-reduces have been removed.

This closes #4346
上级 9437a0ff
......@@ -123,6 +123,8 @@ extends ParameterizedBase {
.addClass(Hash.class)
.addClass(Print.class);
// parameters
private final ParameterTool parameters;
private final BooleanParameter disableObjectReuse = new BooleanParameter(this, "__disable_object_reuse");
......@@ -133,6 +135,18 @@ extends ParameterizedBase {
private StringParameter jobName = new StringParameter(this, "__job_name")
.setDefaultValue(null);
// state
private ExecutionEnvironment env;
private DataSet result;
private String executionName;
private Driver algorithm;
private Output output;
/**
* Create an algorithm runner from the given arguments.
*
......@@ -147,6 +161,26 @@ extends ParameterizedBase {
return this.getClass().getSimpleName();
}
/**
* Get the ExecutionEnvironment. The ExecutionEnvironment is only available
* after calling {@link Runner#run()}.
*
* @return the ExecutionEnvironment
*/
public ExecutionEnvironment getExecutionEnvironment() {
return env;
}
/**
* Get the result DataSet. The result is only available after calling
* {@link Runner#run()}.
*
* @return the result DataSet
*/
public DataSet getResult() {
return result;
}
/**
* List available algorithms. This is displayed to the user when no valid
* algorithm is given in the program parameterization.
......@@ -246,9 +280,17 @@ extends ParameterizedBase {
}
}
public void run() throws Exception {
/**
* Setup the Flink job with the graph input, algorithm, and output.
*
* <p>To then execute the job call {@link #execute}.
*
* @return this
* @throws Exception on error
*/
public Runner run() throws Exception {
// Set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env = ExecutionEnvironment.getExecutionEnvironment();
ExecutionConfig config = env.getConfig();
// should not have any non-Flink data types
......@@ -282,7 +324,7 @@ extends ParameterizedBase {
}
String algorithmName = parameters.get(ALGORITHM);
Driver algorithm = driverFactory.get(algorithmName);
algorithm = driverFactory.get(algorithmName);
if (algorithm == null) {
throw new ProgramParametrizationException("Unknown algorithm name: " + algorithmName);
......@@ -314,7 +356,7 @@ extends ParameterizedBase {
}
String outputName = parameters.get(OUTPUT);
Output output = outputFactory.get(outputName);
output = outputFactory.get(outputName);
if (output == null) {
throw new ProgramParametrizationException("Unknown output type: " + outputName);
......@@ -358,10 +400,10 @@ extends ParameterizedBase {
}
// Run algorithm
DataSet results = algorithm.plan(graph);
result = algorithm.plan(graph);
// Output
String executionName = jobName.getValue() != null ? jobName.getValue() + ": " : "";
executionName = jobName.getValue() != null ? jobName.getValue() + ": " : "";
executionName += input.getIdentity() + " ⇨ " + algorithmName + " ⇨ " + output.getName();
......@@ -386,18 +428,29 @@ extends ParameterizedBase {
throw new ProgramParametrizationException(ex.getMessage());
}
if (results == null) {
env.execute(executionName);
} else {
if (result != null) {
// Transform output if algorithm returned result DataSet
if (transforms.size() > 0) {
Collections.reverse(transforms);
for (Transform transform : transforms) {
results = (DataSet) transform.transformResult(results);
result = (DataSet) transform.transformResult(result);
}
}
}
return this;
}
output.write(executionName.toString(), System.out, results);
/**
* Execute the Flink job.
*
* @throws Exception on error
*/
private void execute() throws Exception {
if (result == null) {
env.execute(executionName);
} else {
output.write(executionName.toString(), System.out, result);
}
System.out.println();
......@@ -450,7 +503,7 @@ extends ParameterizedBase {
}
public static void main(String[] args) throws Exception {
new Runner(args).run();
new Runner(args).run().execute();
}
/**
......
......@@ -99,6 +99,6 @@ extends InputBase<K, NullValue, NullValue> {
throw new ProgramParametrizationException("Unknown type '" + type.getValue() + "'");
}
return simplify.simplify(graph);
return simplify.simplify(graph, parallelism.getValue().intValue());
}
}
......@@ -18,7 +18,6 @@
package org.apache.flink.graph.drivers.input;
import org.apache.flink.graph.drivers.parameter.LongParameter;
import org.apache.flink.graph.drivers.transform.GraphKeyTypeTransform;
import org.apache.flink.graph.drivers.transform.Transform;
import org.apache.flink.graph.drivers.transform.Transformable;
......@@ -27,8 +26,6 @@ import org.apache.flink.types.NullValue;
import java.util.Arrays;
import java.util.List;
import static org.apache.flink.api.common.ExecutionConfig.PARALLELISM_DEFAULT;
/**
* Base class for generated graphs.
*
......@@ -43,9 +40,6 @@ implements Transformable {
return Arrays.<Transform>asList(new GraphKeyTypeTransform(vertexCount()));
}
protected LongParameter parallelism = new LongParameter(this, "__parallelism")
.setDefaultValue(PARALLELISM_DEFAULT);
/**
* The vertex count is verified to be no greater than the capacity of the
* selected data type. All vertices must be counted even if skipped or
......
......@@ -50,7 +50,7 @@ extends GeneratedGraph<LongValue> {
// simplify after the translation to improve the performance of the
// simplify operators by processing smaller data types
return simplify.simplify(graph);
return simplify.simplify(graph, parallelism.getValue().intValue());
}
public abstract Graph<LongValue, NullValue, NullValue> generate(ExecutionEnvironment env) throws Exception;
......
......@@ -19,8 +19,11 @@
package org.apache.flink.graph.drivers.input;
import org.apache.flink.graph.drivers.parameter.LongParameter;
import org.apache.flink.graph.drivers.parameter.ParameterizedBase;
import static org.apache.flink.api.common.ExecutionConfig.PARALLELISM_DEFAULT;
/**
* Base class for inputs.
*
......@@ -31,4 +34,14 @@ import org.apache.flink.graph.drivers.parameter.ParameterizedBase;
public abstract class InputBase<K, VV, EV>
extends ParameterizedBase
implements Input<K, VV, EV> {
protected LongParameter parallelism = new LongParameter(this, "__parallelism")
.setDefaultValue(PARALLELISM_DEFAULT)
.setMinimumValue(1)
.setMaximumValue(Integer.MAX_VALUE);
@Override
public String getName() {
return this.getClass().getSimpleName();
}
}
......@@ -21,7 +21,10 @@ package org.apache.flink.graph.drivers.parameter;
import org.apache.flink.api.java.utils.ParameterTool;
/**
* A {@link Parameter} storing a {@link Long}.
* A {@link Parameter} storing a {@link Long} within <tt>min</tt> and
* <tt>max</tt> bounds (inclusive).
*
* <p>Note that the default value may be outside of these bounds.
*/
public class LongParameter
extends SimpleParameter<Long> {
......@@ -46,36 +49,29 @@ extends SimpleParameter<Long> {
/**
* Set the default value.
*
* <p>The default may set to any value and is not restricted by setting the
* minimum or maximum values.
*
* @param defaultValue the default value.
* @return this
*/
public LongParameter setDefaultValue(long defaultValue) {
super.setDefaultValue(defaultValue);
if (hasMinimumValue) {
Util.checkParameter(defaultValue >= minimumValue,
"Default value (" + defaultValue + ") must be greater than or equal to minimum (" + minimumValue + ")");
}
if (hasMaximumValue) {
Util.checkParameter(defaultValue <= maximumValue,
"Default value (" + defaultValue + ") must be less than or equal to maximum (" + maximumValue + ")");
}
return this;
}
/**
* Set the minimum value.
*
* <p>If a maximum value has been set then the minimum value must not be
* greater than the maximum value.
*
* @param minimumValue the minimum value
* @return this
*/
public LongParameter setMinimumValue(long minimumValue) {
if (hasDefaultValue) {
Util.checkParameter(minimumValue <= defaultValue,
"Minimum value (" + minimumValue + ") must be less than or equal to default (" + defaultValue + ")");
} else if (hasMaximumValue) {
if (hasMaximumValue) {
Util.checkParameter(minimumValue <= maximumValue,
"Minimum value (" + minimumValue + ") must be less than or equal to maximum (" + maximumValue + ")");
}
......@@ -89,14 +85,14 @@ extends SimpleParameter<Long> {
/**
* Set the maximum value.
*
* <p>If a minimum value has been set then the maximum value must not be
* less than the minimum value.
*
* @param maximumValue the maximum value
* @return this
*/
public LongParameter setMaximumValue(long maximumValue) {
if (hasDefaultValue) {
Util.checkParameter(maximumValue >= defaultValue,
"Maximum value (" + maximumValue + ") must be greater than or equal to default (" + defaultValue + ")");
} else if (hasMinimumValue) {
if (hasMinimumValue) {
Util.checkParameter(maximumValue >= minimumValue,
"Maximum value (" + maximumValue + ") must be greater than or equal to minimum (" + minimumValue + ")");
}
......@@ -109,16 +105,21 @@ extends SimpleParameter<Long> {
@Override
public void configure(ParameterTool parameterTool) {
value = hasDefaultValue ? parameterTool.getLong(name, defaultValue) : parameterTool.getLong(name);
if (hasMinimumValue) {
Util.checkParameter(value >= minimumValue,
name + " must be greater than or equal to " + minimumValue);
}
if (hasMaximumValue) {
Util.checkParameter(value <= maximumValue,
name + " must be less than or equal to " + maximumValue);
if (hasDefaultValue && !parameterTool.has(name)) {
// skip checks for min and max when using default value
value = defaultValue;
} else {
value = parameterTool.getLong(name);
if (hasMinimumValue) {
Util.checkParameter(value >= minimumValue,
name + " must be greater than or equal to " + minimumValue);
}
if (hasMaximumValue) {
Util.checkParameter(value <= maximumValue,
name + " must be less than or equal to " + maximumValue);
}
}
}
......
......@@ -109,21 +109,23 @@ implements Parameter<Ordering> {
* @return output graph
* @throws Exception on error
*/
public <T extends Comparable<T>> Graph<T, NullValue, NullValue> simplify(Graph<T, NullValue, NullValue> graph)
public <T extends Comparable<T>> Graph<T, NullValue, NullValue> simplify(Graph<T, NullValue, NullValue> graph, int parallelism)
throws Exception {
switch (value) {
case DIRECTED:
graph = graph
.run(new org.apache.flink.graph.asm.simple.directed.Simplify<>());
.run(new org.apache.flink.graph.asm.simple.directed.Simplify<T, NullValue, NullValue>()
.setParallelism(parallelism));
break;
case UNDIRECTED:
graph = graph
.run(new org.apache.flink.graph.asm.simple.undirected.Simplify<>(false));
.run(new org.apache.flink.graph.asm.simple.undirected.Simplify<T, NullValue, NullValue>(false)
.setParallelism(parallelism));
break;
case UNDIRECTED_CLIP_AND_FLIP:
graph = graph
.run(new org.apache.flink.graph.asm.simple.undirected.Simplify<>(true));
.run(new org.apache.flink.graph.asm.simple.undirected.Simplify<T, NullValue, NullValue>(true)
.setParallelism(parallelism));
break;
}
......
......@@ -62,4 +62,12 @@ public class AdamicAdarITCase extends CopyableValueDriverBaseITCase {
expectedCount(parameters(8, "print"), 39276);
}
@Test
public void testParallelism() throws Exception {
TestUtils.verifyParallelism(parameters(8, "print"),
"FlatMap \\(Mirror results\\)",
"GroupReduce \\(Compute scores\\)",
"GroupReduce \\(Generate group pairs\\)");
}
}
......@@ -73,4 +73,18 @@ public class ClusteringCoefficientITCase extends CopyableValueDriverBaseITCase {
expectedOutput(parameters(8, "undirected", "undirected", "hash"),
"\n" + new Checksum(233, 0x000000743ef6d14bL) + expected);
}
@Test
public void testParallelism() throws Exception {
String[] largeOperators = new String[]{
"Combine \\(Count triangles\\)",
"FlatMap \\(Split triangle vertices\\)",
"Join \\(Triangle listing\\)",
"GroupReduce \\(Generate triplets\\)",
"DataSink \\(Count\\)"};
TestUtils.verifyParallelism(parameters(8, "directed", "directed", "print"), largeOperators);
TestUtils.verifyParallelism(parameters(8, "directed", "undirected", "print"), largeOperators);
TestUtils.verifyParallelism(parameters(8, "undirected", "undirected", "print"), largeOperators);
}
}
......@@ -56,11 +56,15 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
ProgramParametrizationException.class);
}
// CirculantGraph
private String[] getCirculantGraphParameters(String output) {
return parameters("CirculantGraph", output, "--vertex_count", "42", "--range0", "13:4");
}
@Test
public void testHashWithCirculantGraph() throws Exception {
expectedChecksum(
parameters("CirculantGraph", "hash", "--vertex_count", "42", "--range0", "13:4"),
168, 0x000000000001ae80);
expectedChecksum(getCirculantGraphParameters("hash"), 168, 0x000000000001ae80);
}
@Test
......@@ -68,16 +72,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("CirculantGraph", "print", "--vertex_count", "42", "--range0", "13:4"),
new Checksum(168, 0x0000004bdcc52cbcL));
expectedOutputChecksum(getCirculantGraphParameters("print"), new Checksum(168, 0x0000004bdcc52cbcL));
}
@Test
public void testParallelismWithCirculantGraph() throws Exception {
TestUtils.verifyParallelism(getCirculantGraphParameters("print"));
}
// CompleteGraph
private String[] getCompleteGraphParameters(String output) {
return parameters("CompleteGraph", output, "--vertex_count", "42");
}
@Test
public void testHashWithCompleteGraph() throws Exception {
expectedChecksum(
parameters("CompleteGraph", "hash", "--vertex_count", "42"),
1722, 0x0000000000113ca0L);
expectedChecksum(getCompleteGraphParameters("hash"), 1722, 0x0000000000113ca0L);
}
@Test
......@@ -85,16 +96,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("CompleteGraph", "print", "--vertex_count", "42"),
new Checksum(1722, 0x0000031109a0c398L));
expectedOutputChecksum(getCompleteGraphParameters("print"), new Checksum(1722, 0x0000031109a0c398L));
}
@Test
public void testParallelismWithCompleteGraph() throws Exception {
TestUtils.verifyParallelism(getCompleteGraphParameters("print"));
}
// CycleGraph
private String[] getCycleGraphParameters(String output) {
return parameters("CycleGraph", output, "--vertex_count", "42");
}
@Test
public void testHashWithCycleGraph() throws Exception {
expectedChecksum(
parameters("CycleGraph", "hash", "--vertex_count", "42"),
84, 0x000000000000d740L);
expectedChecksum(getCycleGraphParameters("hash"), 84, 0x000000000000d740L);
}
@Test
......@@ -102,16 +120,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("CycleGraph", "print", "--vertex_count", "42"),
new Checksum(84, 0x000000272a136fcaL));
expectedOutputChecksum(getCycleGraphParameters("print"), new Checksum(84, 0x000000272a136fcaL));
}
@Test
public void testParallelismWithCycleGraph() throws Exception {
TestUtils.verifyParallelism(getCycleGraphParameters("print"));
}
// EchoGraph
private String[] getEchoGraphParameters(String output) {
return parameters("EchoGraph", output, "--vertex_count", "42", "--vertex_degree", "13");
}
@Test
public void testHashWithEchoGraph() throws Exception {
expectedChecksum(
parameters("EchoGraph", "hash", "--vertex_count", "42", "--vertex_degree", "13"),
546, 0x0000000000057720L);
expectedChecksum(getEchoGraphParameters("hash"), 546, 0x0000000000057720L);
}
@Test
......@@ -119,23 +144,44 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("EchoGraph", "print", "--vertex_count", "42", "--vertex_degree", "13"),
new Checksum(546, 0x000000f7190b8fcaL));
expectedOutputChecksum(getEchoGraphParameters("print"), new Checksum(546, 0x000000f7190b8fcaL));
}
@Test
public void testParallelismWithEchoGraph() throws Exception {
TestUtils.verifyParallelism(getEchoGraphParameters("print"));
}
// EmptyGraph
private String[] getEmptyGraphParameters(String output) {
return parameters("EmptyGraph", output, "--vertex_count", "42");
}
@Test
public void testHashWithEmptyGraph() throws Exception {
expectedChecksum(
parameters("EmptyGraph", "hash", "--vertex_count", "42"),
0, 0x0000000000000000L);
expectedChecksum(getEmptyGraphParameters("hash"), 0, 0x0000000000000000L);
}
@Test
public void testPrintWithEmptyGraph() throws Exception {
expectedOutputChecksum(getEmptyGraphParameters("print"), new Checksum(0, 0x0000000000000000L));
}
@Test
public void testParallelismWithEmptyGraph() throws Exception {
TestUtils.verifyParallelism(getEmptyGraphParameters("print"));
}
// GridGraph
private String[] getGridGraphParameters(String output) {
return parameters("GridGraph", output, "--dim0", "2:true", "--dim1", "3:false", "--dim2", "5:true");
}
@Test
public void testHashWithGridGraph() throws Exception {
expectedChecksum(
parameters("GridGraph", "hash", "--dim0", "2:true", "--dim1", "3:false", "--dim2", "5:true"),
130, 0x000000000000eba0L);
expectedChecksum(getGridGraphParameters("hash"), 130, 0x000000000000eba0L);
}
@Test
......@@ -143,16 +189,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("GridGraph", "print", "--dim0", "2:true", "--dim1", "3:false", "--dim2", "5:true"),
new Checksum(130, 0x00000033237d24eeL));
expectedOutputChecksum(getGridGraphParameters("print"), new Checksum(130, 0x00000033237d24eeL));
}
@Test
public void testParallelismWithGridGraph() throws Exception {
TestUtils.verifyParallelism(getGridGraphParameters("print"));
}
// HypercubeGraph
private String[] getHypercubeGraphParameters(String output) {
return parameters("HypercubeGraph", output, "--dimensions", "7");
}
@Test
public void testHashWithHypercubeGraph() throws Exception {
expectedChecksum(
parameters("HypercubeGraph", "hash", "--dimensions", "7"),
896, 0x00000000001bc800L);
expectedChecksum(getHypercubeGraphParameters("hash"), 896, 0x00000000001bc800L);
}
@Test
......@@ -160,16 +213,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("HypercubeGraph", "print", "--dimensions", "7"),
new Checksum(896, 0x000001f243ee33b2L));
expectedOutputChecksum(getHypercubeGraphParameters("print"), new Checksum(896, 0x000001f243ee33b2L));
}
@Test
public void testParallelismWithHypercubeGraph() throws Exception {
TestUtils.verifyParallelism(getHypercubeGraphParameters("print"));
}
// PathGraph
private String[] getPathGraphParameters(String output) {
return parameters("PathGraph", output, "--vertex_count", "42");
}
@Test
public void testHashWithPathGraph() throws Exception {
expectedChecksum(
parameters("PathGraph", "hash", "--vertex_count", "42"),
82, 0x000000000000d220L);
expectedChecksum(getPathGraphParameters("hash"), 82, 0x000000000000d220L);
}
@Test
......@@ -177,16 +237,27 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("PathGraph", "print", "--vertex_count", "42"),
new Checksum(82, 0x000000269be2d4c2L));
expectedOutputChecksum(getPathGraphParameters("print"), new Checksum(82, 0x000000269be2d4c2L));
}
@Test
public void testParallelismWithPathGraph() throws Exception {
TestUtils.verifyParallelism(getPathGraphParameters("print"));
}
// RMatGraph
private String[] getRMatGraphParameters(String output, String simplify) {
if (simplify == null) {
return parameters("RMatGraph", output, "--scale", "7");
} else {
return parameters("RMatGraph", output, "--scale", "7", "--simplify", simplify);
}
}
@Test
public void testHashWithRMatGraph() throws Exception {
expectedChecksum(
parameters("RMatGraph", "hash", "--scale", "7"),
2048, 0x00000000001ee529);
expectedChecksum(getRMatGraphParameters("hash", null), 2048, 0x00000000001ee529);
}
@Test
......@@ -194,16 +265,17 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("RMatGraph", "print", "--scale", "7"),
new Checksum(2048, 0x000002f737939f05L));
expectedOutputChecksum(getRMatGraphParameters("print", null), new Checksum(2048, 0x000002f737939f05L));
}
@Test
public void testParallelismWithRMatGraph() throws Exception {
TestUtils.verifyParallelism(getRMatGraphParameters("print", null));
}
@Test
public void testHashWithDirectedRMatGraph() throws Exception {
expectedChecksum(
parameters("RMatGraph", "hash", "--scale", "7", "--simplify", "directed"),
1168, 0x00000000001579bdL);
expectedChecksum(getRMatGraphParameters("hash", "directed"), 1168, 0x00000000001579bdL);
}
@Test
......@@ -211,16 +283,17 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("RMatGraph", "print", "--scale", "7", "--simplify", "directed"),
new Checksum(1168, 0x0000020e35b0f35dL));
expectedOutputChecksum(getRMatGraphParameters("print", "directed"), new Checksum(1168, 0x0000020e35b0f35dL));
}
@Test
public void testParallelismWithDirectedRMatGraph() throws Exception {
TestUtils.verifyParallelism(getRMatGraphParameters("print", "directed"));
}
@Test
public void testHashWithUndirectedRMatGraph() throws Exception {
expectedChecksum(
parameters("RMatGraph", "hash", "--scale", "7", "--simplify", "undirected"),
1854, 0x0000000000242920L);
expectedChecksum(getRMatGraphParameters("hash", "undirected"), 1854, 0x0000000000242920L);
}
@Test
......@@ -228,16 +301,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("RMatGraph", "print", "--scale", "7", "--simplify", "undirected"),
new Checksum(1854, 0x0000036fe5802162L));
expectedOutputChecksum(getRMatGraphParameters("print", "undirected"), new Checksum(1854, 0x0000036fe5802162L));
}
@Test
public void testParallelismWithUndirectedRMatGraph() throws Exception {
TestUtils.verifyParallelism(getRMatGraphParameters("print", "undirected"));
}
// SingletonEdgeGraph
private String[] getSingletonEdgeGraphParameters(String output) {
return parameters("SingletonEdgeGraph", output, "--vertex_pair_count", "42");
}
@Test
public void testHashWithSingletonEdgeGraph() throws Exception {
expectedChecksum(
parameters("SingletonEdgeGraph", "hash", "--vertex_pair_count", "42"),
84, 0x000000000001b3c0L);
expectedChecksum(getSingletonEdgeGraphParameters("hash"), 84, 0x000000000001b3c0L);
}
@Test
......@@ -245,16 +325,23 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("SingletonEdgeGraph", "print", "--vertex_pair_count", "42"),
new Checksum(84, 0x0000002e59e10d9aL));
expectedOutputChecksum(getSingletonEdgeGraphParameters("print"), new Checksum(84, 0x0000002e59e10d9aL));
}
@Test
public void testParallelismWithSingletonEdgeGraph() throws Exception {
TestUtils.verifyParallelism(getSingletonEdgeGraphParameters("print"));
}
// StarGraph
private String[] getStarGraphParameters(String output) {
return parameters("StarGraph", output, "--vertex_count", "42");
}
@Test
public void testHashWithStarGraph() throws Exception {
expectedChecksum(
parameters("StarGraph", "hash", "--vertex_count", "42"),
82, 0x0000000000006ba0L);
expectedChecksum(getStarGraphParameters("hash"), 82, 0x0000000000006ba0L);
}
@Test
......@@ -262,8 +349,11 @@ public class EdgeListITCase extends NonTransformableDriverBaseITCase {
// skip 'char' since it is not printed as a number
Assume.assumeFalse(idType.equals("char") || idType.equals("nativeChar"));
expectedOutputChecksum(
parameters("StarGraph", "print", "--vertex_count", "42"),
new Checksum(82, 0x00000011ec3faee8L));
expectedOutputChecksum(getStarGraphParameters("print"), new Checksum(82, 0x00000011ec3faee8L));
}
@Test
public void testParallelismWithStarGraph() throws Exception {
TestUtils.verifyParallelism(getStarGraphParameters("print"));
}
}
......@@ -98,4 +98,10 @@ public class GraphMetricsITCase extends DriverBaseITCase {
expectedOutput(parameters(7, "undirected", "hash"), expected);
expectedOutput(parameters(7, "undirected", "print"), expected);
}
@Test
public void testParallelism() throws Exception {
TestUtils.verifyParallelism(parameters(8, "directed", "print"));
TestUtils.verifyParallelism(parameters(8, "undirected", "print"));
}
}
......@@ -59,4 +59,9 @@ public class HITSITCase extends DriverBaseITCase {
expectedCount(parameters(8, "print"), 233);
}
@Test
public void testParallelism() throws Exception {
TestUtils.verifyParallelism(parameters(8, "print"));
}
}
......@@ -68,4 +68,12 @@ public class JaccardIndexITCase extends CopyableValueDriverBaseITCase {
expectedOutputChecksum(parameters(8, "print"), new Checksum(39276, 0x00004c5a726220c0L));
}
@Test
public void testParallelism() throws Exception {
TestUtils.verifyParallelism(parameters(8, "print"),
"FlatMap \\(Mirror results\\)",
"GroupReduce \\(Compute scores\\)",
"GroupReduce \\(Generate group pairs\\)");
}
}
......@@ -59,4 +59,9 @@ public class PageRankITCase extends DriverBaseITCase {
expectedCount(parameters(8, "print"), 233);
}
@Test
public void testParallelism() throws Exception {
TestUtils.verifyParallelism(parameters(8, "print"));
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.flink.graph.drivers;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Runner;
import org.apache.flink.optimizer.Optimizer;
import org.apache.flink.optimizer.costs.DefaultCostEstimator;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.PlanNode;
import org.apache.commons.lang3.ArrayUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
import static org.junit.Assert.assertTrue;
/**
* Utility methods for testing graph algorithm drivers.
*/
public class TestUtils {
/**
* Verify algorithm driver parallelism.
*
* <p>Based on {@code org.apache.flink.graph.generator.TestUtils}.
*
* @param arguments program arguments
* @param fullParallelismOperatorNames list of regex strings matching the names of full parallelism operators
*/
static void verifyParallelism(String[] arguments, String... fullParallelismOperatorNames) throws Exception {
// set a reduced parallelism for the algorithm runner
final int parallelism = 8;
arguments = ArrayUtils.addAll(arguments, "--__parallelism", Integer.toString(parallelism));
// configure the runner but do not execute
Runner runner = new Runner(arguments).run();
// we cannot use the actual DataSink since DataSet#writeAsCsv also
// executes the program; instead, we receive the DataSet and configure
// with a DiscardingOutputFormat
DataSet result = runner.getResult();
if (result != null) {
result.output(new DiscardingOutputFormat());
}
// set the default parallelism higher than the expected parallelism
ExecutionEnvironment env = runner.getExecutionEnvironment();
env.setParallelism(2 * parallelism);
// add default regex exclusions for the added DiscardingOutputFormat
// and also for any preceding GraphKeyTypeTransform
List<Pattern> patterns = new ArrayList<>();
patterns.add(Pattern.compile("DataSink \\(org\\.apache\\.flink\\.api\\.java\\.io\\.DiscardingOutputFormat@[0-9a-f]{1,8}\\)"));
patterns.add(Pattern.compile("FlatMap \\(Translate results IDs\\)"));
// add user regex patterns
for (String largeOperatorName : fullParallelismOperatorNames) {
patterns.add(Pattern.compile(largeOperatorName));
}
Optimizer compiler = new Optimizer(null, new DefaultCostEstimator(), new Configuration());
OptimizedPlan optimizedPlan = compiler.compile(env.createProgramPlan());
// walk the job plan from sinks to sources
List<PlanNode> queue = new ArrayList<>();
queue.addAll(optimizedPlan.getDataSinks());
while (queue.size() > 0) {
PlanNode node = queue.remove(queue.size() - 1);
// skip operators matching an exclusion pattern; these are the
// large-scale operators which run at full parallelism
boolean matched = false;
for (Pattern pattern : patterns) {
matched |= pattern.matcher(node.getNodeName()).matches();
}
if (!matched) {
// Data sources may have parallelism of 1, so simply check that the node
// parallelism has not been increased by setting the default parallelism
assertTrue("Wrong parallelism for " + node.toString(), node.getParallelism() <= parallelism);
}
for (Channel channel : node.getInputs()) {
queue.add(channel.getSource());
}
}
}
}
......@@ -113,4 +113,15 @@ public class TriangleListingITCase extends CopyableValueDriverBaseITCase {
expectedOutputChecksum(parameters(8, "undirected", "print"), new Checksum(61410, 0x0000780ffcb6838eL));
}
@Test
public void testParallelism() throws Exception {
String[] largeOperators = new String[]{
"FlatMap \\(Permute triangle vertices\\)",
"Join \\(Triangle listing\\)",
"GroupReduce \\(Generate triplets\\)"};
TestUtils.verifyParallelism(parameters(8, "directed", "print"), largeOperators);
TestUtils.verifyParallelism(parameters(8, "undirected", "print"), largeOperators);
}
}
......@@ -47,34 +47,6 @@ extends ParameterTestBase {
// Test configuration
@Test
public void testDefaultValueBelowMinimum() {
parameter.setMinimumValue(1);
expectedException.expect(ProgramParametrizationException.class);
expectedException.expectMessage("Default value (0) must be greater than or equal to minimum (1)");
parameter.setDefaultValue(0);
}
@Test
public void testDefaultValueBetweenMinAndMax() {
parameter.setMinimumValue(-1);
parameter.setMaximumValue(1);
parameter.setDefaultValue(0);
}
@Test
public void testDefaultValueAboveMaximum() {
parameter.setMaximumValue(-1);
expectedException.expect(ProgramParametrizationException.class);
expectedException.expectMessage("Default value (0) must be less than or equal to maximum (-1)");
parameter.setDefaultValue(0);
}
@Test
public void testMinimumValueAboveMaximum() {
parameter.setMaximumValue(0);
......@@ -85,16 +57,6 @@ extends ParameterTestBase {
parameter.setMinimumValue(1);
}
@Test
public void testMinimumValueAboveDefault() {
parameter.setDefaultValue(0);
expectedException.expect(ProgramParametrizationException.class);
expectedException.expectMessage("Minimum value (1) must be less than or equal to default (0)");
parameter.setMinimumValue(1);
}
@Test
public void testMaximumValueBelowMinimum() {
parameter.setMinimumValue(0);
......@@ -105,16 +67,6 @@ extends ParameterTestBase {
parameter.setMaximumValue(-1);
}
@Test
public void testMaximumValueBelowDefault() {
parameter.setDefaultValue(0);
expectedException.expect(ProgramParametrizationException.class);
expectedException.expectMessage("Maximum value (-1) must be greater than or equal to default (0)");
parameter.setMaximumValue(-1);
}
// With default
@Test
......
......@@ -127,7 +127,8 @@ extends GraphAlgorithmWrappingDataSet<K, VV, EV, Result<K>> {
.groupBy(0)
.reduce(new CountTriangles<>())
.setCombineHint(CombineHint.HASH)
.name("Count triangles");
.name("Count triangles")
.setParallelism(parallelism);
// u, deg(u)
DataSet<Vertex<K, Degrees>> vertexDegree = input
......
......@@ -126,7 +126,8 @@ extends GraphAlgorithmWrappingDataSet<K, VV, EV, Result<K>> {
.groupBy(0)
.reduce(new CountTriangles<>())
.setCombineHint(CombineHint.HASH)
.name("Count triangles");
.name("Count triangles")
.setParallelism(parallelism);
// u, deg(u)
DataSet<Vertex<K, LongValue>> vertexDegree = input
......
......@@ -169,7 +169,6 @@ extends GraphAlgorithmWrappingDataSet<K, VV, EV, Result<K>> {
.setParallelism(parallelism)
.name("Square")
.reduce(new Sum())
.setCombineHint(CombineHint.HASH)
.setParallelism(parallelism)
.name("Sum");
......@@ -193,7 +192,6 @@ extends GraphAlgorithmWrappingDataSet<K, VV, EV, Result<K>> {
.setParallelism(parallelism)
.name("Square")
.reduce(new Sum())
.setCombineHint(CombineHint.HASH)
.setParallelism(parallelism)
.name("Sum");
......
......@@ -18,9 +18,11 @@
package org.apache.flink.graph.asm.simple.directed;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.asm.AsmTestBase;
import org.apache.flink.graph.generator.TestUtils;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.NullValue;
......@@ -34,13 +36,14 @@ import java.util.List;
/**
* Tests for {@link Simplify}.
*/
public class SimplifyTest {
public class SimplifyTest extends AsmTestBase {
protected Graph<IntValue, NullValue, NullValue> graph;
@Before
public void setup() {
ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment();
@Override
public void setup() throws Exception{
super.setup();
Object[][] edges = new Object[][]{
new Object[]{0, 0},
......@@ -73,4 +76,17 @@ public class SimplifyTest {
TestBaseUtils.compareResultAsText(simpleGraph.getEdges().collect(), expectedResult);
}
@Test
public void testParallelism() throws Exception {
int parallelism = 2;
Graph<IntValue, NullValue, NullValue> simpleGraph = graph
.run(new Simplify<>());
simpleGraph.getVertices().output(new DiscardingOutputFormat<>());
simpleGraph.getEdges().output(new DiscardingOutputFormat<>());
TestUtils.verifyParallelism(env, parallelism);
}
}
......@@ -18,9 +18,11 @@
package org.apache.flink.graph.asm.simple.undirected;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.asm.AsmTestBase;
import org.apache.flink.graph.generator.TestUtils;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.NullValue;
......@@ -34,13 +36,14 @@ import java.util.List;
/**
* Tests for {@link Simplify}.
*/
public class SimplifyTest {
public class SimplifyTest extends AsmTestBase {
protected Graph<IntValue, NullValue, NullValue> graph;
@Before
public void setup() {
ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment();
@Override
public void setup() throws Exception {
super.setup();
Object[][] edges = new Object[][]{
new Object[]{0, 0},
......@@ -86,4 +89,17 @@ public class SimplifyTest {
TestBaseUtils.compareResultAsText(simpleGraph.getEdges().collect(), expectedResult);
}
@Test
public void testParallelism() throws Exception {
int parallelism = 2;
Graph<IntValue, NullValue, NullValue> simpleGraph = graph
.run(new Simplify<>(true));
simpleGraph.getVertices().output(new DiscardingOutputFormat<>());
simpleGraph.getEdges().output(new DiscardingOutputFormat<>());
TestUtils.verifyParallelism(env, parallelism);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册