提交 83debdb3 编写于 作者: A Aljoscha Krettek

[scala] Add Field Name Key Support for Scala Case Classes

This does not change the runtime behavious. The key field names are
mapped to tuple indices at pre-flight time.

Also extends tests to cover the new feature.
上级 299cef75
......@@ -85,9 +85,9 @@ object EnumTrianglesBasic {
val triangles = edgesById
// build triads
.groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup(new TriadBuilder())
.groupBy("v1").sortGroup("v2", Order.ASCENDING).reduceGroup(new TriadBuilder())
// filter triads
.join(edgesById).where(1,2).equalTo(0,1) { (t, _) => Some(t) }
.join(edgesById).where("v2", "v3").equalTo("v1", "v2") { (t, _) => Some(t) }
// emit result
if (fileOutput) {
......@@ -163,8 +163,7 @@ object EnumTrianglesBasic {
private def getEdgeDataSet(env: ExecutionEnvironment): DataSet[Edge] = {
if (fileOutput) {
env.readCsvFile[(Int, Int)](edgePath, fieldDelimiter = ' ', includedFields = Array(0, 1)).
map { x => new Edge(x._1, x._2) }
env.readCsvFile[Edge](edgePath, fieldDelimiter = ' ', includedFields = Array(0, 1))
} else {
val edges = EnumTrianglesData.EDGES.map{ case Array(v1, v2) => new Edge(v1.asInstanceOf[Int], v2.asInstanceOf[Int]) }
env.fromCollection(edges)
......
......@@ -30,206 +30,215 @@ import scala.collection.mutable.MutableList
/**
* Triangle enumeration is a pre-processing step to find closely connected parts in graphs.
* A triangle consists of three edges that connect three vertices with each other.
*
* <p>
* The basic algorithm works as follows:
*
* The basic algorithm works as follows:
* It groups all edges that share a common vertex and builds triads, i.e., triples of vertices
* that are connected by two edges. Finally, all triads are filtered for which no third edge exists
* that closes the triangle.
*
* <p>
* For a group of <i>n</i> edges that share a common vertex, the number of built triads is quadratic <i>((n*(n-1))/2)</i>.
* Therefore, an optimization of the algorithm is to group edges on the vertex with the smaller output degree to
* reduce the number of triads.
*
* For a group of ''i'' edges that share a common vertex, the number of built triads is quadratic
* ''(n*(n-1))/2)''. Therefore, an optimization of the algorithm is to group edges on the vertex
* with the smaller output degree to reduce the number of triads.
* This implementation extends the basic algorithm by computing output degrees of edge vertices and
* grouping on edges on the vertex with the smaller degree.
*
* <p>
*
* Input files are plain text files and must be formatted as follows:
* <ul>
* <li>Edges are represented as pairs for vertex IDs which are separated by space
* characters. Edges are separated by new-line characters.<br>
* For example <code>"1 2\n2 12\n1 12\n42 63\n"</code> gives four (undirected) edges (1)-(2), (2)-(12), (1)-(12), and (42)-(63)
* that include a triangle
* </ul>
*
* - Edges are represented as pairs for vertex IDs which are separated by space
* characters. Edges are separated by new-line characters.
* For example `"1 2\n2 12\n1 12\n42 63\n"` gives four (undirected) edges (1)-(2), (2)-(12),
* (1)-(12), and (42)-(63) that include a triangle
*
* <pre>
* (1)
* / \
* (2)-(12)
* </pre>
*
* Usage: <code>EnumTriangleOpt &lt;edge path&gt; &lt;result path&gt;</code><br>
* If no parameters are provided, the program is run with default data from {@link EnumTrianglesData}.
*
* <p>
*
* Usage:
* {{{
* EnumTriangleOpt <edge path> <result path>
* }}}
*
* If no parameters are provided, the program is run with default data from
* [[org.apache.flink.example.java.graph.util.EnumTrianglesData]].
*
* This example shows how to use:
* <ul>
* <li>Custom Java objects which extend Tuple
* <li>Group Sorting
* </ul>
*
*
* - Custom Java objects which extend Tuple
* - Group Sorting
*
*/
object EnumTrianglesOpt {
def main(args: Array[String]) {
if (!parseParameters(args)) {
return
}
// set up execution environment
val env = ExecutionEnvironment.getExecutionEnvironment
// read input data
val edges = getEdgeDataSet(env)
val edgesWithDegrees = edges
// duplicate and switch edges
.flatMap( e => Array(e, Edge(e.v2, e.v1)) )
// add degree of first vertex
.groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup(new DegreeCounter())
// join degrees of vertices
.groupBy(0,2).reduce( (e1, e2) => if(e1.d2 == 0)
new EdgeWithDegrees(e1.v1, e1.d1, e1.v2, e2.d2)
else
new EdgeWithDegrees(e1.v1, e2.d1, e1.v2, e1.d2)
)
// project edges by degrees, vertex with smaller degree comes first
val edgesByDegree = edgesWithDegrees
.map(e => if (e.d1 < e.d2) Edge(e.v1, e.v2) else Edge(e.v2, e.v1) )
// project edges by Id, vertex with smaller Id comes first
val edgesById = edgesByDegree
.map(e => if (e.v1 < e.v2) e else Edge(e.v2, e.v1) )
val triangles = edgesByDegree
// build triads
.groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup(new TriadBuilder())
// filter triads
.join(edgesById).where(1,2).equalTo(0,1) { (t, _) => Some(t) }
// emit result
if (fileOutput) {
triangles.writeAsCsv(outputPath, "\n", " ")
} else {
triangles.print()
}
// execute program
env.execute("TriangleEnumeration Example")
}
// *************************************************************************
// USER DATA TYPES
// *************************************************************************
case class Edge(v1: Int, v2: Int) extends Serializable
case class Triad(v1: Int, v2: Int, v3: Int) extends Serializable
case class EdgeWithDegrees(v1: Int, d1: Int, v2: Int, d2: Int) extends Serializable
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/**
* Counts the number of edges that share a common vertex.
* Emits one edge for each input edge with a degree annotation for the shared vertex.
* For each emitted edge, the first vertex is the vertex with the smaller id.
*/
class DegreeCounter extends GroupReduceFunction[Edge, EdgeWithDegrees] {
val vertices = MutableList[Integer]()
var groupVertex = 0
override def reduce(edges: java.lang.Iterable[Edge], out: Collector[EdgeWithDegrees]) = {
// empty vertex list
vertices.clear
// collect all vertices
for(e <- edges.asScala) {
groupVertex = e.v1
if(!vertices.contains(e.v2) && e.v1 != e.v2) {
vertices += e.v2
}
}
// count vertices to obtain degree of groupVertex
val degree = vertices.length
// create and emit edges with degrees
for(v <- vertices) {
if (v < groupVertex) {
out.collect(new EdgeWithDegrees(v, 0, groupVertex, degree))
} else {
out.collect(new EdgeWithDegrees(groupVertex, degree, v, 0))
}
}
}
}
/**
* Builds triads (triples of vertices) from pairs of edges that share a vertex.
* The first vertex of a triad is the shared vertex, the second and third vertex are ordered by vertexId.
* Assumes that input edges share the first vertex and are in ascending order of the second vertex.
*/
class TriadBuilder extends GroupReduceFunction[Edge, Triad] {
val vertices = MutableList[Integer]()
override def reduce(edges: java.lang.Iterable[Edge], out: Collector[Triad]) = {
// clear vertex list
vertices.clear
// build and emit triads
for(e <- edges.asScala) {
// combine vertex with all previously read vertices
for(v <- vertices) {
out.collect(Triad(e.v1, v, e.v2))
}
vertices += e.v2
}
}
}
// *************************************************************************
// UTIL METHODS
// *************************************************************************
private def parseParameters(args: Array[String]): Boolean = {
if (args.length > 0) {
fileOutput = true
if (args.length == 2) {
edgePath = args(0)
outputPath = args(1)
} else {
System.err.println("Usage: EnumTriangleBasic <edge path> <result path>")
false
}
} else {
System.out.println("Executing Enum Triangles Basic example with built-in default data.");
System.out.println(" Provide parameters to read input data from files.");
System.out.println(" See the documentation for the correct format of input files.");
System.out.println(" Usage: EnumTriangleBasic <edge path> <result path>");
}
true
}
private def getEdgeDataSet(env: ExecutionEnvironment): DataSet[Edge] = {
if (fileOutput) {
env.readCsvFile[(Int, Int)](edgePath, fieldDelimiter = ' ', includedFields = Array(0, 1)).
map { x => new Edge(x._1, x._2) }
} else {
val edges = EnumTrianglesData.EDGES.map{ case Array(v1, v2) => new Edge(v1.asInstanceOf[Int], v2.asInstanceOf[Int]) }
env.fromCollection(edges)
}
}
private var fileOutput: Boolean = false
private var edgePath: String = null
private var outputPath: String = null
def main(args: Array[String]) {
if (!parseParameters(args)) {
return
}
// set up execution environment
val env = ExecutionEnvironment.getExecutionEnvironment
// read input data
val edges = getEdgeDataSet(env)
val edgesWithDegrees = edges
// duplicate and switch edges
.flatMap(e => Seq(e, Edge(e.v2, e.v1)))
// add degree of first vertex
.groupBy("v1").sortGroup("v2", Order.ASCENDING).reduceGroup(new DegreeCounter())
// join degrees of vertices
.groupBy("v1", "v2").reduce {
(e1, e2) =>
if (e1.d2 == 0) {
new EdgeWithDegrees(e1.v1, e1.d1, e1.v2, e2.d2)
} else {
new EdgeWithDegrees(e1.v1, e2.d1, e1.v2, e1.d2)
}
}
// project edges by degrees, vertex with smaller degree comes first
val edgesByDegree = edgesWithDegrees
.map(e => if (e.d1 <= e.d2) Edge(e.v1, e.v2) else Edge(e.v2, e.v1))
// project edges by Id, vertex with smaller Id comes first
val edgesById = edgesByDegree
.map(e => if (e.v1 < e.v2) e else Edge(e.v2, e.v1))
val triangles = edgesByDegree
// build triads
.groupBy("v1").sortGroup("v2", Order.ASCENDING).reduceGroup(new TriadBuilder())
// filter triads
.join(edgesById).where("v2", "v3").equalTo("v1", "v2") { (t, _) => Some(t)}
// emit result
if (fileOutput) {
triangles.writeAsCsv(outputPath, "\n", ",")
} else {
triangles.print()
}
// execute program
env.execute("TriangleEnumeration Example")
}
// *************************************************************************
// USER DATA TYPES
// *************************************************************************
case class Edge(v1: Int, v2: Int) extends Serializable
case class Triad(v1: Int, v2: Int, v3: Int) extends Serializable
case class EdgeWithDegrees(v1: Int, d1: Int, v2: Int, d2: Int) extends Serializable
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/**
* Counts the number of edges that share a common vertex.
* Emits one edge for each input edge with a degree annotation for the shared vertex.
* For each emitted edge, the first vertex is the vertex with the smaller id.
*/
class DegreeCounter extends GroupReduceFunction[Edge, EdgeWithDegrees] {
val vertices = mutable.MutableList[Integer]()
var groupVertex = 0
override def reduce(edges: java.lang.Iterable[Edge], out: Collector[EdgeWithDegrees]) = {
// empty vertex list
vertices.clear()
// collect all vertices
for (e <- edges.asScala) {
groupVertex = e.v1
if (!vertices.contains(e.v2) && e.v1 != e.v2) {
vertices += e.v2
}
}
// count vertices to obtain degree of groupVertex
val degree = vertices.length
// create and emit edges with degrees
for (v <- vertices) {
if (v < groupVertex) {
out.collect(new EdgeWithDegrees(v, 0, groupVertex, degree))
} else {
out.collect(new EdgeWithDegrees(groupVertex, degree, v, 0))
}
}
}
}
/**
* Builds triads (triples of vertices) from pairs of edges that share a vertex.
* The first vertex of a triad is the shared vertex, the second and third vertex are ordered by
* vertexId.
* Assumes that input edges share the first vertex and are in ascending order of the second
* vertex.
*/
class TriadBuilder extends GroupReduceFunction[Edge, Triad] {
val vertices = mutable.MutableList[Integer]()
override def reduce(edges: java.lang.Iterable[Edge], out: Collector[Triad]) = {
// clear vertex list
vertices.clear()
// build and emit triads
for (e <- edges.asScala) {
// combine vertex with all previously read vertices
for (v <- vertices) {
out.collect(Triad(e.v1, v, e.v2))
}
vertices += e.v2
}
}
}
// *************************************************************************
// UTIL METHODS
// *************************************************************************
private def parseParameters(args: Array[String]): Boolean = {
if (args.length > 0) {
fileOutput = true
if (args.length == 2) {
edgePath = args(0)
outputPath = args(1)
} else {
System.err.println("Usage: EnumTriangleOpt <edge path> <result path>")
false
}
} else {
System.out.println("Executing Enum Triangles Optimized example with built-in default data.")
System.out.println(" Provide parameters to read input data from files.")
System.out.println(" See the documentation for the correct format of input files.")
System.out.println(" Usage: EnumTriangleBasic <edge path> <result path>")
}
true
}
private def getEdgeDataSet(env: ExecutionEnvironment): DataSet[Edge] = {
if (fileOutput) {
env.readCsvFile[Edge](
edgePath,
fieldDelimiter = ' ',
includedFields = Array(0, 1))
} else {
val edges = EnumTrianglesData.EDGES.map {
case Array(v1, v2) => new Edge(v1.asInstanceOf[Int], v2.asInstanceOf[Int])}
env.fromCollection(edges)
}
}
private var fileOutput: Boolean = false
private var edgePath: String = null
private var outputPath: String = null
}
\ No newline at end of file
......@@ -17,169 +17,184 @@
*/
package org.apache.flink.examples.scala.graph
import scala.collection.JavaConverters._
import org.apache.flink.api.scala._
import org.apache.flink.example.java.graph.util.PageRankData
import org.apache.flink.util.Collector
import org.apache.flink.api.java.aggregation.Aggregations.SUM
import org.apache.flink.util.Collector
/**
* A basic implementation of the Page Rank algorithm using a bulk iteration.
*
* <p>
* This implementation requires a set of pages and a set of directed links as input and works as follows. <br>
* In each iteration, the rank of every page is evenly distributed to all pages it points to.
* Each page collects the partial ranks of all pages that point to it, sums them up, and applies a dampening factor to the sum.
* The result is the new rank of the page. A new iteration is started with the new ranks of all pages.
* This implementation terminates after a fixed number of iterations.<br>
* This is the Wikipedia entry for the <a href="http://en.wikipedia.org/wiki/Page_rank">Page Rank algorithm</a>.
* This implementation requires a set of pages and a set of directed links as input and works as
* follows.
*
* In each iteration, the rank of every page is evenly distributed to all pages it points to. Each
* page collects the partial ranks of all pages that point to it, sums them up, and applies a
* dampening factor to the sum. The result is the new rank of the page. A new iteration is started
* with the new ranks of all pages. This implementation terminates after a fixed number of
* iterations. This is the Wikipedia entry for the
* [[http://en.wikipedia.org/wiki/Page_rank Page Rank algorithm]]
*
* <p>
* Input files are plain text files and must be formatted as follows:
* <ul>
* <li>Pages represented as an (long) ID separated by new-line characters.<br>
* For example <code>"1\n2\n12\n42\n63\n"</code> gives five pages with IDs 1, 2, 12, 42, and 63.
* <li>Links are represented as pairs of page IDs which are separated by space
* characters. Links are separated by new-line characters.<br>
* For example <code>"1 2\n2 12\n1 12\n42 63\n"</code> gives four (directed) links (1)->(2), (2)->(12), (1)->(12), and (42)->(63).<br>
* For this simple implementation it is required that each page has at least one incoming and one outgoing link (a page can point to itself).
* </ul>
*
* <p>
* Usage: <code>PageRankBasic &lt;pages path&gt; &lt;links path&gt; &lt;output path&gt; &lt;num pages&gt; &lt;num iterations&gt;</code><br>
* If no parameters are provided, the program is run with default data from {@link PageRankData} and 10 iterations.
*
* - Pages represented as an (long) ID separated by new-line characters.
* For example `"1\n2\n12\n42\n63\n"` gives five pages with IDs 1, 2, 12, 42, and 63.
* - Links are represented as pairs of page IDs which are separated by space characters. Links
* are separated by new-line characters.
* For example `"1 2\n2 12\n1 12\n42 63\n"` gives four (directed) links (1)->(2), (2)->(12),
* (1)->(12), and (42)->(63). For this simple implementation it is required that each page has
* at least one incoming and one outgoing link (a page can point to itself).
*
* Usage:
* {{{
* PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>
* }}}
*
* If no parameters are provided, the program is run with default data from
* [[org.apache.flink.example.java.graph.util.PageRankData]] and 10 iterations.
*
* <p>
* This example shows how to use:
* <ul>
* <li>Bulk Iterations
* <li>Default Join
* <li>Configure user-defined functions using constructor parameters.
* </ul>
*
* - Bulk Iterations
* - Default Join
* - Configure user-defined functions using constructor parameters.
*
*/
object PageRankBasic {
private final val DAMPENING_FACTOR: Double = 0.85;
private final val EPSILON: Double = 0.0001;
def main(args: Array[String]) {
if (!parseParameters(args)) {
return
}
// set up execution environment
val env = ExecutionEnvironment.getExecutionEnvironment
// read input data
val pages = getPagesDataSet(env)
val links = getLinksDataSet(env)
// assign initial ranks to pages
val pagesWithRanks = pages.map(p => Page(p, (1.0/numPages)))
// build adjacency list from link input
val adjacencyLists = links
// initialize lists
.map( e => AdjacencyList(e.sourceId, Array[java.lang.Long](e.targetId) ))
// concatenate lists
.groupBy(0).reduce( (l1, l2) => AdjacencyList(l1.sourceId, l1.targetIds ++ l2.targetIds))
// start iteration
val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
currentRanks =>
val newRanks = currentRanks
// distribute ranks to target pages
.join(adjacencyLists).where(0).equalTo(0)
.flatMap { x => for(targetId <- x._2.targetIds) yield Page(targetId, (x._1.rank / x._2.targetIds.length))}
// collect ranks and sum them up
.groupBy(0).aggregate(SUM, 1)
// apply dampening factor
.map { p => Page(p.pageId, (p.rank * DAMPENING_FACTOR) + ((1 - DAMPENING_FACTOR) / numPages) ) }
// terminate if no rank update was significant
val termination = currentRanks
.join(newRanks).where(0).equalTo(0)
// check for significant update
.filter( x => math.abs(x._1.rank - x._2.rank) > EPSILON )
(newRanks, termination)
}
val result = finalRanks;
// emit result
if (fileOutput) {
result.writeAsCsv(outputPath, "\n", " ")
} else {
result.print()
}
// execute program
env.execute("Basic PageRank Example")
}
// *************************************************************************
// USER TYPES
// *************************************************************************
case class Link(sourceId: Long, targetId: Long)
case class Page(pageId: java.lang.Long, rank: Double)
case class AdjacencyList(sourceId: java.lang.Long, targetIds: Array[java.lang.Long])
// *************************************************************************
// UTIL METHODS
// *************************************************************************
private def parseParameters(args: Array[String]): Boolean = {
if (args.length > 0) {
fileOutput = true
if (args.length == 5) {
pagesInputPath = args(0)
linksInputPath = args(1)
outputPath = args(2)
numPages = args(3).toLong
maxIterations = args(4).toInt
} else {
System.err.println("Usage: PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>");
false
}
} else {
System.out.println("Executing PageRank Basic example with default parameters and built-in default data.");
System.out.println(" Provide parameters to read input data from files.");
System.out.println(" See the documentation for the correct format of input files.");
System.out.println(" Usage: PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>");
numPages = PageRankData.getNumberOfPages();
}
true
}
private def getPagesDataSet(env: ExecutionEnvironment): DataSet[Long] = {
if(fileOutput) {
env.readCsvFile[Tuple1[Long]](pagesInputPath, fieldDelimiter = ' ', lineDelimiter = "\n")
.map(x => x._1)
} else {
env.fromCollection(Seq.range(1, PageRankData.getNumberOfPages()+1))
}
}
private def getLinksDataSet(env: ExecutionEnvironment): DataSet[Link] = {
if (fileOutput) {
env.readCsvFile[(Long, Long)](linksInputPath, fieldDelimiter = ' ', includedFields = Array(0, 1))
.map { x => Link(x._1, x._2) }
} else {
val edges = PageRankData.EDGES.map{ case Array(v1, v2) => Link(v1.asInstanceOf[Long], v2.asInstanceOf[Long]) }
env.fromCollection(edges)
}
}
private var fileOutput: Boolean = false
private var pagesInputPath: String = null
private var linksInputPath: String = null
private var outputPath: String = null
private var numPages: Long = 0;
private var maxIterations: Int = 10;
private final val DAMPENING_FACTOR: Double = 0.85
private final val EPSILON: Double = 0.0001
def main(args: Array[String]) {
if (!parseParameters(args)) {
return
}
// set up execution environment
val env = ExecutionEnvironment.getExecutionEnvironment
// read input data
val pages = getPagesDataSet(env)
val links = getLinksDataSet(env)
// assign initial ranks to pages
val pagesWithRanks = pages.map(p => Page(p, 1.0 / numPages))
// build adjacency list from link input
val adjacencyLists = links
// initialize lists
.map(e => AdjacencyList(e.sourceId, Array(e.targetId)))
// concatenate lists
.groupBy("sourceId").reduce((l1, l2) => AdjacencyList(l1.sourceId, l1.targetIds ++ l2.targetIds))
// start iteration
val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
currentRanks =>
val newRanks = currentRanks
// distribute ranks to target pages
.join(adjacencyLists).where("pageId").equalTo("sourceId") {
(page, adjacent, out: Collector[Page]) =>
for (targetId <- adjacent.targetIds) {
out.collect(Page(targetId, page.rank / adjacent.targetIds.length))
}
}
// collect ranks and sum them up
.groupBy("pageId").aggregate(SUM, "rank")
// apply dampening factor
.map { p =>
Page(p.pageId, (p.rank * DAMPENING_FACTOR) + ((1 - DAMPENING_FACTOR) / numPages))
}
// terminate if no rank update was significant
val termination = currentRanks.join(newRanks).where("pageId").equalTo("pageId") {
(current, next) =>
// check for significant update
if (math.abs(current.rank - next.rank) > EPSILON) Some(1) else None
}
(newRanks, termination)
}
val result = finalRanks
// emit result
if (fileOutput) {
result.writeAsCsv(outputPath, "\n", " ")
} else {
result.print()
}
// execute program
env.execute("Basic PageRank Example")
}
// *************************************************************************
// USER TYPES
// *************************************************************************
case class Link(sourceId: Long, targetId: Long)
case class Page(pageId: Long, rank: Double)
case class AdjacencyList(sourceId: Long, targetIds: Array[Long])
// *************************************************************************
// UTIL METHODS
// *************************************************************************
private def parseParameters(args: Array[String]): Boolean = {
if (args.length > 0) {
fileOutput = true
if (args.length == 5) {
pagesInputPath = args(0)
linksInputPath = args(1)
outputPath = args(2)
numPages = args(3).toLong
maxIterations = args(4).toInt
} else {
System.err.println("Usage: PageRankBasic <pages path> <links path> <output path> <num " +
"pages> <num iterations>")
false
}
} else {
System.out.println("Executing PageRank Basic example with default parameters and built-in " +
"default data.")
System.out.println(" Provide parameters to read input data from files.")
System.out.println(" See the documentation for the correct format of input files.")
System.out.println(" Usage: PageRankBasic <pages path> <links path> <output path> <num " +
"pages> <num iterations>")
numPages = PageRankData.getNumberOfPages
}
true
}
private def getPagesDataSet(env: ExecutionEnvironment): DataSet[Long] = {
if (fileOutput) {
env.readCsvFile[Tuple1[Long]](pagesInputPath, fieldDelimiter = ' ', lineDelimiter = "\n")
.map(x => x._1)
} else {
env.generateSequence(1, 15)
}
}
private def getLinksDataSet(env: ExecutionEnvironment): DataSet[Link] = {
if (fileOutput) {
env.readCsvFile[Link](linksInputPath, fieldDelimiter = ' ',
includedFields = Array(0, 1))
} else {
val edges = PageRankData.EDGES.map { case Array(v1, v2) => Link(v1.asInstanceOf[Long],
v2.asInstanceOf[Long])}
env.fromCollection(edges)
}
}
private var fileOutput: Boolean = false
private var pagesInputPath: String = null
private var linksInputPath: String = null
private var outputPath: String = null
private var numPages: Long = 0
private var maxIterations: Int = 10
}
\ No newline at end of file
......@@ -26,22 +26,23 @@ object PiEstimation {
val numSamples: Long = if (args.length > 0) args(0).toLong else 1000000
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val env = ExecutionEnvironment.getExecutionEnvironment
// count how many of the samples would randomly fall into
// the unit circle
// the upper right quadrant of the unit circle
val count =
env.generateSequence(1, numSamples)
.map (sample => {
val x = Math.random()
val y = Math.random()
if (x * x + y * y < 1) 1L else 0L
})
.reduce(_+_)
// the ratio of the unit circle surface to 4 times the unit square is pi
.map { sample =>
val x = Math.random()
val y = Math.random()
if (x * x + y * y < 1) 1L else 0L
}
.reduce(_+_)
// ratio of samples in upper right quadrant vs total samples gives surface of upper
// right quadrant, times 4 gives surface of whole unit circle, i.e. PI
val pi = count
.map (_ * 4.0 / numSamples)
.map ( _ * 4.0 / numSamples)
println("We estimate Pi to be:")
......
......@@ -30,6 +30,7 @@ import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator}
import org.apache.flink.api.scala.typeutils.ScalaTupleTypeInfo
import org.apache.flink.core.fs.FileSystem.WriteMode
import org.apache.flink.core.fs.{FileSystem, Path}
import org.apache.flink.types.TypeInformation
......@@ -375,6 +376,25 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
case _ => wrap(new ScalaAggregateOperator[T](set, agg, field))
}
/**
* Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
* function. Since this is not a keyed DataSet the aggregation will be performed on the whole
* collection of elements.
*
* This only works on CaseClass DataSets.
*/
def aggregate(agg: Aggregations, field: String): DataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, fieldIndex)
wrap(aggregation)
case _ => wrap(new ScalaAggregateOperator[T](set, agg, fieldIndex))
}
}
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
......@@ -396,6 +416,27 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
aggregate(Aggregations.MIN, field)
}
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: String) = {
aggregate(Aggregations.SUM, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: String) = {
aggregate(Aggregations.MAX, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: String) = {
aggregate(Aggregations.MIN, field)
}
/**
* Creates a new [[DataSet]] by merging the elements of this DataSet using an associative reduce
* function.
......@@ -486,7 +527,7 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
* Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
* two elements are distinct or not is made based on only the specified tuple fields.
*
* This only works if this DataSet contains Tuples.
* This only works on tuple DataSets.
*/
def distinct(fields: Int*): DataSet[T] = {
wrap(new DistinctOperator[T](
......@@ -494,6 +535,19 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
new Keys.FieldPositionKeys[T](fields.toArray, set.getType, true)))
}
/**
* Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
* two elements are distinct or not is made based on only the specified fields.
*
* This only works on CaseClass DataSets
*/
def distinct(firstField: String, otherFields: String*): DataSet[T] = {
val fieldIndices = fieldNames2Indices(set.getType, firstField +: otherFields.toArray)
wrap(new DistinctOperator[T](
set,
new Keys.FieldPositionKeys[T](fieldIndices, set.getType, true)))
}
/**
* Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
* two elements are distinct or not is made based on all tuple fields.
......@@ -539,6 +593,23 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
new Keys.FieldPositionKeys[T](fields.toArray, set.getType,false))
}
/**
* Creates a [[GroupedDataSet]] which provides operations on groups of elements. Elements are
* grouped based on the given fields.
*
* This will not create a new DataSet, it will just attach the field names which will be
* used for grouping when executing a grouped operation.
*
* This only works on CaseClass DataSets.
*/
def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
val fieldIndices = fieldNames2Indices(set.getType, firstField +: otherFields.toArray)
new GroupedDataSetImpl[T](
set,
new Keys.FieldPositionKeys[T](fieldIndices, set.getType,false))
}
// public UnsortedGrouping<T> groupBy(String... fields) {
// new UnsortedGrouping<T>(this, new Keys.ExpressionKeys<T>(fields, getType()));
// }
......@@ -587,21 +658,21 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
* }}}
*/
def join[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperationImpl(this.set, other.set, JoinHint.OPTIMIZER_CHOOSES)
new UnfinishedJoinOperationImpl(this, other, JoinHint.OPTIMIZER_CHOOSES)
/**
* Special [[join]] operation for explicitly telling the system that the right side is assumed
* to be a lot smaller than the left side of the join.
*/
def joinWithTiny[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperationImpl(this.set, other.set, JoinHint.BROADCAST_HASH_SECOND)
new UnfinishedJoinOperationImpl(this, other, JoinHint.BROADCAST_HASH_SECOND)
/**
* Special [[join]] operation for explicitly telling the system that the left side is assumed
* to be a lot smaller than the right side of the join.
*/
def joinWithHuge[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperationImpl(this.set, other.set, JoinHint.BROADCAST_HASH_FIRST)
new UnfinishedJoinOperationImpl(this, other, JoinHint.BROADCAST_HASH_FIRST)
// --------------------------------------------------------------------------------------------
// Co-Group
......@@ -641,7 +712,7 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
* }}}
*/
def coGroup[O: ClassTag](other: DataSet[O]): UnfinishedCoGroupOperation[T, O] =
new UnfinishedCoGroupOperationImpl(this.set, other.set)
new UnfinishedCoGroupOperationImpl(this, other)
// --------------------------------------------------------------------------------------------
// Cross
......
......@@ -19,6 +19,7 @@ package org.apache.flink.api.scala
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.scala.operators.ScalaAggregateOperator
import org.apache.flink.api.scala.typeutils.ScalaTupleTypeInfo
import scala.collection.JavaConverters._
......@@ -45,10 +46,20 @@ trait GroupedDataSet[T] {
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`
* use one of the group-at-a-time, i.e. `reduceGroup`.
*
* This only works on Tuple DataSets.
*/
def sortGroup(field: Int, order: Order): GroupedDataSet[T]
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`.
*
* This only works on CaseClass DataSets.
*/
def sortGroup(field: String, order: Order): GroupedDataSet[T]
/**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
......@@ -58,6 +69,15 @@ trait GroupedDataSet[T] {
*/
def aggregate(agg: Aggregations, field: Int): DataSet[T]
/**
* Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* elements with the same key.
*
* This only works on CaseClass DataSets.
*/
def aggregate(agg: Aggregations, field: String): DataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
......@@ -73,6 +93,21 @@ trait GroupedDataSet[T] {
*/
def min(field: Int): DataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: String): DataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: String): DataSet[T]
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: String): DataSet[T]
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
......@@ -124,14 +159,10 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
private val groupSortKeyPositions = mutable.MutableList[Int]()
private val groupSortOrders = mutable.MutableList[Order]()
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
* use one of the group-at-a-time, i.e. `reduceGroup`
*/
def sortGroup(field: Int, order: Order): GroupedDataSet[T] = {
if (!set.getType.isTupleType) {
throw new InvalidProgramException("Specifying order keys via field positions is only valid " +
"for tuple data types")
"for tuple data types.")
}
if (field >= set.getType.getArity) {
throw new IllegalArgumentException("Order key out of tuple bounds.")
......@@ -141,10 +172,14 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
this
}
/**
* Creates a [[SortedGrouping]] if any secondary sort fields were specified. Otherwise, just
* create an [[UnsortedGrouping]].
*/
def sortGroup(field: String, order: Order): GroupedDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
groupSortKeyPositions += fieldIndex
groupSortOrders += order
this
}
private def maybeCreateSortedGrouping(): Grouping[T] = {
if (groupSortKeyPositions.length > 0) {
val grouping = new SortedGrouping[T](set, keys, groupSortKeyPositions(0), groupSortOrders(0))
......@@ -161,13 +196,18 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
/** Convenience methods for creating the [[UnsortedGrouping]] */
private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set, keys)
/**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* tuples with the same key.
*
* This only works on Tuple DataSets.
*/
def aggregate(agg: Aggregations, field: String): DataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, fieldIndex)
wrap(aggregation)
case _ => wrap(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, fieldIndex))
}
}
def aggregate(agg: Aggregations, field: Int): DataSet[T] = set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, field)
......@@ -176,31 +216,30 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
case _ => wrap(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, field))
}
/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: Int): DataSet[T] = {
aggregate(Aggregations.SUM, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: Int): DataSet[T] = {
aggregate(Aggregations.MAX, field)
}
/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: Int): DataSet[T] = {
aggregate(Aggregations.MIN, field)
}
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
def sum(field: String): DataSet[T] = {
aggregate(Aggregations.SUM, field)
}
def max(field: String): DataSet[T] = {
aggregate(Aggregations.MAX, field)
}
def min(field: String): DataSet[T] = {
aggregate(Aggregations.MIN, field)
}
def reduce(fun: (T, T) => T): DataSet[T] = {
Validate.notNull(fun, "Reduce function must not be null.")
val reducer = new ReduceFunction[T] {
......@@ -211,20 +250,11 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer))
}
/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
def reduce(reducer: ReduceFunction[T]): DataSet[T] = {
Validate.notNull(reducer, "Reduce function must not be null.")
wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer))
}
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the group reduce function. The function must output one element. The
* concatenation of those will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T]) => R): DataSet[R] = {
Validate.notNull(fun, "Group reduce function must not be null.")
......@@ -238,11 +268,6 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the group reduce function. The function can output zero or more elements using
* the [[Collector]]. The concatenation of the emitted values will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], Collector[R]) => Unit): DataSet[R] = {
Validate.notNull(fun, "Group reduce function must not be null.")
......@@ -256,11 +281,6 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
/**
* Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
* of elements to the [[GroupReduceFunction]]. The function can output zero or more elements. The
* concatenation of the emitted values will form the resulting [[DataSet]].
*/
def reduceGroup[R: TypeInformation: ClassTag](reducer: GroupReduceFunction[T, R]): DataSet[R] = {
Validate.notNull(reducer, "GroupReduce function must not be null.")
wrap(
......
......@@ -94,8 +94,8 @@ trait CoGroupDataSet[T, O] extends DataSet[(Array[T], Array[O])] {
*/
private[flink] class CoGroupDataSetImpl[T, O](
coGroupOperator: CoGroupOperator[T, O, (Array[T], Array[O])],
thisSet: JavaDataSet[T],
otherSet: JavaDataSet[O],
thisSet: DataSet[T],
otherSet: DataSet[O],
thisKeys: Keys[T],
otherKeys: Keys[O]) extends DataSet(coGroupOperator) with CoGroupDataSet[T, O] {
......@@ -107,7 +107,7 @@ private[flink] class CoGroupDataSetImpl[T, O](
fun(left.iterator.asScala, right.iterator.asScala) map { out.collect(_) }
}
}
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, coGrouper, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
......@@ -120,14 +120,14 @@ private[flink] class CoGroupDataSetImpl[T, O](
fun(left.iterator.asScala, right.iterator.asScala, out)
}
}
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, coGrouper, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
def apply[R: TypeInformation: ClassTag](joiner: CoGroupFunction[T, O, R]): DataSet[R] = {
Validate.notNull(joiner, "CoGroup function must not be null.")
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, joiner, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
......@@ -153,8 +153,8 @@ trait UnfinishedCoGroupOperation[T, O]
* i.e. the parameters of the constructor, hidden.
*/
private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
leftSet: JavaDataSet[T],
rightSet: JavaDataSet[O])
leftSet: DataSet[T],
rightSet: DataSet[O])
extends UnfinishedKeyPairOperation[T, O, CoGroupDataSet[T, O]](leftSet, rightSet)
with UnfinishedCoGroupOperation[T, O] {
......@@ -173,11 +173,13 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
// We have to use this hack, for some reason classOf[Array[T]] does not work.
// Maybe because ObjectArrayTypeInfo does not accept the Scala Array as an array class.
val leftArrayType = ObjectArrayTypeInfo.getInfoFor(new Array[T](0).getClass, leftSet.getType)
val rightArrayType = ObjectArrayTypeInfo.getInfoFor(new Array[O](0).getClass, rightSet.getType)
val leftArrayType =
ObjectArrayTypeInfo.getInfoFor(new Array[T](0).getClass, leftSet.set.getType)
val rightArrayType =
ObjectArrayTypeInfo.getInfoFor(new Array[O](0).getClass, rightSet.set.getType)
val returnType = new ScalaTupleTypeInfo[(Array[T], Array[O])](
classOf[(Array[T], Array[O])], Seq(leftArrayType, rightArrayType)) {
classOf[(Array[T], Array[O])], Seq(leftArrayType, rightArrayType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(Array[T], Array[O])] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
......@@ -195,10 +197,10 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
}
}
val coGroupOperator = new CoGroupOperator[T, O, (Array[T], Array[O])](
leftSet, rightSet, leftKey, rightKey, coGrouper, returnType)
leftSet.set, rightSet.set, leftKey, rightKey, coGrouper, returnType)
// sanity check solution set key mismatches
leftSet match {
leftSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
......@@ -211,7 +213,7 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
}
case _ =>
}
rightSet match {
rightSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
......
......@@ -68,8 +68,11 @@ private[flink] trait TypeInformationGen[C <: Context] {
}
val fieldsExpr = c.Expr[Seq[TypeInformation[_]]](mkList(fields))
val instance = mkCreateTupleInstance[T](desc)(c.WeakTypeTag(desc.tpe))
val fieldNames = desc.getters map { f => Literal(Constant(f.getter.name.toString)) } toList
val fieldNamesExpr = c.Expr[Seq[String]](mkSeq(fieldNames))
reify {
new ScalaTupleTypeInfo[T](tpeClazz.splice, fieldsExpr.splice) {
new ScalaTupleTypeInfo[T](tpeClazz.splice, fieldsExpr.splice, fieldNamesExpr.splice) {
override def createSerializer: TypeSerializer[T] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
for (i <- 0 until getArity) {
......
......@@ -110,7 +110,7 @@ private[flink] object CrossDataSetImpl {
}
}
val returnType = new ScalaTupleTypeInfo[(T, O)](
classOf[(T, O)], Seq(leftSet.getType, rightSet.getType)) {
classOf[(T, O)], Seq(leftSet.getType, rightSet.getType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(T, O)] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
......
......@@ -168,8 +168,8 @@ trait UnfinishedJoinOperation[T, O] extends UnfinishedKeyPairOperation[T, O, Joi
* i.e. the parameters of the constructor, hidden.
*/
private[flink] class UnfinishedJoinOperationImpl[T, O](
leftSet: JavaDataSet[T],
rightSet: JavaDataSet[O],
leftSet: DataSet[T],
rightSet: DataSet[O],
joinHint: JoinHint)
extends UnfinishedKeyPairOperation[T, O, JoinDataSet[T, O]](leftSet, rightSet)
with UnfinishedJoinOperation[T, O] {
......@@ -181,7 +181,7 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
}
val returnType = new ScalaTupleTypeInfo[(T, O)](
classOf[(T, O)], Seq(leftSet.getType, rightSet.getType)) {
classOf[(T, O)], Seq(leftSet.set.getType, rightSet.set.getType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(T, O)] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
......@@ -197,10 +197,10 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
}
val joinOperator = new EquiJoin[T, O, (T, O)](
leftSet, rightSet, leftKey, rightKey, joiner, returnType, joinHint)
leftSet.set, rightSet.set, leftKey, rightKey, joiner, returnType, joinHint)
// sanity check solution set key mismatches
leftSet match {
leftSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
......@@ -213,7 +213,7 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
case _ =>
}
rightSet match {
rightSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
......@@ -227,6 +227,6 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
case _ =>
}
new JoinDataSetImpl(joinOperator, leftSet, rightSet, leftKey, rightKey)
new JoinDataSetImpl(joinOperator, leftSet.set, rightSet.set, leftKey, rightKey)
}
}
\ No newline at end of file
......@@ -21,7 +21,7 @@ package org.apache.flink.api
import _root_.scala.reflect.ClassTag
import language.experimental.macros
import org.apache.flink.types.TypeInformation
import org.apache.flink.api.scala.typeutils.TypeUtils
import org.apache.flink.api.scala.typeutils.{ScalaTupleTypeInfo, TypeUtils}
import org.apache.flink.api.java.{DataSet => JavaDataSet}
package object scala {
......@@ -31,4 +31,17 @@ package object scala {
// We need to wrap Java DataSet because we need the scala operations
private[flink] def wrap[R: ClassTag](set: JavaDataSet[R]) = new DataSet[R](set)
private[flink] def fieldNames2Indices(
typeInfo: TypeInformation[_],
fields: Array[String]): Array[Int] = {
typeInfo match {
case ti: ScalaTupleTypeInfo[_] =>
ti.getFieldIndices(fields)
case _ =>
throw new UnsupportedOperationException("Specifying fields by name is only" +
"supported on Case Classes (for now).")
}
}
}
\ No newline at end of file
......@@ -18,8 +18,7 @@
package org.apache.flink.api.scala.typeutils
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils.{TupleTypeInfo, AtomicType, TupleTypeInfoBase}
import org.apache.flink.api.java.typeutils.{AtomicType, TupleTypeInfoBase}
import org.apache.flink.types.TypeInformation
import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
......@@ -29,7 +28,8 @@ import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
*/
abstract class ScalaTupleTypeInfo[T <: Product](
tupleClass: Class[T],
fieldTypes: Seq[TypeInformation[_]])
fieldTypes: Seq[TypeInformation[_]],
val fieldNames: Seq[String])
extends TupleTypeInfoBase[T](tupleClass, fieldTypes: _*) {
def createComparator(logicalKeyFields: Array[Int], orders: Array[Boolean]): TypeComparator[T] = {
......@@ -76,5 +76,14 @@ abstract class ScalaTupleTypeInfo[T <: Product](
new ScalaTupleComparator[T](logicalKeyFields, fieldComparators, fieldSerializers)
}
def getFieldIndices(fields: Array[String]): Array[Int] = {
val result = fields map { x => fieldNames.indexOf(x) }
if (result.contains(-1)) {
throw new IllegalArgumentException("Fields '" + fields.mkString(", ") + "' are not valid for" +
" " + tupleClass + " with fields '" + fieldNames.mkString(", ") + "'.")
}
result
}
override def toString = "Scala " + super.toString
}
......@@ -24,6 +24,7 @@ import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.operators.Keys
import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
import org.apache.flink.api.scala.typeutils.ScalaTupleTypeInfo
import org.apache.flink.types.TypeInformation
/**
......@@ -43,27 +44,42 @@ import org.apache.flink.types.TypeInformation
* @tparam R The type of the resulting Operation.
*/
private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
private[flink] val leftSet: JavaDataSet[T],
private[flink] val rightSet: JavaDataSet[O]) {
private[flink] val leftSet: DataSet[T],
private[flink] val rightSet: DataSet[O]) {
private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]): R
/**
* Specify the key fields for the left side of the key based operation. This returns
* a [[HalfUnfinishedKeyPairOperation]] on which `isEqualTo` must be called to specify the
* a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*
* This only works on a Tuple [[DataSet]].
* This only works on Tuple [[DataSet]].
*/
def where(leftKeys: Int*) = {
val leftKey = new FieldPositionKeys[T](leftKeys.toArray, leftSet.getType)
val leftKey = new FieldPositionKeys[T](leftKeys.toArray, leftSet.set.getType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
}
/**
* Specify the key fields for the left side of the key based operation. This returns
* a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*
* This only works on a CaseClass [[DataSet]].
*/
def where(firstLeftField: String, otherLeftFields: String*) = {
val fieldIndices = fieldNames2Indices(leftSet.set.getType, firstLeftField +: otherLeftFields.toArray)
val leftKey = new FieldPositionKeys[T](fieldIndices, leftSet.set.getType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
}
/**
* Specify the key selector function for the left side of the key based operation. This returns
* a [[HalfUnfinishedKeyPairOperation]] on which `isEqualTo` must be called to specify the
* a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*/
......@@ -72,7 +88,7 @@ private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
val keyExtractor = new KeySelector[T, K] {
def getKey(in: T) = fun(in)
}
val leftKey = new Keys.SelectorFunctionKeys[T, K](keyExtractor, leftSet.getType, keyType)
val leftKey = new Keys.SelectorFunctionKeys[T, K](keyExtractor, leftSet.set.getType, keyType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
}
}
......@@ -87,7 +103,7 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
* This only works on a Tuple [[DataSet]].
*/
def equalTo(rightKeys: Int*): R = {
val rightKey = new FieldPositionKeys[O](rightKeys.toArray, unfinished.rightSet.getType)
val rightKey = new FieldPositionKeys[O](rightKeys.toArray, unfinished.rightSet.set.getType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
......@@ -95,6 +111,26 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
unfinished.finish(leftKey, rightKey)
}
/**
* Specify the key fields for the right side of the key based operation. This returns
* the finished operation.
*
* This only works on a CaseClass [[DataSet]].
*/
def equalTo(firstRightField: String, otherRightFields: String*): R = {
val fieldIndices = fieldNames2Indices(
unfinished.rightSet.set.getType,
firstRightField +: otherRightFields.toArray)
val rightKey = new FieldPositionKeys[O](fieldIndices, unfinished.rightSet.set.getType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
}
unfinished.finish(leftKey, rightKey)
}
/**
* Specify the key selector function for the right side of the key based operation. This returns
* the finished operation.
......@@ -105,7 +141,7 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
def getKey(in: O) = fun(in)
}
val rightKey =
new Keys.SelectorFunctionKeys[O, K](keyExtractor, unfinished.rightSet.getType, keyType)
new Keys.SelectorFunctionKeys[O, K](keyExtractor, unfinished.rightSet.set.getType, keyType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
......
......@@ -65,6 +65,41 @@ class AggregateOperatorTest {
}
}
@Test
def testFieldNamesAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
// should work
try {
tupleDs.aggregate(Aggregations.SUM, "_2")
} catch {
case e: Exception => Assert.fail()
}
// should not work: invalid field
try {
tupleDs.aggregate(Aggregations.SUM, "foo")
Assert.fail()
} catch {
case iae: IllegalArgumentException =>
case e: Exception => Assert.fail()
}
val longDs = env.fromCollection(emptyLongData)
// should not work: not applied to tuple DataSet
try {
longDs.aggregate(Aggregations.MIN, "_1")
Assert.fail()
} catch {
case uoe: InvalidProgramException =>
case uoe: UnsupportedOperationException =>
case e: Exception => Assert.fail()
}
}
@Test
def testAggregationTypes(): Unit = {
try {
......@@ -72,8 +107,13 @@ class AggregateOperatorTest {
val tupleDs = env.fromCollection(emptyTupleData)
// should work: multiple aggregates
tupleDs.aggregate(Aggregations.SUM, 0).aggregate(Aggregations.MIN, 4)
// should work: nested aggregates
tupleDs.aggregate(Aggregations.MIN, 2).aggregate(Aggregations.SUM, 1)
// should not work: average on string
try {
tupleDs.aggregate(Aggregations.SUM, 2)
Assert.fail()
......
......@@ -94,6 +94,71 @@ class CoGroupOperatorTest {
ds1.coGroup(ds2).where(5).equalTo(0)
}
@Test
def testCoGroupKeyFieldNames1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// Should work
try {
ds1.coGroup(ds2).where("_1").equalTo("_1")
}
catch {
case e: Exception => Assert.fail()
}
}
@Test(expected = classOf[InvalidProgramException])
def testCoGroupKeyFieldNames2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// Should not work, incompatible key types
ds1.coGroup(ds2).where("_1").equalTo("_3")
}
@Test(expected = classOf[InvalidProgramException])
def testCoGroupKeyFieldNames3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// Should not work, incompatible number of key fields
ds1.coGroup(ds2).where("_1", "_2").equalTo("_3")
}
@Test(expected = classOf[IllegalArgumentException])
def testCoGroupKeyFieldNames4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// Should not work, invalid field name
ds1.coGroup(ds2).where("_6").equalTo("_1")
}
@Test(expected = classOf[IllegalArgumentException])
def testCoGroupKeyFieldNames5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// Should not work, invalid field name
ds1.coGroup(ds2).where("_1").equalTo("bar")
}
@Test(expected = classOf[UnsupportedOperationException])
def testCoGroupKeyFieldNames6(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(customTypeData)
// Should not work, field position key on custom data type
ds1.coGroup(ds2).where("_3").equalTo("_1")
}
@Ignore
@Test
def testCoGroupKeyExpressions1(): Unit = {
......
......@@ -30,7 +30,7 @@ class DistinctOperatorTest {
private val emptyLongData = Array[Long]()
@Test
def testDistinctByKeyFields1(): Unit = {
def testDistinctByKeyIndices1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
......@@ -44,7 +44,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
def testDistinctByKeyFields2(): Unit = {
def testDistinctByKeyIndices2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
......@@ -53,7 +53,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
def testDistinctByKeyFields3(): Unit = {
def testDistinctByKeyIndices3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
......@@ -62,7 +62,7 @@ class DistinctOperatorTest {
}
@Test
def testDistinctByKeyFields4(): Unit = {
def testDistinctByKeyIndices4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
......@@ -71,7 +71,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
def testDistinctByKeyFields5(): Unit = {
def testDistinctByKeyIndices5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
......@@ -80,7 +80,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[IllegalArgumentException])
def testDistinctByKeyFields6(): Unit = {
def testDistinctByKeyIndices6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
......@@ -88,6 +88,47 @@ class DistinctOperatorTest {
tupleDs.distinct(-1)
}
@Test
def testDistinctByKeyFields1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
// Should work
try {
tupleDs.distinct("_1")
}
catch {
case e: Exception => Assert.fail()
}
}
@Test(expected = classOf[UnsupportedOperationException])
def testDistinctByKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
// should not work: distinct on basic type
longDs.distinct("_1")
}
@Test(expected = classOf[UnsupportedOperationException])
def testDistinctByKeyFields3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
// should not work: field key on custom type
customDs.distinct("_1")
}
@Test(expected = classOf[IllegalArgumentException])
def testDistinctByKeyFields4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
// should not work, invalid field
tupleDs.distinct("foo")
}
@Test
def testDistinctByKeySelector1(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
......
......@@ -33,7 +33,7 @@ class GroupingTest {
private val emptyLongData = Array[Long]()
@Test
def testGroupByKeyFields1(): Unit = {
def testGroupByKeyIndices1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
......@@ -47,7 +47,7 @@ class GroupingTest {
}
@Test(expected = classOf[InvalidProgramException])
def testGroupByKeyFields2(): Unit = {
def testGroupByKeyIndices2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
......@@ -56,7 +56,7 @@ class GroupingTest {
}
@Test(expected = classOf[InvalidProgramException])
def testGroupByKeyFields3(): Unit = {
def testGroupByKeyIndices3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
......@@ -65,7 +65,7 @@ class GroupingTest {
}
@Test(expected = classOf[IllegalArgumentException])
def testGroupByKeyFields4(): Unit = {
def testGroupByKeyIndices4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
......@@ -74,7 +74,7 @@ class GroupingTest {
}
@Test(expected = classOf[IllegalArgumentException])
def testGroupByKeyFields5(): Unit = {
def testGroupByKeyIndices5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
......@@ -82,6 +82,47 @@ class GroupingTest {
tupleDs.groupBy(-1)
}
@Test
def testGroupByKeyFields1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
// should work
try {
tupleDs.groupBy("_1")
}
catch {
case e: Exception => Assert.fail()
}
}
@Test(expected = classOf[UnsupportedOperationException])
def testGroupByKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
// should not work, grouping on basic type
longDs.groupBy("_1")
}
@Test(expected = classOf[UnsupportedOperationException])
def testGroupByKeyFields3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
// should not work, field key on custom type
customDs.groupBy("_1")
}
@Test(expected = classOf[IllegalArgumentException])
def testGroupByKeyFields4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
// should not work, invalid field
tupleDs.groupBy("foo")
}
@Ignore
@Test
def testGroupByKeyExpressions1(): Unit = {
......
......@@ -31,10 +31,12 @@ class JoinOperatorTest {
private val emptyLongData = Array[Long]()
@Test
def testJoinKeyFields1(): Unit = {
def testJoinKeyIndices1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should work
try {
ds1.join(ds2).where(0).equalTo(0)
}
......@@ -44,45 +46,121 @@ class JoinOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
def testJoinKeyFields2(): Unit = {
def testJoinKeyIndices2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, incompatible key types
ds1.join(ds2).where(0).equalTo(2)
}
@Test(expected = classOf[InvalidProgramException])
def testJoinKeyFields3(): Unit = {
def testJoinKeyIndices3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, non-matching number of key indices
ds1.join(ds2).where(0, 1).equalTo(2)
}
@Test(expected = classOf[IllegalArgumentException])
def testJoinKeyFields4(): Unit = {
def testJoinKeyIndices4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, index out of range
ds1.join(ds2).where(5).equalTo(0)
}
@Test(expected = classOf[IllegalArgumentException])
def testJoinKeyFields5(): Unit = {
def testJoinKeyIndices5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, negative position
ds1.join(ds2).where(-1).equalTo(-1)
}
@Test(expected = classOf[IllegalArgumentException])
def testJoinKeyFields6(): Unit = {
def testJoinKeyIndices6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(customTypeData)
// should not work, key index on custom type
ds1.join(ds2).where(5).equalTo(0)
}
@Test
def testJoinKeyFields1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should work
try {
ds1.join(ds2).where("_1").equalTo("_1")
}
catch {
case e: Exception => Assert.fail()
}
}
@Test(expected = classOf[InvalidProgramException])
def testJoinKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, incompatible field types
ds1.join(ds2).where("_1").equalTo("_3")
}
@Test(expected = classOf[InvalidProgramException])
def testJoinKeyFields3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, non-matching number of key indices
ds1.join(ds2).where("_1", "_2").equalTo("_3")
}
@Test(expected = classOf[IllegalArgumentException])
def testJoinKeyFields4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, non-existent key
ds1.join(ds2).where("foo").equalTo("_1")
}
@Test(expected = classOf[IllegalArgumentException])
def testJoinKeyFields5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
// should not work, non-matching number of key indices
ds1.join(ds2).where("_1").equalTo("bar")
}
@Test(expected = classOf[UnsupportedOperationException])
def testJoinKeyFields6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(customTypeData)
// should not work, field key on custom type
ds1.join(ds2).where("_2").equalTo("_1")
}
@Ignore
@Test
def testJoinKeyExpressions1(): Unit = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册