提交 ebba20df 编写于 作者: Z zentol 提交者: Fabian Hueske

[FLINK-2901] Remove Record API dependencies from flink-tests #2

This closes #1306
上级 8abae2c2
......@@ -16,62 +16,62 @@
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.io.Serializable;
import java.util.Collection;
package org.apache.flink.test.accumulators;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.test.util.JavaProgramTestBase;
import org.apache.flink.util.Collector;
import org.junit.Assert;
/**
* Cross PACT computes the distance of all data points to all cluster
* centers.
*/
@SuppressWarnings("deprecation")
@ConstantFieldsFirst({0,1})
public class ComputeDistanceParameterized extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final DoubleValue distance = new DoubleValue();
private Collection<Record> clusterCenters;
public class AccumulatorIterativeITCase extends JavaProgramTestBase {
private static final int NUM_ITERATIONS = 3;
private static final int NUM_SUBTASKS = 1;
private static final String ACC_NAME = "test";
@Override
public void open(Configuration parameters) throws Exception {
this.clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers");
protected boolean skipCollectionExecution() {
return true;
}
/**
* Computes the distance of one data point to one cluster center.
*
* Output Format:
* 0: pointID
* 1: pointVector
* 2: clusterID
* 3: distance
*/
@Override
public void map(Record dataPointRecord, Collector<Record> out) {
protected void testProgram() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(NUM_SUBTASKS);
CoordVector dataPoint = dataPointRecord.getField(1, CoordVector.class);
IterativeDataSet<Integer> iteration = env.fromElements(1, 2, 3).iterate(NUM_ITERATIONS);
for (Record clusterCenterRecord : this.clusterCenters) {
IntValue clusterCenterId = clusterCenterRecord.getField(0, IntValue.class);
CoordVector clusterPoint = clusterCenterRecord.getField(1, CoordVector.class);
iteration.closeWith(iteration.reduceGroup(new SumReducer())).output(new DiscardingOutputFormat());
this.distance.setValue(dataPoint.computeEuclidianDistance(clusterPoint));
// add cluster center id and distance to the data point record
dataPointRecord.setField(2, clusterCenterId);
dataPointRecord.setField(3, this.distance);
out.collect(dataPointRecord);
Assert.assertEquals(Integer.valueOf(NUM_ITERATIONS * 6), (Integer)env.execute().getAccumulatorResult(ACC_NAME));
}
static final class SumReducer extends RichGroupReduceFunction<Integer, Integer> {
private static final long serialVersionUID = 1L;
private IntCounter testCounter = new IntCounter();
@Override
public void open(Configuration config) throws Exception {
getRuntimeContext().addAccumulator(ACC_NAME, this.testCounter);
}
@Override
public void reduce(Iterable<Integer> values, Collector<Integer> out) {
// Compute the sum
int sum = 0;
for (Integer value : values) {
sum += value;
testCounter.add(value);
}
out.collect(sum);
}
}
}
......@@ -49,9 +49,9 @@ import scala.concurrent.duration.FiniteDuration;
/**
*
*/
public abstract class CancellingTestBase extends TestLogger {
public abstract class CancelingTestBase extends TestLogger {
private static final Logger LOG = LoggerFactory.getLogger(CancellingTestBase.class);
private static final Logger LOG = LoggerFactory.getLogger(CancelingTestBase.class);
private static final int MINIMUM_HEAP_SIZE_MB = 192;
......
......@@ -30,12 +30,11 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator;
import org.apache.flink.test.util.InfiniteIntegerTupleInputFormat;
import org.apache.flink.test.util.UniformIntTupleGeneratorInputFormat;
import org.junit.Test;
public class MatchJoinCancelingITCase extends CancellingTestBase {
public class JoinCancelingITCase extends CancelingTestBase {
private static final int parallelism = 4;
public MatchJoinCancelingITCase() {
public JoinCancelingITCase() {
setTaskManagerNumSlots(parallelism);
}
......
......@@ -26,7 +26,7 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.util.InfiniteIntegerInputFormat;
import org.junit.Test;
public class MapCancelingITCase extends CancellingTestBase {
public class MapCancelingITCase extends CancelingTestBase {
private static final int parallelism = 4;
public MapCancelingITCase() {
......
......@@ -16,53 +16,59 @@
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational.query1Util;
package org.apache.flink.test.iterative;
import java.util.Iterator;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.test.recordJobs.util.Tuple;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import java.util.List;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.test.util.JavaProgramTestBase;
import static org.apache.flink.test.util.TestBaseUtils.containsResultAsText;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class GroupByReturnFlag extends ReduceFunction {
private static final long serialVersionUID = 1L;
public class IterationTerminationWithTwoTails extends JavaProgramTestBase {
private static final String EXPECTED = "22\n";
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception {
Record outRecord = new Record();
Tuple returnTuple = new Tuple();
long quantity = 0;
double extendedPriceSum = 0.0;
boolean first = true;
while(records.hasNext()) {
Record rec = records.next();
Tuple t = rec.getField(1, Tuple.class);
if(first) {
first = false;
rec.copyTo(outRecord);
returnTuple.addAttribute(rec.getField(0, StringValue.class).toString());
protected void testProgram() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
DataSet<String> initialInput = env.fromElements("1", "2", "3", "4", "5").name("input");
IterativeDataSet<String> iteration = initialInput.iterate(5).name("Loop");
DataSet<String> sumReduce = iteration.reduceGroup(new SumReducer()).name("Compute sum (GroupReduce");
DataSet<String> terminationFilter = iteration.filter(new TerminationFilter()).name("Compute termination criterion (Map)");
List<String> result = iteration.closeWith(sumReduce, terminationFilter).collect();
containsResultAsText(result, EXPECTED);
}
public static final class SumReducer implements GroupReduceFunction<String, String> {
private static final long serialVersionUID = 1L;
@Override
public void reduce(Iterable<String> values, Collector<String> out) throws Exception {
int sum = 0;
for (String value : values) {
sum += Integer.parseInt(value) + 1;
}
long tupleQuantity = Long.parseLong(t.getStringValueAt(4));
quantity += tupleQuantity;
double extendedPricePerTuple = Double.parseDouble(t.getStringValueAt(5));
extendedPriceSum += extendedPricePerTuple;
out.collect("" + sum);
}
}
public static class TerminationFilter implements FilterFunction<String> {
private static final long serialVersionUID = 1L;
@Override
public boolean filter(String value) throws Exception {
return Integer.parseInt(value) < 21;
}
LongValue pactQuantity = new LongValue(quantity);
returnTuple.addAttribute("" + pactQuantity);
returnTuple.addAttribute("" + extendedPriceSum);
outRecord.setField(1, returnTuple);
out.collect(outRecord);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.iterative;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.examples.java.clustering.KMeans;
import org.apache.flink.examples.java.clustering.KMeans.Point;
import org.apache.flink.examples.java.clustering.KMeans.Centroid;
import org.apache.flink.test.util.JavaProgramTestBase;
import org.apache.flink.test.testdata.KMeansData;
import java.util.List;
public class KMeansWithBroadcastSetITCase extends JavaProgramTestBase {
@SuppressWarnings("serial")
@Override
protected void testProgram() throws Exception {
String[] points = KMeansData.DATAPOINTS_2D.split("\n");
String[] centers = KMeansData.INITIAL_CENTERS_2D.split("\n");
// set up execution environment
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// get input data
DataSet<Point> pointsSet = env.fromElements(points)
.map(new MapFunction<String, Point>() {
public Point map(String p) {
String[] fields = p.split("\\|");
return new Point(
Double.parseDouble(fields[1]),
Double.parseDouble(fields[2]));
}
});
DataSet <Centroid> centroidsSet = env.fromElements(centers)
.map(new MapFunction<String, Centroid>() {
public Centroid map(String c) {
String[] fields = c.split("\\|");
return new Centroid(
Integer.parseInt(fields[0]),
Double.parseDouble(fields[1]),
Double.parseDouble(fields[2]));
}
});
// set number of bulk iterations for KMeans algorithm
IterativeDataSet<Centroid> loop = centroidsSet.iterate(20);
DataSet<Centroid> newCentroids = pointsSet
// compute closest centroid for each point
.map(new KMeans.SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// count and sum point coordinates for each centroid
.map(new KMeans.CountAppender())
.groupBy(0).reduce(new KMeans.CentroidAccumulator())
// compute new centroids from point counts and coordinate sums
.map(new KMeans.CentroidAverager());
// feed new centroids back into next iteration
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
DataSet<String> stringCentroids = finalCentroids
.map(new MapFunction<Centroid, String>() {
@Override
public String map(Centroid c) throws Exception {
return String.format("%d|%.2f|%.2f|", c.id, c.x, c.y);
}
});
List<String> result = stringCentroids.collect();
KMeansData.checkResultsWithDelta(
KMeansData.CENTERS_2D_AFTER_20_ITERATIONS_DOUBLE_DIGIT,
result,
0.01);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.test.recordJobs.graph.triangleEnumUtil.EdgeInputFormat;
import org.apache.flink.test.recordJobs.graph.triangleEnumUtil.EdgeWithDegreesOutputFormat;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class ComputeEdgeDegrees implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
// --------------------------------------------------------------------------------------------
// Vertex Degree Computation
// --------------------------------------------------------------------------------------------
/**
* Duplicates each edge such that: (u,v) becomes (u,v),(v,u)
*/
public static final class ProjectEdge extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record copy = new Record();
@Override
public void map(Record record, Collector<Record> out) throws Exception {
this.copy.setField(0, record.getField(1, IntValue.class));
this.copy.setField(1, record.getField(0, IntValue.class));
out.collect(this.copy);
out.collect(record);
}
}
/**
* Creates for all records in the group a record of the form (v1, v2, c1, c2), where
* v1 is the lexicographically smaller vertex id and the count for the vertex that
* was the key contains the number of edges associated with it. The other count is zero.
* This reducer also eliminates duplicate edges.
*/
public static final class CountEdges extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record result = new Record();
private final IntValue firstVertex = new IntValue();
private final IntValue secondVertex = new IntValue();
private final IntValue firstCount = new IntValue();
private final IntValue secondCount = new IntValue();
private int[] vals = new int[1024];
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception {
int[] vals = this.vals;
int len = 0;
int key = -1;
// collect all values
while (records.hasNext()) {
final Record rec = records.next();
final int id = rec.getField(1, IntValue.class).getValue();
if (key == -1) {
key = rec.getField(0, IntValue.class).getValue();
}
if (len >= vals.length) {
vals = new int[vals.length * 2];
System.arraycopy(this.vals, 0, vals, 0, this.vals.length);
this.vals = vals;
}
vals[len++] = id;
}
// sort the values to and uniquify them
Arrays.sort(vals, 0, len);
int k = 0;
for (int curr = -1, i = 0; i < len; i++) {
int val = vals[i];
if (val != curr) {
curr = val;
vals[k] = vals[i];
k++;
}
else {
vals[k] = vals[i];
}
}
len = k;
// create such that the vertex with the lower id is always the first
// both vertices contain a count, which is zero for the non-key vertices
for (int i = 0; i < len; i++) {
final int e2 = vals[i];
if (key <= e2) {
firstVertex.setValue(key);
secondVertex.setValue(e2);
firstCount.setValue(len);
secondCount.setValue(0);
} else {
firstVertex.setValue(e2);
secondVertex.setValue(key);
firstCount.setValue(0);
secondCount.setValue(len);
}
this.result.setField(0, firstVertex);
this.result.setField(1, secondVertex);
this.result.setField(2, firstCount);
this.result.setField(3, secondCount);
out.collect(result);
}
}
}
/**
* Takes the two separate edge entries (v1, v2, c1, 0) and (v1, v2, 0, c2)
* and creates an entry (v1, v2, c1, c2).
*/
public static final class JoinCountsAndUniquify extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final IntValue count1 = new IntValue();
private final IntValue count2 = new IntValue();
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception {
Record rec = null;
int c1 = 0, c2 = 0;
int numValues = 0;
while (records.hasNext()) {
rec = records.next();
final int f1 = rec.getField(2, IntValue.class).getValue();
final int f2 = rec.getField(3, IntValue.class).getValue();
c1 += f1;
c2 += f2;
numValues++;
}
if (numValues != 2 || c1 == 0 || c2 == 0) {
throw new RuntimeException("JoinCountsAndUniquify Problem: key1=" +
rec.getField(0, IntValue.class).getValue() + ", key2=" +
rec.getField(1, IntValue.class).getValue() +
"values=" + numValues + ", c1=" + c1 + ", c2=" + c2);
}
count1.setValue(c1);
count2.setValue(c2);
rec.setField(2, count1);
rec.setField(3, count2);
out.collect(rec);
}
}
// --------------------------------------------------------------------------------------------
// Triangle Enumeration
// --------------------------------------------------------------------------------------------
/**
* Assembles the Plan of the triangle enumeration example Pact program.
*/
@Override
public Plan getPlan(String... args) {
// parse job parameters
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String edgeInput = args.length > 1 ? args[1] : "";
final String output = args.length > 2 ? args[2] : "";
final char delimiter = args.length > 3 ? (char) Integer.parseInt(args[3]) : ',';
FileDataSource edges = new FileDataSource(new EdgeInputFormat(), edgeInput, "Input Edges");
edges.setParameter(EdgeInputFormat.ID_DELIMITER_CHAR, delimiter);
MapOperator projectEdge = MapOperator.builder(new ProjectEdge())
.input(edges).name("Project Edge").build();
ReduceOperator edgeCounter = ReduceOperator.builder(new CountEdges(), IntValue.class, 0)
.input(projectEdge).name("Count Edges for Vertex").build();
ReduceOperator countJoiner = ReduceOperator.builder(new JoinCountsAndUniquify())
.keyField(IntValue.class, 0)
.keyField(IntValue.class, 1)
.input(edgeCounter)
.name("Join Counts")
.build();
FileDataSink triangles = new FileDataSink(new EdgeWithDegreesOutputFormat(), output, countJoiner, "Unique Edges With Degrees");
Plan p = new Plan(triangles, "Normalize Edges and compute Vertex Degrees");
p.setDefaultParallelism(numSubTasks);
return p;
}
@Override
public String getDescription() {
return "Parameters: [noSubStasks] [input file] [output file] [vertex delimiter]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.java.record.functions.CoGroupFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsSecond;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.operators.CoGroupOperator;
import org.apache.flink.api.java.record.operators.DeltaIteration;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.test.recordJobs.graph.WorksetConnectedComponents.DuplicateLongMap;
import org.apache.flink.test.recordJobs.graph.WorksetConnectedComponents.NeighborWithComponentIDJoin;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class ConnectedComponentsWithCoGroup implements Program {
private static final long serialVersionUID = 1L;
@ConstantFieldsFirst(0)
@ConstantFieldsSecond(0)
public static final class MinIdAndUpdate extends CoGroupFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final LongValue newComponentId = new LongValue();
@Override
public void coGroup(Iterator<Record> candidates, Iterator<Record> current, Collector<Record> out) throws Exception {
if (!current.hasNext()) {
throw new Exception("Error: Id not encountered before.");
}
Record old = current.next();
long oldId = old.getField(1, LongValue.class).getValue();
long minimumComponentID = Long.MAX_VALUE;
while (candidates.hasNext()) {
long candidateComponentID = candidates.next().getField(1, LongValue.class).getValue();
if (candidateComponentID < minimumComponentID) {
minimumComponentID = candidateComponentID;
}
}
if (minimumComponentID < oldId) {
newComponentId.setValue(minimumComponentID);
old.setField(1, newComponentId);
out.collect(old);
}
}
}
@SuppressWarnings("unchecked")
@Override
public Plan getPlan(String... args) {
// parse job parameters
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String verticesInput = (args.length > 1 ? args[1] : "");
final String edgeInput = (args.length > 2 ? args[2] : "");
final String output = (args.length > 3 ? args[3] : "");
final int maxIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1);
// data source for initial vertices
FileDataSource initialVertices = new FileDataSource(new CsvInputFormat(' ', LongValue.class), verticesInput, "Vertices");
MapOperator verticesWithId = MapOperator.builder(DuplicateLongMap.class).input(initialVertices).name("Assign Vertex Ids").build();
DeltaIteration iteration = new DeltaIteration(0, "Connected Components Iteration");
iteration.setInitialSolutionSet(verticesWithId);
iteration.setInitialWorkset(verticesWithId);
iteration.setMaximumNumberOfIterations(maxIterations);
// create DataSourceContract for the edges
FileDataSource edges = new FileDataSource(new CsvInputFormat(' ', LongValue.class, LongValue.class), edgeInput, "Edges");
// create CrossOperator for distance computation
JoinOperator joinWithNeighbors = JoinOperator.builder(new NeighborWithComponentIDJoin(), LongValue.class, 0, 0)
.input1(iteration.getWorkset())
.input2(edges)
.name("Join Candidate Id With Neighbor")
.build();
CoGroupOperator minAndUpdate = CoGroupOperator.builder(new MinIdAndUpdate(), LongValue.class, 0, 0)
.input1(joinWithNeighbors)
.input2(iteration.getSolutionSet())
.name("Min Id and Update")
.build();
iteration.setNextWorkset(minAndUpdate);
iteration.setSolutionSetDelta(minAndUpdate);
// create DataSinkContract for writing the new cluster positions
FileDataSink result = new FileDataSink(new CsvOutputFormat(), output, iteration, "Result");
CsvOutputFormat.configureRecordFormat(result)
.recordDelimiter('\n')
.fieldDelimiter(' ')
.field(LongValue.class, 0)
.field(LongValue.class, 1);
// return the PACT plan
Plan plan = new Plan(result, "Workset Connected Components");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFields;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsSecond;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.operators.DeltaIteration;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class DeltaPageRankWithInitialDeltas implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
@ConstantFieldsSecond(0)
public static final class RankComparisonMatch extends JoinFunction {
private static final long serialVersionUID = 1L;
private final DoubleValue newRank = new DoubleValue();
@Override
public void join(Record vertexWithDelta, Record vertexWithOldRank, Collector<Record> out) throws Exception {
DoubleValue deltaVal = vertexWithDelta.getField(1, DoubleValue.class);
DoubleValue currentVal = vertexWithOldRank.getField(1, DoubleValue.class);
newRank.setValue(deltaVal.getValue() + currentVal.getValue());
vertexWithOldRank.setField(1, newRank);
out.collect(vertexWithOldRank);
}
}
@Combinable
@ConstantFields(0)
public static final class UpdateRankReduceDelta extends ReduceFunction {
private static final long serialVersionUID = 1L;
private final DoubleValue newRank = new DoubleValue();
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) {
double rankSum = 0.0;
double rank;
Record rec = null;
while (records.hasNext()) {
rec = records.next();
rank = rec.getField(1, DoubleValue.class).getValue();
rankSum += rank;
}
// ignore small deltas
if (Math.abs(rankSum) > 0.00001) {
newRank.setValue(rankSum);
rec.setField(1, newRank);
out.collect(rec);
}
}
}
public class PRDependenciesComputationMatchDelta extends JoinFunction {
private static final long serialVersionUID = 1L;
private final Record result = new Record();
private final DoubleValue partRank = new DoubleValue();
/*
* (srcId, trgId, weight) x (vId, rank) => (trgId, rank / weight)
*/
@Override
public void join(Record vertexWithRank, Record edgeWithWeight, Collector<Record> out) throws Exception {
result.setField(0, edgeWithWeight.getField(1, LongValue.class));
final long outLinks = edgeWithWeight.getField(2, LongValue.class).getValue();
final double rank = vertexWithRank.getField(1, DoubleValue.class).getValue();
partRank.setValue(rank / (double) outLinks);
result.setField(1, partRank);
out.collect(result);
}
}
@SuppressWarnings("unchecked")
@Override
public Plan getPlan(String... args) {
// parse job parameters
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String solutionSetInput = (args.length > 1 ? args[1] : "");
final String deltasInput = (args.length > 2 ? args[2] : "");
final String dependencySetInput = (args.length > 3 ? args[3] : "");
final String output = (args.length > 4 ? args[4] : "");
final int maxIterations = (args.length > 5 ? Integer.parseInt(args[5]) : 1);
// create DataSourceContract for the initalSolutionSet
FileDataSource initialSolutionSet = new FileDataSource(new CsvInputFormat(' ', LongValue.class, DoubleValue.class), solutionSetInput, "Initial Solution Set");
// create DataSourceContract for the initalDeltaSet
FileDataSource initialDeltaSet = new FileDataSource(new CsvInputFormat(' ', LongValue.class, DoubleValue.class), deltasInput, "Initial DeltaSet");
// create DataSourceContract for the edges
FileDataSource dependencySet = new FileDataSource(new CsvInputFormat(' ', LongValue.class, LongValue.class, LongValue.class), dependencySetInput, "Dependency Set");
DeltaIteration iteration = new DeltaIteration(0, "Delta PageRank");
iteration.setInitialSolutionSet(initialSolutionSet);
iteration.setInitialWorkset(initialDeltaSet);
iteration.setMaximumNumberOfIterations(maxIterations);
JoinOperator dependenciesMatch = JoinOperator.builder(PRDependenciesComputationMatchDelta.class,
LongValue.class, 0, 0)
.input1(iteration.getWorkset())
.input2(dependencySet)
.name("calculate dependencies")
.build();
ReduceOperator updateRanks = ReduceOperator.builder(UpdateRankReduceDelta.class, LongValue.class, 0)
.input(dependenciesMatch)
.name("update ranks")
.build();
JoinOperator oldRankComparison = JoinOperator.builder(RankComparisonMatch.class, LongValue.class, 0, 0)
.input1(updateRanks)
.input2(iteration.getSolutionSet())
.name("comparison with old ranks")
.build();
iteration.setNextWorkset(updateRanks);
iteration.setSolutionSetDelta(oldRankComparison);
// create DataSinkContract for writing the final ranks
FileDataSink result = new FileDataSink(CsvOutputFormat.class, output, iteration, "Final Ranks");
CsvOutputFormat.configureRecordFormat(result)
.recordDelimiter('\n')
.fieldDelimiter(' ')
.field(LongValue.class, 0)
.field(DoubleValue.class, 1);
// return the PACT plan
Plan plan = new Plan(result, "Delta PageRank");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: <numberOfSubTasks> <initialSolutionSet(pageId, rank)> <deltas(pageId, delta)> <dependencySet(srcId, trgId, out_links)> <out> <maxIterations>";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.test.recordJobs.graph.triangleEnumUtil.EdgeWithDegreesInputFormat;
import org.apache.flink.test.recordJobs.graph.triangleEnumUtil.TriangleOutputFormat;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
/**
* An implementation of the triangle enumeration, which expects its input to
* encode the degrees of the vertices. The algorithm selects the lower-degree vertex for the
* enumeration of open triads.
*/
@SuppressWarnings("deprecation")
public class EnumTrianglesOnEdgesWithDegrees implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
// --------------------------------------------------------------------------------------------
// Triangle Enumeration
// --------------------------------------------------------------------------------------------
public static final class ProjectOutCounts extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
@Override
public void map(Record record, Collector<Record> out) throws Exception {
record.setNumFields(2);
out.collect(record);
}
}
public static final class ProjectToLowerDegreeVertex extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
@Override
public void map(Record record, Collector<Record> out) throws Exception {
final int d1 = record.getField(2, IntValue.class).getValue();
final int d2 = record.getField(3, IntValue.class).getValue();
if (d1 > d2) {
IntValue first = record.getField(1, IntValue.class);
IntValue second = record.getField(0, IntValue.class);
record.setField(0, first);
record.setField(1, second);
}
record.setNumFields(2);
out.collect(record);
}
}
public static final class BuildTriads extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final IntValue firstVertex = new IntValue();
private final IntValue secondVertex = new IntValue();
private int[] edgeCache = new int[1024];
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception {
int len = 0;
Record rec = null;
while (records.hasNext()) {
rec = records.next();
final int e1 = rec.getField(1, IntValue.class).getValue();
for (int i = 0; i < len; i++) {
final int e2 = this.edgeCache[i];
if (e1 <= e2) {
firstVertex.setValue(e1);
secondVertex.setValue(e2);
} else {
firstVertex.setValue(e2);
secondVertex.setValue(e1);
}
rec.setField(1, firstVertex);
rec.setField(2, secondVertex);
out.collect(rec);
}
if (len >= this.edgeCache.length) {
int[] na = new int[this.edgeCache.length * 2];
System.arraycopy(this.edgeCache, 0, na, 0, this.edgeCache.length);
this.edgeCache = na;
}
this.edgeCache[len++] = e1;
}
}
}
public static class CloseTriads extends JoinFunction implements Serializable {
private static final long serialVersionUID = 1L;
@Override
public void join(Record triangle, Record missingEdge, Collector<Record> out) throws Exception {
out.collect(triangle);
}
}
/**
* Assembles the Plan of the triangle enumeration example Pact program.
*/
@Override
public Plan getPlan(String... args) {
// parse job parameters
int numSubTasks = args.length > 0 ? Integer.parseInt(args[0]) : 1;
String edgeInput = args.length > 1 ? args[1] : "";
String output = args.length > 2 ? args[2] : "";
FileDataSource edges = new FileDataSource(new EdgeWithDegreesInputFormat(), edgeInput, "Input Edges with Degrees");
edges.setParameter(EdgeWithDegreesInputFormat.VERTEX_DELIMITER_CHAR, '|');
edges.setParameter(EdgeWithDegreesInputFormat.DEGREE_DELIMITER_CHAR, ',');
// =========================== Triangle Enumeration ============================
MapOperator toLowerDegreeEdge = MapOperator.builder(new ProjectToLowerDegreeVertex())
.input(edges)
.name("Select lower-degree Edge")
.build();
MapOperator projectOutCounts = MapOperator.builder(new ProjectOutCounts())
.input(edges)
.name("Project to vertex Ids only")
.build();
ReduceOperator buildTriads = ReduceOperator.builder(new BuildTriads(), IntValue.class, 0)
.input(toLowerDegreeEdge)
.name("Build Triads")
.build();
JoinOperator closeTriads = JoinOperator.builder(new CloseTriads(), IntValue.class, 1, 0)
.keyField(IntValue.class, 2, 1)
.input1(buildTriads)
.input2(projectOutCounts)
.name("Close Triads")
.build();
closeTriads.setParameter("INPUT_SHIP_STRATEGY", "SHIP_REPARTITION_HASH");
closeTriads.setParameter("LOCAL_STRATEGY", "LOCAL_STRATEGY_HASH_BUILD_SECOND");
FileDataSink triangles = new FileDataSink(new TriangleOutputFormat(), output, closeTriads, "Triangles");
Plan p = new Plan(triangles, "Enumerate Triangles");
p.setDefaultParallelism(numSubTasks);
return p;
}
@Override
public String getDescription() {
return "Parameters: [noSubStasks] [input file] [output file]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFields;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirstExcept;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.io.DelimitedInputFormat;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;
/**
* Implementation of the triangle enumeration example Pact program.
* The program expects a file with RDF triples (in XML serialization) as input. Triples must be separated by linebrakes.
*
* The program filters for foaf:knows predicates to identify relationships between two entities (typically persons).
* Relationships are interpreted as edges in a social graph. Then the program enumerates all triangles which are build
* by edges in that graph.
*
* Usually, triangle enumeration is used as a pre-processing step to identify highly connected subgraphs.
* The algorithm was published as MapReduce job by J. Cohen in "Graph Twiddling in a MapReduce World".
* The Pact version was described in "MapReduce and PACT - Comparing Data Parallel Programming Models" (BTW 2011).
*/
@SuppressWarnings("deprecation")
public class EnumTrianglesRdfFoaf implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
/**
* Reads RDF triples and filters on the foaf:knows RDF predicate.
* The foaf:knows RDF predicate indicates that the RDF subject and object (typically of type foaf:person) know each
* other.
* Therefore, knowing connections between people are extracted and handles as graph edges.
* The EdgeInFormat filters all rdf triples with foaf:knows predicates. The subjects and objects URLs are
* compared.
* The lexicographically smaller URL is set as the first field of the output record, the greater one as the second field.
*/
public static class EdgeInFormat extends DelimitedInputFormat {
private static final long serialVersionUID = 1L;
private final StringValue rdfSubj = new StringValue();
private final StringValue rdfPred = new StringValue();
private final StringValue rdfObj = new StringValue();
@Override
public Record readRecord(Record target, byte[] bytes, int offset, int numBytes) {
final int limit = offset + numBytes;
int startPos = offset;
// read RDF subject
startPos = parseVarLengthEncapsulatedStringField(bytes, startPos, limit, ' ', rdfSubj, '"');
if (startPos < 0) {
// invalid record, exit
return null;
}
// read RDF predicate
startPos = parseVarLengthEncapsulatedStringField(bytes, startPos, limit, ' ', rdfPred, '"');
if (startPos < 0 || !rdfPred.getValue().equals("<http://xmlns.com/foaf/0.1/knows>")) {
// invalid record or predicate is not a foaf-knows predicate, exit
return null;
}
// read RDF object
startPos = parseVarLengthEncapsulatedStringField(bytes, startPos, limit, ' ', rdfObj, '"');
if (startPos < 0) {
// invalid record, exit
return null;
}
// compare RDF subject and object
if (rdfSubj.compareTo(rdfObj) <= 0) {
// subject is smaller, subject becomes first attribute, object second
target.setField(0, rdfSubj);
target.setField(1, rdfObj);
} else {
// object is smaller, object becomes first attribute, subject second
target.setField(0, rdfObj);
target.setField(1, rdfSubj);
}
return target;
}
/*
* Utility method to efficiently parse encapsulated, variable length strings
*/
private int parseVarLengthEncapsulatedStringField(byte[] bytes, int startPos, int limit, char delim, StringValue field, char encaps) {
boolean isEncaps = false;
// check whether string is encapsulated
if (bytes[startPos] == encaps) {
isEncaps = true;
}
if (isEncaps) {
// string is encapsulated
for (int i = startPos; i < limit; i++) {
if (bytes[i] == encaps) {
if (bytes[i+1] == delim) {
field.setValueAscii(bytes, startPos, i-startPos+1);
return i+2;
}
}
}
return -1;
} else {
// string is not encapsulated
int i;
for (i = startPos; i < limit; i++) {
if (bytes[i] == delim) {
field.setValueAscii(bytes, startPos, i-startPos);
return i+1;
}
}
if (i == limit) {
field.setValueAscii(bytes, startPos, i-startPos);
return i+1;
} else {
return -1;
}
}
}
}
/**
* Builds triads (open triangle) from all two edges that share a vertex.
* The common vertex is
*/
@ConstantFields(0)
public static class BuildTriads extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
// list of non-matching vertices
private final ArrayList<StringValue> otherVertices = new ArrayList<StringValue>(32);
// matching vertex
private final StringValue matchVertex = new StringValue();
// mutable output record
private final Record result = new Record();
// initialize list of non-matching vertices for one vertex
public BuildTriads() {
this.otherVertices.add(new StringValue());
}
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception {
// read the first edge
final Record rec = records.next();
// read the matching vertex
rec.getFieldInto(0, this.matchVertex);
// read the non-matching vertex and add it to the list
rec.getFieldInto(1, this.otherVertices.get(0));
// set the matching vertex in the output record
this.result.setField(0, this.matchVertex);
int numEdges = 1;
// while there are more edges
while (records.hasNext()) {
// read the next edge
final Record next = records.next();
final StringValue myVertex;
// obtain an object to store the non-matching vertex
if (numEdges >= this.otherVertices.size()) {
// we need an additional vertex object
// create the object
myVertex = new StringValue();
// and put it in the list
this.otherVertices.add(myVertex);
} else {
// we reuse a previously created object from the list
myVertex = this.otherVertices.get(numEdges);
}
// read the non-matching vertex into the obtained object
next.getFieldInto(1, myVertex);
// combine the current edge with all vertices in the non-matching vertex list
for (int i = 0; i < numEdges; i++) {
// get the other non-matching vertex
final StringValue otherVertex = this.otherVertices.get(i);
// add my and other vertex to the output record depending on their ordering
if (otherVertex.compareTo(myVertex) < 0) {
this.result.setField(1, otherVertex);
this.result.setField(2, myVertex);
out.collect(this.result);
} else {
next.setField(2, otherVertex);
out.collect(next);
}
}
numEdges++;
}
}
}
/**
* Matches all missing edges with existing edges from input.
* If the missing edge for a triad is found, the triad is transformed to a triangle by adding the missing edge.
*/
@ConstantFieldsFirstExcept({})
public static class CloseTriads extends JoinFunction implements Serializable {
private static final long serialVersionUID = 1L;
@Override
public void join(Record triad, Record missingEdge, Collector<Record> out) throws Exception {
// emit triangle (already contains missing edge at field 0
out.collect(triad);
}
}
/**
* Assembles the Plan of the triangle enumeration example Pact program.
*/
@Override
public Plan getPlan(String... args) {
// parse job parameters
int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
String edgeInput = (args.length > 1 ? args[1] : "");
String output = (args.length > 2 ? args[2] : "");
FileDataSource edges = new FileDataSource(new EdgeInFormat(), edgeInput, "BTC Edges");
ReduceOperator buildTriads = ReduceOperator.builder(new BuildTriads(), StringValue.class, 0)
.name("Build Triads")
.build();
JoinOperator closeTriads = JoinOperator.builder(new CloseTriads(), StringValue.class, 1, 0)
.keyField(StringValue.class, 2, 1)
.name("Close Triads")
.build();
closeTriads.setParameter("INPUT_LEFT_SHIP_STRATEGY", "SHIP_REPARTITION_HASH");
closeTriads.setParameter("INPUT_RIGHT_SHIP_STRATEGY", "SHIP_REPARTITION_HASH");
closeTriads.setParameter("LOCAL_STRATEGY", "LOCAL_STRATEGY_HASH_BUILD_SECOND");
FileDataSink triangles = new FileDataSink(new CsvOutputFormat(), output, "Output");
CsvOutputFormat.configureRecordFormat(triangles)
.recordDelimiter('\n')
.fieldDelimiter(' ')
.field(StringValue.class, 0)
.field(StringValue.class, 1)
.field(StringValue.class, 2);
triangles.setInput(closeTriads);
closeTriads.setSecondInput(edges);
closeTriads.setFirstInput(buildTriads);
buildTriads.setInput(edges);
Plan plan = new Plan(triangles, "Enumerate Triangles");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: [numSubStasks] [inputRDFTriples] [outputTriangles]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.test.recordJobs.graph.ComputeEdgeDegrees.CountEdges;
import org.apache.flink.test.recordJobs.graph.ComputeEdgeDegrees.JoinCountsAndUniquify;
import org.apache.flink.test.recordJobs.graph.ComputeEdgeDegrees.ProjectEdge;
import org.apache.flink.test.recordJobs.graph.EnumTrianglesOnEdgesWithDegrees.BuildTriads;
import org.apache.flink.test.recordJobs.graph.EnumTrianglesOnEdgesWithDegrees.CloseTriads;
import org.apache.flink.test.recordJobs.graph.EnumTrianglesOnEdgesWithDegrees.ProjectOutCounts;
import org.apache.flink.test.recordJobs.graph.EnumTrianglesOnEdgesWithDegrees.ProjectToLowerDegreeVertex;
import org.apache.flink.test.recordJobs.graph.triangleEnumUtil.EdgeInputFormat;
import org.apache.flink.test.recordJobs.graph.triangleEnumUtil.TriangleOutputFormat;
import org.apache.flink.types.IntValue;
/**
* An implementation of the triangle enumeration, which includes the pre-processing step
* to compute the degrees of the vertices and to select the lower-degree vertex for the
* enumeration of open triads.
*/
@SuppressWarnings("deprecation")
public class EnumTrianglesWithDegrees implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
@Override
public Plan getPlan(String... args) {
// parse job parameters
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String edgeInput = args.length > 1 ? args[1] : "";
final String output = args.length > 2 ? args[2] : "";
final char delimiter = args.length > 3 ? (char) Integer.parseInt(args[3]) : ',';
FileDataSource edges = new FileDataSource(new EdgeInputFormat(), edgeInput, "Input Edges");
edges.setParameter(EdgeInputFormat.ID_DELIMITER_CHAR, delimiter);
// =========================== Vertex Degree ============================
MapOperator projectEdge = MapOperator.builder(new ProjectEdge())
.input(edges).name("Project Edge").build();
ReduceOperator edgeCounter = ReduceOperator.builder(new CountEdges(), IntValue.class, 0)
.input(projectEdge).name("Count Edges for Vertex").build();
ReduceOperator countJoiner = ReduceOperator.builder(new JoinCountsAndUniquify(), IntValue.class, 0)
.keyField(IntValue.class, 1)
.input(edgeCounter).name("Join Counts").build();
// =========================== Triangle Enumeration ============================
MapOperator toLowerDegreeEdge = MapOperator.builder(new ProjectToLowerDegreeVertex())
.input(countJoiner).name("Select lower-degree Edge").build();
MapOperator projectOutCounts = MapOperator.builder(new ProjectOutCounts())
.input(countJoiner).name("Project out Counts").build();
ReduceOperator buildTriads = ReduceOperator.builder(new BuildTriads(), IntValue.class, 0)
.input(toLowerDegreeEdge).name("Build Triads").build();
JoinOperator closeTriads = JoinOperator.builder(new CloseTriads(), IntValue.class, 1, 0)
.keyField(IntValue.class, 2, 1)
.input1(buildTriads).input2(projectOutCounts)
.name("Close Triads").build();
closeTriads.setParameter("INPUT_SHIP_STRATEGY", "SHIP_REPARTITION_HASH");
closeTriads.setParameter("LOCAL_STRATEGY", "LOCAL_STRATEGY_HASH_BUILD_SECOND");
FileDataSink triangles = new FileDataSink(new TriangleOutputFormat(), output, closeTriads, "Triangles");
Plan p = new Plan(triangles, "Enumerate Triangles");
p.setDefaultParallelism(numSubTasks);
return p;
}
@Override
public String getDescription() {
return "Parameters: [noSubStasks] [input file] [output file]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.CoGroupFunction;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsSecond;
import org.apache.flink.api.java.record.io.DelimitedInputFormat;
import org.apache.flink.api.java.record.io.FileOutputFormat;
import org.apache.flink.api.java.record.operators.CoGroupOperator;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;
/**
* Implementation of the Pairwise Shortest Path example PACT program.
* The program implements one iteration of the algorithm and must be run multiple times until no changes are computed.
*
* The pairwise shortest path algorithm comes from the domain graph problems. The goal is to find all shortest paths
* between any two transitively connected nodes in a graph. In this implementation edges are interpreted as directed and weighted.
*
* For the first iteration, the program allows two input formats:
* 1) RDF triples with foaf:knows predicates. A triple is interpreted as an edge from the RDF subject to the RDF object with weight 1.
* 2) The programs text-serialization for paths (see @see PathInFormat and @see PathOutFormat).
*
* The RDF input format is used if the 4th parameter of the getPlan() method is set to "true". If set to "false" the path input format is used.
*/
@SuppressWarnings("deprecation")
public class PairwiseSP implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
/**
* Reads RDF triples and filters on the foaf:knows RDF predicate. The triples elements must be separated by whitespaces.
* The foaf:knows RDF predicate indicates that the RDF subject knows the object (typically of type foaf:person).
* The connections between people are extracted and handles as graph edges. For the Pairwise Shortest Path algorithm the
* connection is interpreted as a directed edge, i.e. subject knows object, but the object does not necessarily know the subject.
*
* The RDFTripleInFormat filters all RDF triples with foaf:knows predicates.
* For each triple with foaf:knows predicate, a record is emitted with
* - from-node being the RDF subject at field position 0,
* - to-node being the RDF object at field position 1,
* - length being 1 at field position 2, and
* - hopList being an empty string at field position 3.
*
*/
public static class RDFTripleInFormat extends DelimitedInputFormat {
private static final long serialVersionUID = 1L;
private final StringValue fromNode = new StringValue();
private final StringValue toNode = new StringValue();
private final IntValue pathLength = new IntValue(1);
private final IntValue hopCnt = new IntValue(0);
private final StringValue hopList = new StringValue(" ");
@Override
public Record readRecord(Record target, byte[] bytes, int offset, int numBytes) {
String lineStr = new String(bytes, offset, numBytes);
// replace reduce whitespaces and trim
lineStr = lineStr.replaceAll("\\s+", " ").trim();
// build whitespace tokenizer
StringTokenizer st = new StringTokenizer(lineStr, " ");
// line must have at least three elements
if (st.countTokens() < 3) {
return null;
}
String rdfSubj = st.nextToken();
String rdfPred = st.nextToken();
String rdfObj = st.nextToken();
// we only want foaf:knows predicates
if (!rdfPred.equals("<http://xmlns.com/foaf/0.1/knows>")) {
return null;
}
// build node pair from subject and object
fromNode.setValue(rdfSubj);
toNode.setValue(rdfObj);
target.setField(0, fromNode);
target.setField(1, toNode);
target.setField(2, pathLength);
target.setField(3, hopCnt);
target.setField(4, hopList);
return target;
}
}
/**
* The PathInFormat reads paths consisting of a from-node a to-node, a length, and hop node list serialized as a string.
* All four elements of the path must be separated by the pipe character ('|') and may not contain any pipe characters itself.
*
* PathInFormat returns records with:
* - from-node at field position 0,
* - to-node at field position 1,
* - length at field position 2,
* - hop list at field position 3.
*/
public static class PathInFormat extends DelimitedInputFormat {
private static final long serialVersionUID = 1L;
private final StringValue fromNode = new StringValue();
private final StringValue toNode = new StringValue();
private final IntValue length = new IntValue();
private final IntValue hopCnt = new IntValue();
private final StringValue hopList = new StringValue();
@Override
public Record readRecord(Record target, byte[] bytes, int offset, int numBytes) {
String lineStr = new String(bytes, offset, numBytes);
StringTokenizer st = new StringTokenizer(lineStr, "|");
// path must have exactly 5 tokens (fromNode, toNode, length, hopCnt, hopList)
if (st.countTokens() != 5) {
return null;
}
this.fromNode.setValue(st.nextToken());
this.toNode.setValue(st.nextToken());
this.length.setValue(Integer.parseInt(st.nextToken()));
this.hopCnt.setValue(Integer.parseInt(st.nextToken()));
this.hopList.setValue(st.nextToken());
target.setField(0, fromNode);
target.setField(1, toNode);
target.setField(2, length);
target.setField(3, hopCnt);
target.setField(4, hopList);
return target;
}
}
/**
* The PathOutFormat serializes paths to text.
* In order, the from-node, the to-node, the length, the hop list are written out.
* Elements are separated by the pipe character ('|').
*
*
*/
public static class PathOutFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
@Override
public void writeRecord(Record record) throws IOException {
StringBuilder line = new StringBuilder();
// append from-node
line.append(record.getField(0, StringValue.class).toString());
line.append("|");
// append to-node
line.append(record.getField(1, StringValue.class).toString());
line.append("|");
// append length
line.append(record.getField(2, IntValue.class).toString());
line.append("|");
// append hopCnt
line.append(record.getField(3, IntValue.class).toString());
line.append("|");
// append hopList
line.append(record.getField(4, StringValue.class).toString());
line.append("|");
line.append("\n");
stream.write(line.toString().getBytes());
}
}
/**
* Concatenates two paths where the from-node of the first path and the to-node of the second path are the same.
* The second input path becomes the first part and the first input path the second part of the output path.
* The length of the output path is the sum of both input paths.
* The output path's hops list is built from both path's hops lists and the common node.
*/
@ConstantFieldsFirst(1)
@ConstantFieldsSecond(0)
public static class ConcatPaths extends JoinFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record outputRecord = new Record();
private final IntValue length = new IntValue();
private final IntValue hopCnt = new IntValue();
private final StringValue hopList = new StringValue();
@Override
public void join(Record rec1, Record rec2, Collector<Record> out) throws Exception {
// rec1 has matching start, rec2 matching end
// Therefore, rec2's end node and rec1's start node are identical
// First half of new path will be rec2, second half will be rec1
// Get from-node and to-node of new path
final StringValue fromNode = rec2.getField(0, StringValue.class);
final StringValue toNode = rec1.getField(1, StringValue.class);
// Check whether from-node = to-node to prevent circles!
if (fromNode.equals(toNode)) {
return;
}
// Create new path
outputRecord.setField(0, fromNode);
outputRecord.setField(1, toNode);
// Compute length of new path
length.setValue(rec1.getField(2, IntValue.class).getValue() + rec2.getField(2, IntValue.class).getValue());
outputRecord.setField(2, length);
// compute hop count
int hops = rec1.getField(3, IntValue.class).getValue() + 1 + rec2.getField(3, IntValue.class).getValue();
hopCnt.setValue(hops);
outputRecord.setField(3, hopCnt);
// Concatenate hops lists and insert matching node
StringBuilder sb = new StringBuilder();
// first path
sb.append(rec2.getField(4, StringValue.class).getValue());
sb.append(" ");
// common node
sb.append(rec1.getField(0, StringValue.class).getValue());
// second path
sb.append(" ");
sb.append(rec1.getField(4, StringValue.class).getValue());
hopList.setValue(sb.toString().trim());
outputRecord.setField(4, hopList);
out.collect(outputRecord);
}
}
/**
* Gets two lists of paths as input and emits for each included from-node/to-node combination the shortest path(s).
* If for a combination more than one shortest path exists, all shortest paths are emitted.
*
*
*/
@ConstantFieldsFirst({0,1})
@ConstantFieldsSecond({0,1})
public static class FindShortestPath extends CoGroupFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record outputRecord = new Record();
private final Set<StringValue> shortestPaths = new HashSet<StringValue>();
private final Map<StringValue,IntValue> hopCnts = new HashMap<StringValue,IntValue>();
private final IntValue minLength = new IntValue();
@Override
public void coGroup(Iterator<Record> inputRecords, Iterator<Record> concatRecords, Collector<Record> out) {
// init minimum length and minimum path
Record pathRec = null;
StringValue path = null;
if(inputRecords.hasNext()) {
// path is in input paths
pathRec = inputRecords.next();
} else {
// path must be in concat paths
pathRec = concatRecords.next();
}
// get from node (common for all paths)
StringValue fromNode = pathRec.getField(0, StringValue.class);
// get to node (common for all paths)
StringValue toNode = pathRec.getField(1, StringValue.class);
// get length of path
minLength.setValue(pathRec.getField(2, IntValue.class).getValue());
// store path and hop count
path = new StringValue(pathRec.getField(4, StringValue.class));
shortestPaths.add(path);
hopCnts.put(path, new IntValue(pathRec.getField(3, IntValue.class).getValue()));
// find shortest path of all input paths
while (inputRecords.hasNext()) {
pathRec = inputRecords.next();
IntValue length = pathRec.getField(2, IntValue.class);
if (length.getValue() == minLength.getValue()) {
// path has also minimum length add to list
path = new StringValue(pathRec.getField(4, StringValue.class));
if(shortestPaths.add(path)) {
hopCnts.put(path, new IntValue(pathRec.getField(3, IntValue.class).getValue()));
}
} else if (length.getValue() < minLength.getValue()) {
// path has minimum length
minLength.setValue(length.getValue());
// clear lists
hopCnts.clear();
shortestPaths.clear();
// get path and add path and hop count
path = new StringValue(pathRec.getField(4, StringValue.class));
shortestPaths.add(path);
hopCnts.put(path, new IntValue(pathRec.getField(3, IntValue.class).getValue()));
}
}
// find shortest path of all input and concatenated paths
while (concatRecords.hasNext()) {
pathRec = concatRecords.next();
IntValue length = pathRec.getField(2, IntValue.class);
if (length.getValue() == minLength.getValue()) {
// path has also minimum length add to list
path = new StringValue(pathRec.getField(4, StringValue.class));
if(shortestPaths.add(path)) {
hopCnts.put(path, new IntValue(pathRec.getField(3, IntValue.class).getValue()));
}
} else if (length.getValue() < minLength.getValue()) {
// path has minimum length
minLength.setValue(length.getValue());
// clear lists
hopCnts.clear();
shortestPaths.clear();
// get path and add path and hop count
path = new StringValue(pathRec.getField(4, StringValue.class));
shortestPaths.add(path);
hopCnts.put(path, new IntValue(pathRec.getField(3, IntValue.class).getValue()));
}
}
outputRecord.setField(0, fromNode);
outputRecord.setField(1, toNode);
outputRecord.setField(2, minLength);
// emit all shortest paths
for(StringValue shortestPath : shortestPaths) {
outputRecord.setField(3, hopCnts.get(shortestPath));
outputRecord.setField(4, shortestPath);
out.collect(outputRecord);
}
hopCnts.clear();
shortestPaths.clear();
}
}
/**
* Assembles the Plan of the Pairwise Shortest Paths example Pact program.
* The program computes one iteration of the Pairwise Shortest Paths algorithm.
*
* For the first iteration, two input formats can be chosen:
* 1) RDF triples with foaf:knows predicates
* 2) Text-serialized paths (see PathInFormat and PathOutFormat)
*
* To choose 1) set the forth parameter to "true". If set to "false" 2) will be used.
*
*/
@Override
public Plan getPlan(String... args) {
// parse job parameters
int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
String paths = (args.length > 1 ? args[1] : "");
String output = (args.length > 2 ? args[2] : "");
boolean rdfInput = (args.length > 3 && Boolean.parseBoolean(args[3]));
FileDataSource pathsInput;
if(rdfInput) {
pathsInput = new FileDataSource(new RDFTripleInFormat(), paths, "RDF Triples");
} else {
pathsInput = new FileDataSource(new PathInFormat(), paths, "Paths");
}
pathsInput.setParallelism(numSubTasks);
JoinOperator concatPaths =
JoinOperator.builder(new ConcatPaths(), StringValue.class, 0, 1)
.name("Concat Paths")
.build();
concatPaths.setParallelism(numSubTasks);
CoGroupOperator findShortestPaths =
CoGroupOperator.builder(new FindShortestPath(), StringValue.class, 0, 0)
.keyField(StringValue.class, 1, 1)
.name("Find Shortest Paths")
.build();
findShortestPaths.setParallelism(numSubTasks);
FileDataSink result = new FileDataSink(new PathOutFormat(),output, "New Paths");
result.setParallelism(numSubTasks);
result.setInput(findShortestPaths);
findShortestPaths.setFirstInput(pathsInput);
findShortestPaths.setSecondInput(concatPaths);
concatPaths.setFirstInput(pathsInput);
concatPaths.setSecondInput(pathsInput);
return new Plan(result, "Pairwise Shortest Paths");
}
@Override
public String getDescription() {
return "Parameters: [numSubStasks], [inputPaths], [outputPaths], [RDFInputFlag]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph;
import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFields;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.operators.DeltaIteration;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class WorksetConnectedComponents implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
public static final class DuplicateLongMap extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
@Override
public void map(Record record, Collector<Record> out) throws Exception {
record.setField(1, record.getField(0, LongValue.class));
out.collect(record);
}
}
/**
* UDF that joins a (Vertex-ID, Component-ID) pair that represents the current component that
* a vertex is associated with, with a (Source-Vertex-ID, Target-VertexID) edge. The function
* produces a (Target-vertex-ID, Component-ID) pair.
*/
public static final class NeighborWithComponentIDJoin extends JoinFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record result = new Record();
@Override
public void join(Record vertexWithComponent, Record edge, Collector<Record> out) {
this.result.setField(0, edge.getField(1, LongValue.class));
this.result.setField(1, vertexWithComponent.getField(1, LongValue.class));
out.collect(this.result);
}
}
/**
* Minimum aggregation over (Vertex-ID, Component-ID) pairs, selecting the pair with the smallest Component-ID.
*/
@Combinable
@ConstantFields(0)
public static final class MinimumComponentIDReduce extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record result = new Record();
private final LongValue vertexId = new LongValue();
private final LongValue minComponentId = new LongValue();
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) {
final Record first = records.next();
final long vertexID = first.getField(0, LongValue.class).getValue();
long minimumComponentID = first.getField(1, LongValue.class).getValue();
while (records.hasNext()) {
long candidateComponentID = records.next().getField(1, LongValue.class).getValue();
if (candidateComponentID < minimumComponentID) {
minimumComponentID = candidateComponentID;
}
}
this.vertexId.setValue(vertexID);
this.minComponentId.setValue(minimumComponentID);
this.result.setField(0, this.vertexId);
this.result.setField(1, this.minComponentId);
out.collect(this.result);
}
}
/**
* UDF that joins a candidate (Vertex-ID, Component-ID) pair with another (Vertex-ID, Component-ID) pair.
* Returns the candidate pair, if the candidate's Component-ID is smaller.
*/
@ConstantFieldsFirst(0)
public static final class UpdateComponentIdMatch extends JoinFunction implements Serializable {
private static final long serialVersionUID = 1L;
@Override
public void join(Record newVertexWithComponent, Record currentVertexWithComponent, Collector<Record> out){
long candidateComponentID = newVertexWithComponent.getField(1, LongValue.class).getValue();
long currentComponentID = currentVertexWithComponent.getField(1, LongValue.class).getValue();
if (candidateComponentID < currentComponentID) {
out.collect(newVertexWithComponent);
}
}
}
@SuppressWarnings("unchecked")
@Override
public Plan getPlan(String... args) {
// parse job parameters
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String verticesInput = (args.length > 1 ? args[1] : "");
final String edgeInput = (args.length > 2 ? args[2] : "");
final String output = (args.length > 3 ? args[3] : "");
final int maxIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1);
// data source for initial vertices
FileDataSource initialVertices = new FileDataSource(new CsvInputFormat(' ', LongValue.class), verticesInput, "Vertices");
MapOperator verticesWithId = MapOperator.builder(DuplicateLongMap.class).input(initialVertices).name("Assign Vertex Ids").build();
// the loop takes the vertices as the solution set and changed vertices as the workset
// initially, all vertices are changed
DeltaIteration iteration = new DeltaIteration(0, "Connected Components Iteration");
iteration.setInitialSolutionSet(verticesWithId);
iteration.setInitialWorkset(verticesWithId);
iteration.setMaximumNumberOfIterations(maxIterations);
// data source for the edges
FileDataSource edges = new FileDataSource(new CsvInputFormat(' ', LongValue.class, LongValue.class), edgeInput, "Edges");
// join workset (changed vertices) with the edges to propagate changes to neighbors
JoinOperator joinWithNeighbors = JoinOperator.builder(new NeighborWithComponentIDJoin(), LongValue.class, 0, 0)
.input1(iteration.getWorkset())
.input2(edges)
.name("Join Candidate Id With Neighbor")
.build();
// find for each neighbor the smallest of all candidates
ReduceOperator minCandidateId = ReduceOperator.builder(new MinimumComponentIDReduce(), LongValue.class, 0)
.input(joinWithNeighbors)
.name("Find Minimum Candidate Id")
.build();
// join candidates with the solution set and update if the candidate component-id is smaller
JoinOperator updateComponentId = JoinOperator.builder(new UpdateComponentIdMatch(), LongValue.class, 0, 0)
.input1(minCandidateId)
.input2(iteration.getSolutionSet())
.name("Update Component Id")
.build();
iteration.setNextWorkset(updateComponentId);
iteration.setSolutionSetDelta(updateComponentId);
// sink is the iteration result
FileDataSink result = new FileDataSink(new CsvOutputFormat(), output, iteration, "Result");
CsvOutputFormat.configureRecordFormat(result)
.recordDelimiter('\n')
.fieldDelimiter(' ')
.field(LongValue.class, 0)
.field(LongValue.class, 1);
Plan plan = new Plan(result, "Workset Connected Components");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: <numberOfSubTasks> <vertices> <edges> <out> <maxIterations>";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph.triangleEnumUtil;
import org.apache.flink.api.java.record.io.DelimitedInputFormat;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
*
*/
public final class EdgeInputFormat extends DelimitedInputFormat {
private static final long serialVersionUID = 1L;
public static final String ID_DELIMITER_CHAR = "edgeinput.delimiter";
private final IntValue i1 = new IntValue();
private final IntValue i2 = new IntValue();
private char delimiter;
// --------------------------------------------------------------------------------------------
@Override
public Record readRecord(Record target, byte[] bytes, int offset, int numBytes) {
final int limit = offset + numBytes;
int first = 0, second = 0;
final char delimiter = this.delimiter;
int pos = offset;
while (pos < limit && bytes[pos] != delimiter) {
first = first * 10 + (bytes[pos++] - '0');
}
pos += 1;// skip the delimiter
while (pos < limit) {
second = second * 10 + (bytes[pos++] - '0');
}
if (first <= 0 || second <= 0 || first == second) {
return null;
}
this.i1.setValue(first);
this.i2.setValue(second);
target.setField(0, this.i1);
target.setField(1, this.i2);
return target;
}
// --------------------------------------------------------------------------------------------
@Override
public void configure(Configuration parameters) {
super.configure(parameters);
this.delimiter = (char) parameters.getInteger(ID_DELIMITER_CHAR, ',');
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph.triangleEnumUtil;
import org.apache.flink.api.java.record.io.DelimitedInputFormat;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
* Input format that reads edges augmented with vertex degrees. The data to be read is assumed to be in
* the format <code>v1,d1|v2,d2\n</code>, where <code>v1</code> and <code>v2</code> are the IDs of the first and
* second vertex, while <code>d1</code> and <code>d2</code> are the vertex degrees.
* <p>
* The result record holds the fields in the sequence <code>(v1, v2, d1, d2)</code>.
* <p>
* The delimiters are configurable. The default delimiter between vertex ID and
* vertex degree is the comma (<code>,</code>). The default delimiter between the two vertices is
* the vertical bar (<code>|</code>).
*/
public final class EdgeWithDegreesInputFormat extends DelimitedInputFormat {
private static final long serialVersionUID = 1L;
public static final String VERTEX_DELIMITER_CHAR = "edgeinput.vertexdelimiter";
public static final String DEGREE_DELIMITER_CHAR = "edgeinput.degreedelimiter";
private final IntValue v1 = new IntValue();
private final IntValue v2 = new IntValue();
private final IntValue d1 = new IntValue();
private final IntValue d2 = new IntValue();
private char vertexDelimiter;
private char degreeDelimiter;
// --------------------------------------------------------------------------------------------
@Override
public Record readRecord(Record target, byte[] bytes, int offset, int numBytes) {
final int limit = offset + numBytes;
int firstV = 0, secondV = 0;
int firstD = 0, secondD = 0;
final char vertexDelimiter = this.vertexDelimiter;
final char degreeDelimiter = this.degreeDelimiter;
int pos = offset;
// read the first vertex ID
while (pos < limit && bytes[pos] != degreeDelimiter) {
firstV = firstV * 10 + (bytes[pos++] - '0');
}
pos += 1;// skip the delimiter
// read the first vertex degree
while (pos < limit && bytes[pos] != vertexDelimiter) {
firstD = firstD * 10 + (bytes[pos++] - '0');
}
pos += 1;// skip the delimiter
// read the second vertex ID
while (pos < limit && bytes[pos] != degreeDelimiter) {
secondV = secondV * 10 + (bytes[pos++] - '0');
}
pos += 1;// skip the delimiter
// read the second vertex degree
while (pos < limit) {
secondD = secondD * 10 + (bytes[pos++] - '0');
}
if (firstV <= 0 || secondV <= 0 || firstV == secondV) {
return null;
}
v1.setValue(firstV);
v2.setValue(secondV);
d1.setValue(firstD);
d2.setValue(secondD);
target.setField(0, v1);
target.setField(1, v2);
target.setField(2, d1);
target.setField(3, d2);
return target;
}
// --------------------------------------------------------------------------------------------
@Override
public void configure(Configuration parameters) {
super.configure(parameters);
this.vertexDelimiter = (char) parameters.getInteger(VERTEX_DELIMITER_CHAR, '|');
this.degreeDelimiter = (char) parameters.getInteger(DEGREE_DELIMITER_CHAR, ',');
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph.triangleEnumUtil;
import org.apache.flink.api.java.record.io.DelimitedOutputFormat;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
*
*/
public final class EdgeWithDegreesOutputFormat extends DelimitedOutputFormat {
private static final long serialVersionUID = 1L;
private final StringBuilder line = new StringBuilder();
@Override
public int serializeRecord(Record rec, byte[] target) throws Exception {
final int e1 = rec.getField(0, IntValue.class).getValue();
final int e2 = rec.getField(1, IntValue.class).getValue();
final int e3 = rec.getField(2, IntValue.class).getValue();
final int e4 = rec.getField(3, IntValue.class).getValue();
this.line.setLength(0);
this.line.append(e1);
this.line.append(',');
this.line.append(e3);
this.line.append('|');
this.line.append(e2);
this.line.append(',');
this.line.append(e4);
if (target.length >= line.length()) {
for (int i = 0; i < line.length(); i++) {
target[i] = (byte) line.charAt(i);
}
return line.length();
}
else {
return -line.length();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.graph.triangleEnumUtil;
import org.apache.flink.api.java.record.io.DelimitedOutputFormat;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
*
*/
public final class TriangleOutputFormat extends DelimitedOutputFormat {
private static final long serialVersionUID = 1L;
private final StringBuilder line = new StringBuilder();
@Override
public int serializeRecord(Record rec, byte[] target) throws Exception {
final int e1 = rec.getField(0, IntValue.class).getValue();
final int e2 = rec.getField(1, IntValue.class).getValue();
final int e3 = rec.getField(2, IntValue.class).getValue();
this.line.setLength(0);
this.line.append(e1);
this.line.append(',');
this.line.append(e2);
this.line.append(',');
this.line.append(e3);
if (target.length >= line.length()) {
for (int i = 0; i < line.length(); i++) {
target[i] = (byte) line.charAt(i);
}
return line.length();
} else {
return -line.length();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.FileOutputFormat;
import org.apache.flink.api.java.record.operators.BulkIteration;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.client.LocalExecutor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class KMeansBroadcast implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
@Override
public Plan getPlan(String... args) {
// parse job parameters
int parallelism = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
String dataPointInput = (args.length > 1 ? args[1] : "");
String clusterInput = (args.length > 2 ? args[2] : "");
String output = (args.length > 3 ? args[3] : "");
int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 2);
// data source data point input
@SuppressWarnings("unchecked")
FileDataSource pointsSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), dataPointInput, "Data Points");
// data source for cluster center input
@SuppressWarnings("unchecked")
FileDataSource clustersSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), clusterInput, "Centers");
MapOperator dataPoints = MapOperator.builder(new PointBuilder()).name("Build data points").input(pointsSource).build();
MapOperator clusterPoints = MapOperator.builder(new PointBuilder()).name("Build cluster points").input(clustersSource).build();
// ---------------------- Begin K-Means Loop ---------------------
BulkIteration iter = new BulkIteration("k-means loop");
iter.setInput(clusterPoints);
iter.setMaximumNumberOfIterations(numIterations);
// compute the distances and select the closest center
MapOperator findNearestClusterCenters = MapOperator.builder(new SelectNearestCenter())
.setBroadcastVariable("centers", iter.getPartialSolution())
.input(dataPoints)
.name("Find Nearest Centers")
.build();
// computing the new cluster positions
ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0)
.input(findNearestClusterCenters)
.name("Recompute Center Positions")
.build();
iter.setNextPartialSolution(recomputeClusterCenter);
// ---------------------- End K-Means Loop ---------------------
// create DataSinkContract for writing the new cluster positions
FileDataSink newClusterPoints = new FileDataSink(new PointOutFormat(), output, iter, "New Center Positions");
Plan plan = new Plan(newClusterPoints, "K-Means");
plan.setDefaultParallelism(parallelism);
return plan;
}
@Override
public String getDescription() {
return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>";
}
// --------------------------------------------------------------------------------------------
// Data Types and UDFs
// --------------------------------------------------------------------------------------------
/**
* A simple three-dimensional point.
*/
public static final class Point implements Value {
private static final long serialVersionUID = 1L;
public double x, y, z;
public Point() {}
public Point(double x, double y, double z) {
this.x = x;
this.y = y;
this.z = z;
}
public void add(Point other) {
x += other.x;
y += other.y;
z += other.z;
}
public Point div(long val) {
x /= val;
y /= val;
z /= val;
return this;
}
public double euclideanDistance(Point other) {
return Math.sqrt((x-other.x)*(x-other.x) + (y-other.y)*(y-other.y) + (z-other.z)*(z-other.z));
}
public void clear() {
x = y = z = 0.0;
}
@Override
public void write(DataOutputView out) throws IOException {
out.writeDouble(x);
out.writeDouble(y);
out.writeDouble(z);
}
@Override
public void read(DataInputView in) throws IOException {
x = in.readDouble();
y = in.readDouble();
z = in.readDouble();
}
@Override
public String toString() {
return "(" + x + "|" + y + "|" + z + ")";
}
}
public static final class PointWithId {
public int id;
public Point point;
public PointWithId(int id, Point p) {
this.id = id;
this.point = p;
}
}
/**
* Determines the closest cluster center for a data point.
*/
public static final class SelectNearestCenter extends MapFunction {
private static final long serialVersionUID = 1L;
private final IntValue one = new IntValue(1);
private final Record result = new Record(3);
private List<PointWithId> centers = new ArrayList<PointWithId>();
/**
* Reads all the center values from the broadcast variable into a collection.
*/
@Override
public void open(Configuration parameters) throws Exception {
List<Record> clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers");
centers.clear();
synchronized (clusterCenters) {
for (Record r : clusterCenters) {
centers.add(new PointWithId(r.getField(0, IntValue.class).getValue(), r.getField(1, Point.class)));
}
}
}
/**
* Computes a minimum aggregation on the distance of a data point to cluster centers.
*
* Output Format:
* 0: centerID
* 1: pointVector
* 2: constant(1) (to enable combinable average computation in the following reducer)
*/
@Override
public void map(Record dataPointRecord, Collector<Record> out) {
Point p = dataPointRecord.getField(1, Point.class);
double nearestDistance = Double.MAX_VALUE;
int centerId = -1;
// check all cluster centers
for (PointWithId center : centers) {
// compute distance
double distance = p.euclideanDistance(center.point);
// update nearest cluster if necessary
if (distance < nearestDistance) {
nearestDistance = distance;
centerId = center.id;
}
}
// emit a new record with the center id and the data point. add a one to ease the
// implementation of the average function with a combiner
result.setField(0, new IntValue(centerId));
result.setField(1, p);
result.setField(2, one);
out.collect(result);
}
}
@Combinable
public static final class RecomputeClusterCenter extends ReduceFunction {
private static final long serialVersionUID = 1L;
private final Point p = new Point();
/**
* Compute the new position (coordinate vector) of a cluster center.
*/
@Override
public void reduce(Iterator<Record> points, Collector<Record> out) {
Record sum = sumPointsAndCount(points);
sum.setField(1, sum.getField(1, Point.class).div(sum.getField(2, IntValue.class).getValue()));
out.collect(sum);
}
/**
* Computes a pre-aggregated average value of a coordinate vector.
*/
@Override
public void combine(Iterator<Record> points, Collector<Record> out) {
out.collect(sumPointsAndCount(points));
}
private final Record sumPointsAndCount(Iterator<Record> dataPoints) {
Record next = null;
p.clear();
int count = 0;
// compute coordinate vector sum and count
while (dataPoints.hasNext()) {
next = dataPoints.next();
p.add(next.getField(1, Point.class));
count += next.getField(2, IntValue.class).getValue();
}
next.setField(1, p);
next.setField(2, new IntValue(count));
return next;
}
}
public static final class PointBuilder extends MapFunction {
private static final long serialVersionUID = 1L;
@Override
public void map(Record record, Collector<Record> out) throws Exception {
double x = record.getField(1, DoubleValue.class).getValue();
double y = record.getField(2, DoubleValue.class).getValue();
double z = record.getField(3, DoubleValue.class).getValue();
record.setField(1, new Point(x, y, z));
out.collect(record);
}
}
public static final class PointOutFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
private static final String format = "%d|%.1f|%.1f|%.1f|\n";
@Override
public void writeRecord(Record record) throws IOException {
int id = record.getField(0, IntValue.class).getValue();
Point p = record.getField(1, Point.class);
byte[] bytes = String.format(format, id, p.x, p.y, p.z).getBytes();
this.stream.write(bytes);
}
}
public static void main(String[] args) throws Exception {
System.out.println(LocalExecutor.optimizerPlanAsJSON(new KMeansBroadcast().getPlan("4", "/dev/random", "/dev/random", "/tmp", "20")));
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.operators.BulkIteration;
import org.apache.flink.api.java.record.operators.CrossOperator;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.client.LocalExecutor;
import org.apache.flink.test.recordJobs.kmeans.udfs.ComputeDistance;
import org.apache.flink.test.recordJobs.kmeans.udfs.FindNearestCenter;
import org.apache.flink.test.recordJobs.kmeans.udfs.PointInFormat;
import org.apache.flink.test.recordJobs.kmeans.udfs.PointOutFormat;
import org.apache.flink.test.recordJobs.kmeans.udfs.RecomputeClusterCenter;
import org.apache.flink.types.IntValue;
@SuppressWarnings("deprecation")
public class KMeansCross implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
@Override
public Plan getPlan(String... args) {
// parse job parameters
final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String dataPointInput = (args.length > 1 ? args[1] : "");
final String clusterInput = (args.length > 2 ? args[2] : "");
final String output = (args.length > 3 ? args[3] : "");
final int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1);
// create DataSourceContract for cluster center input
FileDataSource initialClusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers");
initialClusterPoints.setParallelism(1);
BulkIteration iteration = new BulkIteration("K-Means Loop");
iteration.setInput(initialClusterPoints);
iteration.setMaximumNumberOfIterations(numIterations);
// create DataSourceContract for data point input
FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points");
// create CrossOperator for distance computation
CrossOperator computeDistance = CrossOperator.builder(new ComputeDistance())
.input1(dataPoints)
.input2(iteration.getPartialSolution())
.name("Compute Distances")
.build();
// create ReduceOperator for finding the nearest cluster centers
ReduceOperator findNearestClusterCenters = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0)
.input(computeDistance)
.name("Find Nearest Centers")
.build();
// create ReduceOperator for computing new cluster positions
ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0)
.input(findNearestClusterCenters)
.name("Recompute Center Positions")
.build();
iteration.setNextPartialSolution(recomputeClusterCenter);
// create DataSourceContract for data point input
FileDataSource dataPoints2 = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points 2");
// compute distance of points to final clusters
CrossOperator computeFinalDistance = CrossOperator.builder(new ComputeDistance())
.input1(dataPoints2)
.input2(iteration)
.name("Compute Final Distances")
.build();
// find nearest final cluster for point
ReduceOperator findNearestFinalCluster = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0)
.input(computeFinalDistance)
.name("Find Nearest Final Centers")
.build();
// create DataSinkContract for writing the new cluster positions
FileDataSink finalClusters = new FileDataSink(new PointOutFormat(), output+"/centers", iteration, "Cluster Positions");
// write assigned clusters
FileDataSink clusterAssignments = new FileDataSink(new PointOutFormat(), output+"/points", findNearestFinalCluster, "Cluster Assignments");
List<FileDataSink> sinks = new ArrayList<FileDataSink>();
sinks.add(finalClusters);
sinks.add(clusterAssignments);
// return the PACT plan
Plan plan = new Plan(sinks, "Iterative KMeans");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>";
}
public static void main(String[] args) throws Exception {
KMeansCross kmi = new KMeansCross();
if (args.length < 5) {
System.err.println(kmi.getDescription());
System.exit(1);
}
Plan plan = kmi.getPlan(args);
// This will execute the kMeans clustering job embedded in a local context.
LocalExecutor.execute(plan);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.FileOutputFormat;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class KMeansSingleStep implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
@Override
public Plan getPlan(String... args) {
// parse job parameters
int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
String dataPointInput = (args.length > 1 ? args[1] : "");
String clusterInput = (args.length > 2 ? args[2] : "");
String output = (args.length > 3 ? args[3] : "");
// create DataSourceContract for data point input
@SuppressWarnings("unchecked")
FileDataSource pointsSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), dataPointInput, "Data Points");
// create DataSourceContract for cluster center input
@SuppressWarnings("unchecked")
FileDataSource clustersSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), clusterInput, "Centers");
MapOperator dataPoints = MapOperator.builder(new PointBuilder()).name("Build data points").input(pointsSource).build();
MapOperator clusterPoints = MapOperator.builder(new PointBuilder()).name("Build cluster points").input(clustersSource).build();
// the mapper computes the distance to all points, which it draws from a broadcast variable
MapOperator findNearestClusterCenters = MapOperator.builder(new SelectNearestCenter())
.setBroadcastVariable("centers", clusterPoints)
.input(dataPoints)
.name("Find Nearest Centers")
.build();
// create reducer recomputes the cluster centers as the average of all associated data points
ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0)
.input(findNearestClusterCenters)
.name("Recompute Center Positions")
.build();
// create DataSinkContract for writing the new cluster positions
FileDataSink newClusterPoints = new FileDataSink(new PointOutFormat(), output, recomputeClusterCenter, "New Center Positions");
// return the plan
Plan plan = new Plan(newClusterPoints, "KMeans Iteration");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output>";
}
public static final class Point implements Value {
private static final long serialVersionUID = 1L;
public double x, y, z;
public Point() {}
public Point(double x, double y, double z) {
this.x = x;
this.y = y;
this.z = z;
}
public void add(Point other) {
x += other.x;
y += other.y;
z += other.z;
}
public Point div(long val) {
x /= val;
y /= val;
z /= val;
return this;
}
public double euclideanDistance(Point other) {
return Math.sqrt((x-other.x)*(x-other.x) + (y-other.y)*(y-other.y) + (z-other.z)*(z-other.z));
}
public void clear() {
x = y = z = 0.0;
}
@Override
public void write(DataOutputView out) throws IOException {
out.writeDouble(x);
out.writeDouble(y);
out.writeDouble(z);
}
@Override
public void read(DataInputView in) throws IOException {
x = in.readDouble();
y = in.readDouble();
z = in.readDouble();
}
@Override
public String toString() {
return "(" + x + "|" + y + "|" + z + ")";
}
}
public static final class PointWithId {
public int id;
public Point point;
public PointWithId(int id, Point p) {
this.id = id;
this.point = p;
}
}
/**
* Determines the closest cluster center for a data point.
*/
public static final class SelectNearestCenter extends MapFunction {
private static final long serialVersionUID = 1L;
private final IntValue one = new IntValue(1);
private final Record result = new Record(3);
private List<PointWithId> centers = new ArrayList<PointWithId>();
/**
* Reads all the center values from the broadcast variable into a collection.
*/
@Override
public void open(Configuration parameters) throws Exception {
Collection<Record> clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers");
centers.clear();
for (Record r : clusterCenters) {
centers.add(new PointWithId(r.getField(0, IntValue.class).getValue(), r.getField(1, Point.class)));
}
}
/**
* Computes a minimum aggregation on the distance of a data point to cluster centers.
*
* Output Format:
* 0: centerID
* 1: pointVector
* 2: constant(1) (to enable combinable average computation in the following reducer)
*/
@Override
public void map(Record dataPointRecord, Collector<Record> out) {
Point p = dataPointRecord.getField(1, Point.class);
double nearestDistance = Double.MAX_VALUE;
int centerId = -1;
// check all cluster centers
for (PointWithId center : centers) {
// compute distance
double distance = p.euclideanDistance(center.point);
// update nearest cluster if necessary
if (distance < nearestDistance) {
nearestDistance = distance;
centerId = center.id;
}
}
// emit a new record with the center id and the data point. add a one to ease the
// implementation of the average function with a combiner
result.setField(0, new IntValue(centerId));
result.setField(1, p);
result.setField(2, one);
out.collect(result);
}
}
@Combinable
public static final class RecomputeClusterCenter extends ReduceFunction {
private static final long serialVersionUID = 1L;
private final Point p = new Point();
/**
* Compute the new position (coordinate vector) of a cluster center.
*/
@Override
public void reduce(Iterator<Record> points, Collector<Record> out) {
Record sum = sumPointsAndCount(points);
sum.setField(1, sum.getField(1, Point.class).div(sum.getField(2, IntValue.class).getValue()));
out.collect(sum);
}
/**
* Computes a pre-aggregated average value of a coordinate vector.
*/
@Override
public void combine(Iterator<Record> points, Collector<Record> out) {
out.collect(sumPointsAndCount(points));
}
private final Record sumPointsAndCount(Iterator<Record> dataPoints) {
Record next = null;
p.clear();
int count = 0;
// compute coordinate vector sum and count
while (dataPoints.hasNext()) {
next = dataPoints.next();
p.add(next.getField(1, Point.class));
count += next.getField(2, IntValue.class).getValue();
}
next.setField(1, p);
next.setField(2, new IntValue(count));
return next;
}
}
public static final class PointBuilder extends MapFunction {
private static final long serialVersionUID = 1L;
@Override
public void map(Record record, Collector<Record> out) throws Exception {
double x = record.getField(1, DoubleValue.class).getValue();
double y = record.getField(2, DoubleValue.class).getValue();
double z = record.getField(3, DoubleValue.class).getValue();
record.setField(1, new Point(x, y, z));
out.collect(record);
}
}
public static final class PointOutFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
private static final String format = "%d|%.1f|%.1f|%.1f|\n";
@Override
public void writeRecord(Record record) throws IOException {
int id = record.getField(0, IntValue.class).getValue();
Point p = record.getField(1, Point.class);
byte[] bytes = String.format(format, id, p.x, p.y, p.z).getBytes();
this.stream.write(bytes);
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.io.Serializable;
import org.apache.flink.api.java.record.functions.CrossFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
* Cross PACT computes the distance of all data points to all cluster
* centers.
*/
@SuppressWarnings("deprecation")
@ConstantFieldsFirst({0,1})
public class ComputeDistance extends CrossFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final DoubleValue distance = new DoubleValue();
/**
* Computes the distance of one data point to one cluster center.
*
* Output Format:
* 0: pointID
* 1: pointVector
* 2: clusterID
* 3: distance
*/
@Override
public Record cross(Record dataPointRecord, Record clusterCenterRecord) throws Exception {
CoordVector dataPoint = dataPointRecord.getField(1, CoordVector.class);
IntValue clusterCenterId = clusterCenterRecord.getField(0, IntValue.class);
CoordVector clusterPoint = clusterCenterRecord.getField(1, CoordVector.class);
this.distance.setValue(dataPoint.computeEuclidianDistance(clusterPoint));
// add cluster center id and distance to the data point record
dataPointRecord.setField(2, clusterCenterId);
dataPointRecord.setField(3, this.distance);
return dataPointRecord;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.io.IOException;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.types.Key;
/**
* Implements a feature vector as a multi-dimensional point. Coordinates of that point
* (= the features) are stored as double values. The distance between two feature vectors is
* the Euclidian distance between the points.
*/
public final class CoordVector implements Key<CoordVector> {
private static final long serialVersionUID = 1L;
// coordinate array
private double[] coordinates;
/**
* Initializes a blank coordinate vector. Required for deserialization!
*/
public CoordVector() {
coordinates = null;
}
/**
* Initializes a coordinate vector.
*
* @param coordinates The coordinate vector of a multi-dimensional point.
*/
public CoordVector(Double[] coordinates) {
this.coordinates = new double[coordinates.length];
for (int i = 0; i < coordinates.length; i++) {
this.coordinates[i] = coordinates[i];
}
}
/**
* Initializes a coordinate vector.
*
* @param coordinates The coordinate vector of a multi-dimensional point.
*/
public CoordVector(double[] coordinates) {
this.coordinates = coordinates;
}
/**
* Returns the coordinate vector of a multi-dimensional point.
*
* @return The coordinate vector of a multi-dimensional point.
*/
public double[] getCoordinates() {
return this.coordinates;
}
/**
* Sets the coordinate vector of a multi-dimensional point.
*
* @param coordinates The dimension values of the point.
*/
public void setCoordinates(double[] coordinates) {
this.coordinates = coordinates;
}
/**
* Computes the Euclidian distance between this coordinate vector and a
* second coordinate vector.
*
* @param cv The coordinate vector to which the distance is computed.
* @return The Euclidian distance to coordinate vector cv. If cv has a
* different length than this coordinate vector, -1 is returned.
*/
public double computeEuclidianDistance(CoordVector cv) {
// check coordinate vector lengths
if (cv.coordinates.length != this.coordinates.length) {
return -1.0;
}
double quadSum = 0.0;
for (int i = 0; i < this.coordinates.length; i++) {
double diff = this.coordinates[i] - cv.coordinates[i];
quadSum += diff*diff;
}
return Math.sqrt(quadSum);
}
@Override
public void read(DataInputView in) throws IOException {
int length = in.readInt();
this.coordinates = new double[length];
for (int i = 0; i < length; i++) {
this.coordinates[i] = in.readDouble();
}
}
@Override
public void write(DataOutputView out) throws IOException {
out.writeInt(this.coordinates.length);
for (int i = 0; i < this.coordinates.length; i++) {
out.writeDouble(this.coordinates[i]);
}
}
/**
* Compares this coordinate vector to another key.
*
* @return -1 if the other key is not of type CoordVector. If the other
* key is also a CoordVector but its length differs from this
* coordinates vector, -1 is return if this coordinate vector is
* smaller and 1 if it is larger. If both coordinate vectors
* have the same length, the coordinates of both are compared.
* If a coordinate of this coordinate vector is smaller than the
* corresponding coordinate of the other vector -1 is returned
* and 1 otherwise. If all coordinates are identical 0 is
* returned.
*/
@Override
public int compareTo(CoordVector o) {
// check if both coordinate vectors have identical lengths
if (o.coordinates.length > this.coordinates.length) {
return -1;
}
else if (o.coordinates.length < this.coordinates.length) {
return 1;
}
// compare all coordinates
for (int i = 0; i < this.coordinates.length; i++) {
if (o.coordinates[i] > this.coordinates[i]) {
return -1;
} else if (o.coordinates[i] < this.coordinates[i]) {
return 1;
}
}
return 0;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFields;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
/**
* Reduce PACT determines the closes cluster center for a data point. This
* is a minimum aggregation. Hence, a Combiner can be easily implemented.
*/
@SuppressWarnings("deprecation")
@Combinable
@ConstantFields(1)
public class FindNearestCenter extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final IntValue centerId = new IntValue();
private final CoordVector position = new CoordVector();
private final IntValue one = new IntValue(1);
private final Record result = new Record(3);
/**
* Computes a minimum aggregation on the distance of a data point to
* cluster centers.
*
* Output Format:
* 0: centerID
* 1: pointVector
* 2: constant(1) (to enable combinable average computation in the following reducer)
*/
@Override
public void reduce(Iterator<Record> pointsWithDistance, Collector<Record> out) {
double nearestDistance = Double.MAX_VALUE;
int nearestClusterId = 0;
// check all cluster centers
while (pointsWithDistance.hasNext()) {
Record res = pointsWithDistance.next();
double distance = res.getField(3, DoubleValue.class).getValue();
// compare distances
if (distance < nearestDistance) {
// if distance is smaller than smallest till now, update nearest cluster
nearestDistance = distance;
nearestClusterId = res.getField(2, IntValue.class).getValue();
res.getFieldInto(1, this.position);
}
}
// emit a new record with the center id and the data point. add a one to ease the
// implementation of the average function with a combiner
this.centerId.setValue(nearestClusterId);
this.result.setField(0, this.centerId);
this.result.setField(1, this.position);
this.result.setField(2, this.one);
out.collect(this.result);
}
// ----------------------------------------------------------------------------------------
private final Record nearest = new Record();
/**
* Computes a minimum aggregation on the distance of a data point to
* cluster centers.
*/
@Override
public void combine(Iterator<Record> pointsWithDistance, Collector<Record> out) {
double nearestDistance = Double.MAX_VALUE;
// check all cluster centers
while (pointsWithDistance.hasNext()) {
Record res = pointsWithDistance.next();
double distance = res.getField(3, DoubleValue.class).getValue();
// compare distances
if (distance < nearestDistance) {
nearestDistance = distance;
res.copyTo(this.nearest);
}
}
// emit nearest one
out.collect(this.nearest);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.java.record.io.DelimitedInputFormat;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
* Generates records with an id and a and CoordVector.
* The input format is line-based, i.e. one record is read from one line
* which is terminated by '\n'. Within a line the first '|' character separates
* the id from the CoordVector. The vector consists of a vector of decimals.
* The decimals are separated by '|' as well. The id is the id of a data point or
* cluster center and the CoordVector the corresponding position (coordinate
* vector) of the data point or cluster center. Example line:
* "42|23.23|52.57|74.43| Id: 42 Coordinate vector: (23.23, 52.57, 74.43)
*/
public class PointInFormat extends DelimitedInputFormat {
private static final long serialVersionUID = 1L;
private final IntValue idInteger = new IntValue();
private final CoordVector point = new CoordVector();
private final List<Double> dimensionValues = new ArrayList<Double>();
private double[] pointValues = new double[0];
@Override
public Record readRecord(Record record, byte[] line, int offset, int numBytes) {
final int limit = offset + numBytes;
int id = -1;
int value = 0;
int fractionValue = 0;
int fractionChars = 0;
boolean negative = false;
this.dimensionValues.clear();
for (int pos = offset; pos < limit; pos++) {
if (line[pos] == '|') {
// check if id was already set
if (id == -1) {
id = value;
}
else {
double v = value + ((double) fractionValue) * Math.pow(10, (-1 * (fractionChars - 1)));
this.dimensionValues.add(negative ? -v : v);
}
// reset value
value = 0;
fractionValue = 0;
fractionChars = 0;
negative = false;
} else if (line[pos] == '.') {
fractionChars = 1;
} else if (line[pos] == '-') {
negative = true;
} else {
if (fractionChars == 0) {
value *= 10;
value += line[pos] - '0';
} else {
fractionValue *= 10;
fractionValue += line[pos] - '0';
fractionChars++;
}
}
}
// set the ID
this.idInteger.setValue(id);
record.setField(0, this.idInteger);
// set the data points
if (this.pointValues.length != this.dimensionValues.size()) {
this.pointValues = new double[this.dimensionValues.size()];
}
for (int i = 0; i < this.pointValues.length; i++) {
this.pointValues[i] = this.dimensionValues.get(i);
}
this.point.setCoordinates(this.pointValues);
record.setField(1, this.point);
return record;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import org.apache.flink.api.java.record.io.DelimitedOutputFormat;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
/**
* Writes records that contain an id and a CoordVector.
* The output format is line-based, i.e. one record is written to
* a line and terminated with '\n'. Within a line the first '|' character
* separates the id from the CoordVector. The vector consists of a vector of
* decimals. The decimals are separated by '|'. The is is the id of a data
* point or cluster center and the vector the corresponding position
* (coordinate vector) of the data point or cluster center. Example line:
* "42|23.23|52.57|74.43| Id: 42 Coordinate vector: (23.23, 52.57, 74.43)
*/
public class PointOutFormat extends DelimitedOutputFormat {
private static final long serialVersionUID = 1L;
private final DecimalFormat df = new DecimalFormat("####0.00");
private final StringBuilder line = new StringBuilder();
public PointOutFormat() {
DecimalFormatSymbols dfSymbols = new DecimalFormatSymbols();
dfSymbols.setDecimalSeparator('.');
this.df.setDecimalFormatSymbols(dfSymbols);
}
@Override
public int serializeRecord(Record record, byte[] target) {
line.setLength(0);
IntValue centerId = record.getField(0, IntValue.class);
CoordVector centerPos = record.getField(1, CoordVector.class);
line.append(centerId.getValue());
for (double coord : centerPos.getCoordinates()) {
line.append('|');
line.append(df.format(coord));
}
line.append('|');
byte[] byteString = line.toString().getBytes();
if (byteString.length <= target.length) {
System.arraycopy(byteString, 0, target, 0, byteString.length);
return byteString.length;
}
else {
return -byteString.length;
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.kmeans.udfs;
import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFields;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
/**
* Reduce PACT computes the new position (coordinate vector) of a cluster
* center. This is an average computation. Hence, Combinable is annotated
* and the combine method implemented.
*
* Output Format:
* 0: clusterID
* 1: clusterVector
*/
@SuppressWarnings("deprecation")
@Combinable
@ConstantFields(0)
public class RecomputeClusterCenter extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final IntValue count = new IntValue();
/**
* Compute the new position (coordinate vector) of a cluster center.
*/
@Override
public void reduce(Iterator<Record> dataPoints, Collector<Record> out) {
Record next = null;
// initialize coordinate vector sum and count
CoordVector coordinates = new CoordVector();
double[] coordinateSum = null;
int count = 0;
// compute coordinate vector sum and count
while (dataPoints.hasNext()) {
next = dataPoints.next();
// get the coordinates and the count from the record
double[] thisCoords = next.getField(1, CoordVector.class).getCoordinates();
int thisCount = next.getField(2, IntValue.class).getValue();
if (coordinateSum == null) {
if (coordinates.getCoordinates() != null) {
coordinateSum = coordinates.getCoordinates();
}
else {
coordinateSum = new double[thisCoords.length];
}
}
addToCoordVector(coordinateSum, thisCoords);
count += thisCount;
}
// compute new coordinate vector (position) of cluster center
for (int i = 0; i < coordinateSum.length; i++) {
coordinateSum[i] /= count;
}
coordinates.setCoordinates(coordinateSum);
next.setField(1, coordinates);
next.setNull(2);
// emit new position of cluster center
out.collect(next);
}
/**
* Computes a pre-aggregated average value of a coordinate vector.
*/
@Override
public void combine(Iterator<Record> dataPoints, Collector<Record> out) {
Record next = null;
// initialize coordinate vector sum and count
CoordVector coordinates = new CoordVector();
double[] coordinateSum = null;
int count = 0;
// compute coordinate vector sum and count
while (dataPoints.hasNext()) {
next = dataPoints.next();
// get the coordinates and the count from the record
double[] thisCoords = next.getField(1, CoordVector.class).getCoordinates();
int thisCount = next.getField(2, IntValue.class).getValue();
if (coordinateSum == null) {
if (coordinates.getCoordinates() != null) {
coordinateSum = coordinates.getCoordinates();
}
else {
coordinateSum = new double[thisCoords.length];
}
}
addToCoordVector(coordinateSum, thisCoords);
count += thisCount;
}
coordinates.setCoordinates(coordinateSum);
this.count.setValue(count);
next.setField(1, coordinates);
next.setField(2, this.count);
// emit partial sum and partial count for average computation
out.collect(next);
}
/**
* Adds two coordinate vectors by summing up each of their coordinates.
*
* @param cvToAddTo
* The coordinate vector to which the other vector is added.
* This vector is returned.
* @param cvToBeAdded
* The coordinate vector which is added to the other vector.
* This vector is not modified.
*/
private void addToCoordVector(double[] cvToAddTo, double[] cvToBeAdded) {
// check if both vectors have same length
if (cvToAddTo.length != cvToBeAdded.length) {
throw new IllegalArgumentException("The given coordinate vectors are not of equal length.");
}
// sum coordinate vectors coordinate-wise
for (int i = 0; i < cvToAddTo.length; i++) {
cvToAddTo[i] += cvToBeAdded[i];
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsExcept;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirstExcept;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
@SuppressWarnings("deprecation")
public class MergeOnlyJoin implements Program {
private static final long serialVersionUID = 1L;
@ConstantFieldsFirstExcept(2)
public static class JoinInputs extends JoinFunction {
private static final long serialVersionUID = 1L;
@Override
public void join(Record input1, Record input2, Collector<Record> out) {
input1.setField(2, input2.getField(1, IntValue.class));
out.collect(input1);
}
}
@ConstantFieldsExcept({})
public static class DummyReduce extends ReduceFunction {
private static final long serialVersionUID = 1L;
@Override
public void reduce(Iterator<Record> values, Collector<Record> out) {
while (values.hasNext()) {
out.collect(values.next());
}
}
}
@Override
public Plan getPlan(final String... args) {
// parse program parameters
int numSubtasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
String input1Path = (args.length > 1 ? args[1] : "");
String input2Path = (args.length > 2 ? args[2] : "");
String output = (args.length > 3 ? args[3] : "");
int numSubtasksInput2 = (args.length > 4 ? Integer.parseInt(args[4]) : 1);
// create DataSourceContract for Orders input
@SuppressWarnings("unchecked")
CsvInputFormat format1 = new CsvInputFormat('|', IntValue.class, IntValue.class);
FileDataSource input1 = new FileDataSource(format1, input1Path, "Input 1");
ReduceOperator aggInput1 = ReduceOperator.builder(DummyReduce.class, IntValue.class, 0)
.input(input1)
.name("AggOrders")
.build();
// create DataSourceContract for Orders input
@SuppressWarnings("unchecked")
CsvInputFormat format2 = new CsvInputFormat('|', IntValue.class, IntValue.class);
FileDataSource input2 = new FileDataSource(format2, input2Path, "Input 2");
input2.setParallelism(numSubtasksInput2);
ReduceOperator aggInput2 = ReduceOperator.builder(DummyReduce.class, IntValue.class, 0)
.input(input2)
.name("AggLines")
.build();
aggInput2.setParallelism(numSubtasksInput2);
// create JoinOperator for joining Orders and LineItems
JoinOperator joinLiO = JoinOperator.builder(JoinInputs.class, IntValue.class, 0, 0)
.input1(aggInput1)
.input2(aggInput2)
.name("JoinLiO")
.build();
// create DataSinkContract for writing the result
FileDataSink result = new FileDataSink(new CsvOutputFormat(), output, joinLiO, "Output");
CsvOutputFormat.configureRecordFormat(result)
.recordDelimiter('\n')
.fieldDelimiter('|')
.lenient(true)
.field(IntValue.class, 0)
.field(IntValue.class, 1)
.field(IntValue.class, 2);
// assemble the PACT plan
Plan plan = new Plan(result, "Merge Only Join");
plan.setDefaultParallelism(numSubtasks);
return plan;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.test.recordJobs.relational.query1Util.GroupByReturnFlag;
import org.apache.flink.test.recordJobs.relational.query1Util.LineItemFilter;
import org.apache.flink.test.recordJobs.util.IntTupleDataInFormat;
import org.apache.flink.test.recordJobs.util.StringTupleDataOutFormat;
import org.apache.flink.types.StringValue;
@SuppressWarnings("deprecation")
public class TPCHQuery1 implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
private int parallelism = 1;
private String lineItemInputPath;
private String outputPath;
@Override
public Plan getPlan(String... args) throws IllegalArgumentException {
if (args.length != 3) {
this.parallelism = 1;
this.lineItemInputPath = "";
this.outputPath = "";
} else {
this.parallelism = Integer.parseInt(args[0]);
this.lineItemInputPath = args[1];
this.outputPath = args[2];
}
FileDataSource lineItems =
new FileDataSource(new IntTupleDataInFormat(), this.lineItemInputPath, "LineItems");
lineItems.setParallelism(this.parallelism);
FileDataSink result =
new FileDataSink(new StringTupleDataOutFormat(), this.outputPath, "Output");
result.setParallelism(this.parallelism);
MapOperator lineItemFilter =
MapOperator.builder(new LineItemFilter())
.name("LineItem Filter")
.build();
lineItemFilter.setParallelism(this.parallelism);
ReduceOperator groupByReturnFlag =
ReduceOperator.builder(new GroupByReturnFlag(), StringValue.class, 0)
.name("groupyBy")
.build();
lineItemFilter.setInput(lineItems);
groupByReturnFlag.setInput(lineItemFilter);
result.setInput(groupByReturnFlag);
return new Plan(result, "TPC-H 1");
}
@Override
public String getDescription() {
return "Parameters: [parallelism] [lineitem-input] [output]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.io.FileOutputFormat;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.test.recordJobs.util.IntTupleDataInFormat;
import org.apache.flink.test.recordJobs.util.Tuple;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;
@SuppressWarnings({"serial", "deprecation"})
public class TPCHQuery10 implements Program, ProgramDescription {
// --------------------------------------------------------------------------------------------
// Local Filters and Projections
// --------------------------------------------------------------------------------------------
/**
* Forwards (0 = orderkey, 1 = custkey).
*/
public static class FilterO extends MapFunction
{
private static final int YEAR_FILTER = 1990;
private final IntValue custKey = new IntValue();
@Override
public void map(Record record, Collector<Record> out) throws Exception {
Tuple t = record.getField(1, Tuple.class);
if (Integer.parseInt(t.getStringValueAt(4).substring(0, 4)) > FilterO.YEAR_FILTER) {
// project
this.custKey.setValue((int) t.getLongValueAt(1));
record.setField(1, this.custKey);
out.collect(record);
}
}
}
/**
* Forwards (0 = lineitem, 1 = tuple (extendedprice, discount) )
*/
public static class FilterLI extends MapFunction
{
private final Tuple tuple = new Tuple();
@Override
public void map(Record record, Collector<Record> out) throws Exception
{
Tuple t = record.getField(1, this.tuple);
if (t.getStringValueAt(8).equals("R")) {
t.project(0x60); // l_extendedprice, l_discount
record.setField(1, t);
out.collect(record);
}
}
}
/**
* Returns (0 = custkey, 1 = custName, 2 = NULL, 3 = balance, 4 = nationkey, 5 = address, 6 = phone, 7 = comment)
*/
public static class ProjectC extends MapFunction {
private final Tuple tuple = new Tuple();
private final StringValue custName = new StringValue();
private final StringValue balance = new StringValue();
private final IntValue nationKey = new IntValue();
private final StringValue address = new StringValue();
private final StringValue phone = new StringValue();
private final StringValue comment = new StringValue();
@Override
public void map(Record record, Collector<Record> out) throws Exception
{
final Tuple t = record.getField(1, this.tuple);
this.custName.setValue(t.getStringValueAt(1));
this.address.setValue(t.getStringValueAt(2));
this.nationKey.setValue((int) t.getLongValueAt(3));
this.phone.setValue(t.getStringValueAt(4));
this.balance.setValue(t.getStringValueAt(5));
this.comment.setValue(t.getStringValueAt(7));
record.setField(1, this.custName);
record.setField(3, this.balance);
record.setField(4, this.nationKey);
record.setField(5, this.address);
record.setField(6, this.phone);
record.setField(7, this.comment);
out.collect(record);
}
}
/**
* Returns (0 = nationkey, 1 = nation_name)
*/
public static class ProjectN extends MapFunction
{
private final Tuple tuple = new Tuple();
private final StringValue nationName = new StringValue();
@Override
public void map(Record record, Collector<Record> out) throws Exception
{
final Tuple t = record.getField(1, this.tuple);
this.nationName.setValue(t.getStringValueAt(1));
record.setField(1, this.nationName);
out.collect(record);
}
}
// --------------------------------------------------------------------------------------------
// Joins
// --------------------------------------------------------------------------------------------
/**
* Returns (0 = custKey, 1 = tuple (extendedprice, discount) )
*/
public static class JoinOL extends JoinFunction
{
@Override
public void join(Record order, Record lineitem, Collector<Record> out) throws Exception {
lineitem.setField(0, order.getField(1, IntValue.class));
out.collect(lineitem);
}
}
/**
* Returns (0 = custkey, 1 = custName, 2 = extPrice * (1-discount), 3 = balance, 4 = nationkey, 5 = address, 6 = phone, 7 = comment)
*/
public static class JoinCOL extends JoinFunction
{
private final DoubleValue d = new DoubleValue();
@Override
public void join(Record custRecord, Record olRecord, Collector<Record> out) throws Exception
{
final Tuple t = olRecord.getField(1, Tuple.class);
final double extPrice = Double.parseDouble(t.getStringValueAt(0));
final double discount = Double.parseDouble(t.getStringValueAt(1));
this.d.setValue(extPrice * (1 - discount));
custRecord.setField(2, this.d);
out.collect(custRecord);
}
}
/**
* Returns (0 = custkey, 1 = custName, 2 = extPrice * (1-discount), 3 = balance, 4 = nationName, 5 = address, 6 = phone, 7 = comment)
*/
public static class JoinNCOL extends JoinFunction
{
@Override
public void join(Record colRecord, Record nation, Collector<Record> out) throws Exception {
colRecord.setField(4, nation.getField(1, StringValue.class));
out.collect(colRecord);
}
}
@ReduceOperator.Combinable
public static class Sum extends ReduceFunction
{
private final DoubleValue d = new DoubleValue();
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception
{
Record record = null;
double sum = 0;
while (records.hasNext()) {
record = records.next();
sum += record.getField(2, DoubleValue.class).getValue();
}
this.d.setValue(sum);
record.setField(2, this.d);
out.collect(record);
}
@Override
public void combine(Iterator<Record> records, Collector<Record> out) throws Exception {
reduce(records,out);
}
}
public static class TupleOutputFormat extends FileOutputFormat {
private static final long serialVersionUID = 1L;
private final DecimalFormat formatter;
private final StringBuilder buffer = new StringBuilder();
public TupleOutputFormat() {
DecimalFormatSymbols decimalFormatSymbol = new DecimalFormatSymbols();
decimalFormatSymbol.setDecimalSeparator('.');
this.formatter = new DecimalFormat("#.####");
this.formatter.setDecimalFormatSymbols(decimalFormatSymbol);
}
@Override
public void writeRecord(Record record) throws IOException
{
this.buffer.setLength(0);
this.buffer.append(record.getField(0, IntValue.class).toString()).append('|');
this.buffer.append(record.getField(1, StringValue.class).toString()).append('|');
this.buffer.append(this.formatter.format(record.getField(2, DoubleValue.class).getValue())).append('|');
this.buffer.append(record.getField(3, StringValue.class).toString()).append('|');
this.buffer.append(record.getField(4, StringValue.class).toString()).append('|');
this.buffer.append(record.getField(5, StringValue.class).toString()).append('|');
this.buffer.append(record.getField(6, StringValue.class).toString()).append('|');
this.buffer.append(record.getField(7, StringValue.class).toString()).append('|');
this.buffer.append('\n');
final byte[] bytes = this.buffer.toString().getBytes();
this.stream.write(bytes);
}
}
@Override
public String getDescription() {
return "TPC-H Query 10";
}
@Override
public Plan getPlan(String... args) throws IllegalArgumentException {
final String ordersPath;
final String lineitemsPath;
final String customersPath;
final String nationsPath;
final String resultPath;
final int parallelism;
if (args.length < 6) {
throw new IllegalArgumentException("Invalid number of parameters");
} else {
parallelism = Integer.parseInt(args[0]);
ordersPath = args[1];
lineitemsPath = args[2];
customersPath = args[3];
nationsPath = args[4];
resultPath = args[5];
}
FileDataSource orders = new FileDataSource(new IntTupleDataInFormat(), ordersPath, "Orders");
// orders.setOutputContract(UniqueKey.class);
// orders.getCompilerHints().setAvgNumValuesPerKey(1);
FileDataSource lineitems = new FileDataSource(new IntTupleDataInFormat(), lineitemsPath, "LineItems");
// lineitems.getCompilerHints().setAvgNumValuesPerKey(4);
FileDataSource customers = new FileDataSource(new IntTupleDataInFormat(), customersPath, "Customers");
FileDataSource nations = new FileDataSource(new IntTupleDataInFormat(), nationsPath, "Nations");
MapOperator mapO = MapOperator.builder(FilterO.class)
.name("FilterO")
.build();
MapOperator mapLi = MapOperator.builder(FilterLI.class)
.name("FilterLi")
.build();
MapOperator projectC = MapOperator.builder(ProjectC.class)
.name("ProjectC")
.build();
MapOperator projectN = MapOperator.builder(ProjectN.class)
.name("ProjectN")
.build();
JoinOperator joinOL = JoinOperator.builder(JoinOL.class, IntValue.class, 0, 0)
.name("JoinOL")
.build();
JoinOperator joinCOL = JoinOperator.builder(JoinCOL.class, IntValue.class, 0, 0)
.name("JoinCOL")
.build();
JoinOperator joinNCOL = JoinOperator.builder(JoinNCOL.class, IntValue.class, 4, 0)
.name("JoinNCOL")
.build();
ReduceOperator reduce = ReduceOperator.builder(Sum.class)
.keyField(IntValue.class, 0)
.keyField(StringValue.class, 1)
.keyField(StringValue.class, 3)
.keyField(StringValue.class, 4)
.keyField(StringValue.class, 5)
.keyField(StringValue.class, 6)
.keyField(StringValue.class, 7)
.name("Reduce")
.build();
FileDataSink result = new FileDataSink(new TupleOutputFormat(), resultPath, "Output");
result.setInput(reduce);
reduce.setInput(joinNCOL);
joinNCOL.setFirstInput(joinCOL);
joinNCOL.setSecondInput(projectN);
joinCOL.setFirstInput(projectC);
joinCOL.setSecondInput(joinOL);
joinOL.setFirstInput(mapO);
joinOL.setSecondInput(mapLi);
projectC.setInput(customers);
projectN.setInput(nations);
mapLi.setInput(lineitems);
mapO.setInput(orders);
// return the PACT plan
Plan p = new Plan(result, "TPCH Q10");
p.setDefaultParallelism(parallelism);
return p;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational;
import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.JoinFunction;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFields;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator.Combinable;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;
/**
* The TPC-H is a decision support benchmark on relational data.
* Its documentation and the data generator (DBGEN) can be found
* on http://www.tpc.org/tpch/ .This implementation is tested with
* the DB2 data format.
*
* This program implements a modified version of the query 3 of
* the TPC-H benchmark including one join, some filtering and an
* aggregation.
*
* SELECT l_orderkey, o_shippriority, sum(l_extendedprice) as revenue
* FROM orders, lineitem
* WHERE l_orderkey = o_orderkey
* AND o_orderstatus = "X"
* AND YEAR(o_orderdate) > Y
* AND o_orderpriority LIKE "Z%"
* GROUP BY l_orderkey, o_shippriority;
*/
@SuppressWarnings("deprecation")
public class TPCHQuery3 implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
public static final String YEAR_FILTER = "parameter.YEAR_FILTER";
public static final String PRIO_FILTER = "parameter.PRIO_FILTER";
/**
* Map PACT implements the selection and projection on the orders table.
*/
@ConstantFields({0,1})
public static class FilterO extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
private String prioFilter; // filter literal for the order priority
private int yearFilter; // filter literal for the year
// reusable objects for the fields touched in the mapper
private StringValue orderStatus;
private StringValue orderDate;
private StringValue orderPrio;
/**
* Reads the filter literals from the configuration.
*
* @see org.apache.flink.api.common.functions.RichFunction#open(org.apache.flink.configuration.Configuration)
*/
@Override
public void open(Configuration parameters) {
this.yearFilter = parameters.getInteger(YEAR_FILTER, 1990);
this.prioFilter = parameters.getString(PRIO_FILTER, "0");
}
/**
* Filters the orders table by year, order status and order priority.
*
* o_orderstatus = "X"
* AND YEAR(o_orderdate) > Y
* AND o_orderpriority LIKE "Z"
*
* Output Schema:
* 0:ORDERKEY,
* 1:SHIPPRIORITY
*/
@Override
public void map(final Record record, final Collector<Record> out) {
orderStatus = record.getField(2, StringValue.class);
if (!orderStatus.getValue().equals("F")) {
return;
}
orderPrio = record.getField(4, StringValue.class);
if(!orderPrio.getValue().startsWith(this.prioFilter)) {
return;
}
orderDate = record.getField(3, StringValue.class);
if (!(Integer.parseInt(orderDate.getValue().substring(0, 4)) > this.yearFilter)) {
return;
}
record.setNumFields(2);
out.collect(record);
}
}
/**
* Match PACT realizes the join between LineItem and Order table.
*
*/
@ConstantFieldsFirst({0,1})
public static class JoinLiO extends JoinFunction implements Serializable {
private static final long serialVersionUID = 1L;
/**
* Implements the join between LineItem and Order table on the order key.
*
* Output Schema:
* 0:ORDERKEY
* 1:SHIPPRIORITY
* 2:EXTENDEDPRICE
*/
@Override
public void join(Record order, Record lineitem, Collector<Record> out) {
order.setField(2, lineitem.getField(1, DoubleValue.class));
out.collect(order);
}
}
/**
* Reduce PACT implements the sum aggregation.
* The Combinable annotation is set as the partial sums can be calculated
* already in the combiner
*
*/
@Combinable
@ConstantFields({0,1})
public static class AggLiO extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final DoubleValue extendedPrice = new DoubleValue();
/**
* Implements the sum aggregation.
*
* Output Schema:
* 0:ORDERKEY
* 1:SHIPPRIORITY
* 2:SUM(EXTENDEDPRICE)
*/
@Override
public void reduce(Iterator<Record> values, Collector<Record> out) {
Record rec = null;
double partExtendedPriceSum = 0;
while (values.hasNext()) {
rec = values.next();
partExtendedPriceSum += rec.getField(2, DoubleValue.class).getValue();
}
this.extendedPrice.setValue(partExtendedPriceSum);
rec.setField(2, this.extendedPrice);
out.collect(rec);
}
/**
* Creates partial sums on the price attribute for each data batch.
*/
@Override
public void combine(Iterator<Record> values, Collector<Record> out) {
reduce(values, out);
}
}
@Override
public Plan getPlan(final String... args) {
// parse program parameters
final int numSubtasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
final String ordersPath = (args.length > 1 ? args[1] : "");
final String lineitemsPath = (args.length > 2 ? args[2] : "");
final String output = (args.length > 3 ? args[3] : "");
// create DataSourceContract for Orders input
FileDataSource orders = new FileDataSource(new CsvInputFormat(), ordersPath, "Orders");
CsvInputFormat.configureRecordFormat(orders)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0) // order id
.field(IntValue.class, 7) // ship prio
.field(StringValue.class, 2, 2) // order status
.field(StringValue.class, 4, 10) // order date
.field(StringValue.class, 5, 8); // order prio
// create DataSourceContract for LineItems input
FileDataSource lineitems = new FileDataSource(new CsvInputFormat(), lineitemsPath, "LineItems");
CsvInputFormat.configureRecordFormat(lineitems)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0) // order id
.field(DoubleValue.class, 5); // extended price
// create MapOperator for filtering Orders tuples
MapOperator filterO = MapOperator.builder(new FilterO())
.input(orders)
.name("FilterO")
.build();
// filter configuration
filterO.setParameter(YEAR_FILTER, 1993);
filterO.setParameter(PRIO_FILTER, "5");
// compiler hints
filterO.getCompilerHints().setFilterFactor(0.05f);
// create JoinOperator for joining Orders and LineItems
JoinOperator joinLiO = JoinOperator.builder(new JoinLiO(), LongValue.class, 0, 0)
.input1(filterO)
.input2(lineitems)
.name("JoinLiO")
.build();
// create ReduceOperator for aggregating the result
// the reducer has a composite key, consisting of the fields 0 and 1
ReduceOperator aggLiO = ReduceOperator.builder(new AggLiO())
.keyField(LongValue.class, 0)
.keyField(StringValue.class, 1)
.input(joinLiO)
.name("AggLio")
.build();
// create DataSinkContract for writing the result
FileDataSink result = new FileDataSink(new CsvOutputFormat(), output, aggLiO, "Output");
CsvOutputFormat.configureRecordFormat(result)
.recordDelimiter('\n')
.fieldDelimiter('|')
.lenient(true)
.field(LongValue.class, 0)
.field(IntValue.class, 1)
.field(DoubleValue.class, 2);
// assemble the PACT plan
Plan plan = new Plan(result, "TPCH Q3");
plan.setDefaultParallelism(numSubtasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: [numSubStasks], [orders], [lineitem], [output]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.JoinOperator;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.test.recordJobs.relational.TPCHQuery3.AggLiO;
import org.apache.flink.test.recordJobs.relational.TPCHQuery3.FilterO;
import org.apache.flink.test.recordJobs.relational.TPCHQuery3.JoinLiO;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.StringValue;
/**
* The TPC-H is a decision support benchmark on relational data.
* Its documentation and the data generator (DBGEN) can be found
* on http://www.tpc.org/tpch/ .This implementation is tested with
* the DB2 data format.
* THe PACT program implements a modified version of the query 3 of
* the TPC-H benchmark including one join, some filtering and an
* aggregation.
*
* SELECT l_orderkey, o_shippriority, sum(l_extendedprice) as revenue
* FROM orders, lineitem
* WHERE l_orderkey = o_orderkey
* AND o_orderstatus = "X"
* AND YEAR(o_orderdate) > Y
* AND o_orderpriority LIKE "Z%"
* GROUP BY l_orderkey, o_shippriority;
*/
@SuppressWarnings("deprecation")
public class TPCHQuery3Unioned implements Program, ProgramDescription {
private static final long serialVersionUID = 1L;
@Override
public Plan getPlan(final String... args) {
// parse program parameters
final int numSubtasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
String orders1Path = (args.length > 1 ? args[1] : "");
String orders2Path = (args.length > 2 ? args[2] : "");
String partJoin1Path = (args.length > 3 ? args[3] : "");
String partJoin2Path = (args.length > 4 ? args[4] : "");
String lineitemsPath = (args.length > 5 ? args[5] : "");
String output = (args.length > 6 ? args[6] : "");
// create DataSourceContract for Orders input
FileDataSource orders1 = new FileDataSource(new CsvInputFormat(), orders1Path, "Orders 1");
CsvInputFormat.configureRecordFormat(orders1)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0) // order id
.field(IntValue.class, 7) // ship prio
.field(StringValue.class, 2, 2) // order status
.field(StringValue.class, 4, 10) // order date
.field(StringValue.class, 5, 8); // order prio
FileDataSource orders2 = new FileDataSource(new CsvInputFormat(), orders2Path, "Orders 2");
CsvInputFormat.configureRecordFormat(orders2)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0) // order id
.field(IntValue.class, 7) // ship prio
.field(StringValue.class, 2, 2) // order status
.field(StringValue.class, 4, 10) // order date
.field(StringValue.class, 5, 8); // order prio
// create DataSourceContract for LineItems input
FileDataSource lineitems = new FileDataSource(new CsvInputFormat(), lineitemsPath, "LineItems");
CsvInputFormat.configureRecordFormat(lineitems)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0)
.field(DoubleValue.class, 5);
// create MapOperator for filtering Orders tuples
MapOperator filterO1 = MapOperator.builder(new FilterO())
.name("FilterO")
.input(orders1)
.build();
// filter configuration
filterO1.setParameter(TPCHQuery3.YEAR_FILTER, 1993);
filterO1.setParameter(TPCHQuery3.PRIO_FILTER, "5");
filterO1.getCompilerHints().setFilterFactor(0.05f);
// create MapOperator for filtering Orders tuples
MapOperator filterO2 = MapOperator.builder(new FilterO())
.name("FilterO")
.input(orders2)
.build();
// filter configuration
filterO2.setParameter(TPCHQuery3.YEAR_FILTER, 1993);
filterO2.setParameter(TPCHQuery3.PRIO_FILTER, "5");
// create JoinOperator for joining Orders and LineItems
@SuppressWarnings("unchecked")
JoinOperator joinLiO = JoinOperator.builder(new JoinLiO(), LongValue.class, 0, 0)
.input1(filterO2, filterO1)
.input2(lineitems)
.name("JoinLiO")
.build();
FileDataSource partJoin1 = new FileDataSource(new CsvInputFormat(), partJoin1Path, "Part Join 1");
CsvInputFormat.configureRecordFormat(partJoin1)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0)
.field(IntValue.class, 1)
.field(DoubleValue.class, 2);
FileDataSource partJoin2 = new FileDataSource(new CsvInputFormat(), partJoin2Path, "Part Join 2");
CsvInputFormat.configureRecordFormat(partJoin2)
.recordDelimiter('\n')
.fieldDelimiter('|')
.field(LongValue.class, 0)
.field(IntValue.class, 1)
.field(DoubleValue.class, 2);
// create ReduceOperator for aggregating the result
// the reducer has a composite key, consisting of the fields 0 and 1
@SuppressWarnings("unchecked")
ReduceOperator aggLiO = ReduceOperator.builder(new AggLiO())
.keyField(LongValue.class, 0)
.keyField(StringValue.class, 1)
.input(joinLiO, partJoin2, partJoin1)
.name("AggLio")
.build();
// create DataSinkContract for writing the result
FileDataSink result = new FileDataSink(new CsvOutputFormat(), output, aggLiO, "Output");
CsvOutputFormat.configureRecordFormat(result)
.recordDelimiter('\n')
.fieldDelimiter('|')
.lenient(true)
.field(LongValue.class, 0)
.field(IntValue.class, 1)
.field(DoubleValue.class, 2);
// assemble the PACT plan
Plan plan = new Plan(result, "TPCH Q3 Unioned");
plan.setDefaultParallelism(numSubtasks);
return plan;
}
@Override
public String getDescription() {
return "Parameters: [numSubStasks], [orders1], [orders2], [partJoin1], [partJoin2], [lineitem], [output]";
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.recordJobs.relational.query1Util;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.test.recordJobs.util.Tuple;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Filters the line item tuples according to the filter condition
* l_shipdate <= date '1998-12-01' - interval '[DELTA]' day (3)
* TODO: add parametrisation; first version uses a static interval = 90
*
* In prepration of the following reduce step (see {@link GroupByReturnFlag}) the key has to be set to &quot;return flag&quot;
*/
@SuppressWarnings("deprecation")
public class LineItemFilter extends MapFunction {
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(LineItemFilter.class);
private static final String DATE_CONSTANT = "1998-09-03";
private static final DateFormat format = new SimpleDateFormat("yyyy-MM-dd");
private final Date constantDate;
public LineItemFilter() {
try {
this.constantDate = format.parse(DATE_CONSTANT);
}
catch (ParseException e) {
LOG.error("Date constant could not be parsed.", e);
throw new RuntimeException("Date constant could not be parsed.");
}
}
@Override
public void map(Record record, Collector<Record> out) throws Exception {
Tuple value = record.getField(1, Tuple.class);
if (value != null && value.getNumberOfColumns() >= 11) {
String shipDateString = value.getStringValueAt(10);
try {
Date shipDate = format.parse(shipDateString);
if (shipDate.before(constantDate)) {
String returnFlag = value.getStringValueAt(8);
record.setField(0, new StringValue(returnFlag));
out.collect(record);
}
}
catch (ParseException e) {
LOG.warn("ParseException while parsing the shipping date.", e);
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册