提交 6be85554 编写于 作者: A Aljoscha Krettek 提交者: Robert Metzger

Really add POJO support and nested keys for Scala API

This also adds more integration tests, but not all tests of the Java API
have been ported to Scala yet.
上级 598ae376
...@@ -76,7 +76,7 @@ object ConnectedComponents { ...@@ -76,7 +76,7 @@ object ConnectedComponents {
val edges = getEdgesDataSet(env).flatMap { edge => Seq(edge, (edge._2, edge._1)) } val edges = getEdgesDataSet(env).flatMap { edge => Seq(edge, (edge._2, edge._1)) }
// open a delta iteration // open a delta iteration
val verticesWithComponents = vertices.iterateDelta(vertices, maxIterations, Array(0)) { val verticesWithComponents = vertices.iterateDelta(vertices, maxIterations, Array("_1")) {
(s, ws) => (s, ws) =>
// apply the step logic: join with the edges // apply the step logic: join with the edges
......
...@@ -23,6 +23,7 @@ import java.util.Arrays; ...@@ -23,6 +23,7 @@ import java.util.Arrays;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import com.google.common.base.Joiner;
import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.TypeInformation;
...@@ -306,7 +307,12 @@ public abstract class Keys<T> { ...@@ -306,7 +307,12 @@ public abstract class Keys<T> {
} }
return Ints.toArray(logicalKeys); return Ints.toArray(logicalKeys);
} }
@Override
public String toString() {
Joiner join = Joiner.on('.');
return "ExpressionKeys: " + join.join(keyFields);
}
} }
private static String[] removeDuplicates(String[] in) { private static String[] removeDuplicates(String[] in) {
......
...@@ -345,7 +345,9 @@ public class PojoTypeExtractionTest { ...@@ -345,7 +345,9 @@ public class PojoTypeExtractionTest {
Assert.assertEquals(typeInfo.getTypeClass(), WC.class); Assert.assertEquals(typeInfo.getTypeClass(), WC.class);
Assert.assertEquals(typeInfo.getArity(), 2); Assert.assertEquals(typeInfo.getArity(), 2);
} }
// Kryo is required for this, so disable for now.
@Ignore
@Test @Test
public void testPojoAllPublic() { public void testPojoAllPublic() {
TypeInformation<?> typeForClass = TypeExtractor.createTypeInfo(AllPublic.class); TypeInformation<?> typeForClass = TypeExtractor.createTypeInfo(AllPublic.class);
......
...@@ -550,14 +550,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { ...@@ -550,14 +550,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
/** /**
* Creates a new DataSet containing the distinct elements of this DataSet. The decision whether * Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
* two elements are distinct or not is made based on only the specified fields. * two elements are distinct or not is made based on only the specified fields.
*
* This only works on CaseClass DataSets
*/ */
def distinct(firstField: String, otherFields: String*): DataSet[T] = { def distinct(firstField: String, otherFields: String*): DataSet[T] = {
val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)
wrap(new DistinctOperator[T]( wrap(new DistinctOperator[T](
javaSet, javaSet,
new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType, true))) new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType)))
} }
/** /**
...@@ -615,8 +612,6 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { ...@@ -615,8 +612,6 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* This only works on CaseClass DataSets. * This only works on CaseClass DataSets.
*/ */
def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = { def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
// val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)
new GroupedDataSet[T]( new GroupedDataSet[T](
this, this,
new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType)) new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType))
...@@ -862,10 +857,8 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { ...@@ -862,10 +857,8 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
*/ */
def iterateDelta[R: ClassTag](workset: DataSet[R], maxIterations: Int, keyFields: Array[String])( def iterateDelta[R: ClassTag](workset: DataSet[R], maxIterations: Int, keyFields: Array[String])(
stepFunction: (DataSet[T], DataSet[R]) => (DataSet[T], DataSet[R])) = { stepFunction: (DataSet[T], DataSet[R]) => (DataSet[T], DataSet[R])) = {
val fieldIndices = fieldNames2Indices(javaSet.getType, keyFields)
val key = new ExpressionKeys[T](fieldIndices, javaSet.getType, false) val key = new ExpressionKeys[T](keyFields, javaSet.getType)
val iterativeSet = new DeltaIteration[T, R]( val iterativeSet = new DeltaIteration[T, R](
javaSet.getExecutionEnvironment, javaSet.getExecutionEnvironment,
javaSet.getType, javaSet.getType,
...@@ -931,12 +924,10 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { ...@@ -931,12 +924,10 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* significant amount of time. * significant amount of time.
*/ */
def partitionByHash(firstField: String, otherFields: String*): DataSet[T] = { def partitionByHash(firstField: String, otherFields: String*): DataSet[T] = {
val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)
val op = new PartitionOperator[T]( val op = new PartitionOperator[T](
javaSet, javaSet,
PartitionMethod.HASH, PartitionMethod.HASH,
new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType, false)) new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType))
wrap(op) wrap(op)
} }
......
...@@ -47,7 +47,7 @@ class GroupedDataSet[T: ClassTag]( ...@@ -47,7 +47,7 @@ class GroupedDataSet[T: ClassTag](
// These are for optional secondary sort. They are only used // These are for optional secondary sort. They are only used
// when using a group-at-a-time reduce function. // when using a group-at-a-time reduce function.
private val groupSortKeyPositions = mutable.MutableList[Int]() private val groupSortKeyPositions = mutable.MutableList[Either[Int, String]]()
private val groupSortOrders = mutable.MutableList[Order]() private val groupSortOrders = mutable.MutableList[Order]()
/** /**
...@@ -64,7 +64,7 @@ class GroupedDataSet[T: ClassTag]( ...@@ -64,7 +64,7 @@ class GroupedDataSet[T: ClassTag](
if (field >= set.getType.getArity) { if (field >= set.getType.getArity) {
throw new IllegalArgumentException("Order key out of tuple bounds.") throw new IllegalArgumentException("Order key out of tuple bounds.")
} }
groupSortKeyPositions += field groupSortKeyPositions += Left(field)
groupSortOrders += order groupSortOrders += order
this this
} }
...@@ -76,9 +76,7 @@ class GroupedDataSet[T: ClassTag]( ...@@ -76,9 +76,7 @@ class GroupedDataSet[T: ClassTag](
* This only works on CaseClass DataSets. * This only works on CaseClass DataSets.
*/ */
def sortGroup(field: String, order: Order): GroupedDataSet[T] = { def sortGroup(field: String, order: Order): GroupedDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0) groupSortKeyPositions += Right(field)
groupSortKeyPositions += fieldIndex
groupSortOrders += order groupSortOrders += order
this this
} }
...@@ -88,14 +86,32 @@ class GroupedDataSet[T: ClassTag]( ...@@ -88,14 +86,32 @@ class GroupedDataSet[T: ClassTag](
*/ */
private def maybeCreateSortedGrouping(): Grouping[T] = { private def maybeCreateSortedGrouping(): Grouping[T] = {
if (groupSortKeyPositions.length > 0) { if (groupSortKeyPositions.length > 0) {
val grouping = new SortedGrouping[T]( val grouping = groupSortKeyPositions(0) match {
set.javaSet, case Left(pos) =>
keys, new SortedGrouping[T](
groupSortKeyPositions(0), set.javaSet,
groupSortOrders(0)) keys,
pos,
groupSortOrders(0))
case Right(field) =>
new SortedGrouping[T](
set.javaSet,
keys,
field,
groupSortOrders(0))
}
// now manually add the rest of the keys // now manually add the rest of the keys
for (i <- 1 until groupSortKeyPositions.length) { for (i <- 1 until groupSortKeyPositions.length) {
grouping.sortGroup(groupSortKeyPositions(i), groupSortOrders(i)) groupSortKeyPositions(i) match {
case Left(pos) =>
grouping.sortGroup(pos, groupSortOrders(i))
case Right(field) =>
grouping.sortGroup(field, groupSortOrders(i))
}
} }
grouping grouping
} else { } else {
...@@ -209,7 +225,7 @@ class GroupedDataSet[T: ClassTag]( ...@@ -209,7 +225,7 @@ class GroupedDataSet[T: ClassTag](
} }
} }
wrap( wrap(
new GroupReduceOperator[T, R](createUnsortedGrouping(), new GroupReduceOperator[T, R](maybeCreateSortedGrouping(),
implicitly[TypeInformation[R]], reducer)) implicitly[TypeInformation[R]], reducer))
} }
...@@ -227,7 +243,7 @@ class GroupedDataSet[T: ClassTag]( ...@@ -227,7 +243,7 @@ class GroupedDataSet[T: ClassTag](
} }
} }
wrap( wrap(
new GroupReduceOperator[T, R](createUnsortedGrouping(), new GroupReduceOperator[T, R](maybeCreateSortedGrouping(),
implicitly[TypeInformation[R]], reducer)) implicitly[TypeInformation[R]], reducer))
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
*/ */
package org.apache.flink.api.scala.codegen package org.apache.flink.api.scala.codegen
import scala.Option.option2Iterable
import scala.collection.GenTraversableOnce import scala.collection.GenTraversableOnce
import scala.collection.mutable import scala.collection.mutable
import scala.reflect.macros.Context import scala.reflect.macros.Context
...@@ -59,12 +58,17 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] ...@@ -59,12 +58,17 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, tpe, default, wrapper) case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, tpe, default, wrapper)
case BoxedPrimitiveType(default, wrapper, box, unbox) => case BoxedPrimitiveType(default, wrapper, box, unbox) =>
BoxedPrimitiveDescriptor(id, tpe, default, wrapper, box, unbox) BoxedPrimitiveDescriptor(id, tpe, default, wrapper, box, unbox)
case ListType(elemTpe, iter) => analyzeList(id, tpe, elemTpe, iter) case ListType(elemTpe, iter) =>
analyzeList(id, tpe, elemTpe, iter)
case CaseClassType() => analyzeCaseClass(id, tpe) case CaseClassType() => analyzeCaseClass(id, tpe)
case BaseClassType() => analyzeClassHierarchy(id, tpe)
case ValueType() => ValueDescriptor(id, tpe) case ValueType() => ValueDescriptor(id, tpe)
case WritableType() => WritableDescriptor(id, tpe) case WritableType() => WritableDescriptor(id, tpe)
case _ => GenericClassDescriptor(id, tpe) case JavaType() =>
// It's a Java Class, let the TypeExtractor deal with it...
c.warning(c.enclosingPosition, s"Type $tpe is a java class. Will be analyzed by " +
s"TypeExtractor at runtime.")
GenericClassDescriptor(id, tpe)
case _ => analyzePojo(id, tpe)
} }
} }
} }
...@@ -78,110 +82,63 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] ...@@ -78,110 +82,63 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
case desc => ListDescriptor(id, tpe, iter, desc) case desc => ListDescriptor(id, tpe, iter, desc)
} }
private def analyzeClassHierarchy(id: Int, tpe: Type): UDTDescriptor = { private def analyzePojo(id: Int, tpe: Type): UDTDescriptor = {
val immutableFields = tpe.members filter { _.isTerm } map { _.asTerm } filter { _.isVal }
val tagField = { if (immutableFields.nonEmpty) {
val (intTpe, intDefault, intWrapper) = PrimitiveType.intPrimitive // We don't support POJOs with immutable fields
FieldAccessor( c.warning(
NoSymbol, c.enclosingPosition,
NoSymbol, s"Type $tpe is no POJO, has immutable fields: ${immutableFields.mkString(", ")}.")
NullaryMethodType(intTpe), return GenericClassDescriptor(id, tpe)
isBaseField = true,
PrimitiveDescriptor(cache.newId, intTpe, intDefault, intWrapper))
} }
val subTypes = tpe.typeSymbol.asClass.knownDirectSubclasses.toList flatMap { d =>
val dTpe =
{
val tArgs = (tpe.typeSymbol.asClass.typeParams, typeArgs(tpe)).zipped.toMap
val dArgs = d.asClass.typeParams map { dp =>
val tArg = tArgs.keySet.find { tp =>
dp == tp.typeSignature.asSeenFrom(d.typeSignature, tpe.typeSymbol).typeSymbol
}
tArg map { tArgs(_) } getOrElse dp.typeSignature
}
appliedType(d.asType.toType, dArgs) val fields = tpe.members
} .filter { _.isTerm }
.map { _.asTerm }
.filter { _.isVar }
.filterNot { _.annotations.exists( _.tpe <:< typeOf[scala.transient]) }
if (dTpe <:< tpe) { if (fields.isEmpty) {
Some(analyze(dTpe)) c.warning(c.enclosingPosition, "Type $tpe has no fields that are visible from Scala Type" +
} else { " analysis. Falling back to Java Type Analysis (TypeExtractor).")
None return GenericClassDescriptor(id, tpe)
}
} }
val errors = subTypes flatMap { _.findByType[UnsupportedDescriptor] } // check whether all fields are either: 1. public, 2. have getter/setter
val invalidFields = fields filterNot {
errors match { f =>
case _ :: _ => f.isPublic ||
val errorMessage = errors flatMap { (f.getter != NoSymbol && f.getter.isPublic && f.setter != NoSymbol && f.setter.isPublic)
case UnsupportedDescriptor(_, subType, errs) => }
errs map { err => "Subtype " + subType + " - " + err }
}
UnsupportedDescriptor(id, tpe, errorMessage)
case Nil if subTypes.isEmpty =>
UnsupportedDescriptor(id, tpe, Seq("No instantiable subtypes found for base class"))
case Nil =>
val (tParams, _) = tpe.typeSymbol.asClass.typeParams.zip(typeArgs(tpe)).unzip
val baseMembers =
tpe.members filter { f => f.isMethod } filter { f => f.asMethod.isSetter } map {
f => (f, f.asMethod.setter, f.asMethod.returnType)
}
val subMembers = subTypes map {
case BaseClassDescriptor(_, _, getters, _) => getters
case CaseClassDescriptor(_, _, _, _, getters) => getters
case _ => Seq()
}
val baseFields = baseMembers flatMap {
case (bGetter, bSetter, bTpe) =>
val accessors = subMembers map {
_ find { sf =>
sf.getter.name == bGetter.name &&
sf.tpe.termSymbol.asMethod.returnType <:< bTpe.termSymbol.asMethod.returnType
}
}
accessors.forall { _.isDefined } match {
case true =>
Some(
FieldAccessor(
bGetter,
bSetter,
bTpe,
isBaseField = true,
analyze(bTpe.termSymbol.asMethod.returnType)))
case false => None
}
}
def wireBaseFields(desc: UDTDescriptor): UDTDescriptor = { if (invalidFields.nonEmpty) {
c.warning(c.enclosingPosition, s"Type $tpe is no POJO because it has non-public fields '" +
s"${invalidFields.mkString(", ")}' that don't have public getters/setters.")
return GenericClassDescriptor(id, tpe)
}
def updateField(field: FieldAccessor) = { // check whether we have a zero-parameter ctor
baseFields find { bf => bf.getter.name == field.getter.name } match { val hasZeroCtor = tpe.declarations exists {
case Some(FieldAccessor(_, _, _, _, fieldDesc)) => case m: MethodSymbol
field.copy(isBaseField = true, desc = fieldDesc) if m.isConstructor && m.paramss.length == 1 && m.paramss(0).length == 0 => true
case None => field case _ => false
} }
}
desc match { if (!hasZeroCtor) {
case desc @ BaseClassDescriptor(_, _, getters, baseSubTypes) => // We don't support POJOs without zero-paramter ctor
desc.copy( c.warning(
getters = getters map updateField, c.enclosingPosition,
subTypes = baseSubTypes map wireBaseFields) s"Class $tpe is no POJO, has no zero-parameters constructor.")
case desc @ CaseClassDescriptor(_, _, _, _, getters) => return GenericClassDescriptor(id, tpe)
desc.copy(getters = getters map updateField) }
case _ => desc
}
}
BaseClassDescriptor(id, tpe, tagField +: baseFields.toSeq, subTypes map wireBaseFields) val fieldDescriptors = fields map {
f =>
val fieldTpe = f.getter.asMethod.returnType.asSeenFrom(tpe, tpe.typeSymbol)
FieldDescriptor(f.name.toString.trim, f.getter, f.setter, fieldTpe, analyze(fieldTpe))
} }
PojoDescriptor(id, tpe, fieldDescriptors.toSeq)
} }
private def analyzeCaseClass(id: Int, tpe: Type): UDTDescriptor = { private def analyzeCaseClass(id: Int, tpe: Type): UDTDescriptor = {
...@@ -216,7 +173,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] ...@@ -216,7 +173,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
} }
val fields = caseFields map { val fields = caseFields map {
case (fgetter, fsetter, fTpe) => case (fgetter, fsetter, fTpe) =>
FieldAccessor(fgetter, fsetter, fTpe, isBaseField = false, analyze(fTpe)) FieldDescriptor(fgetter.name.toString.trim, fgetter, fsetter, fTpe, analyze(fTpe))
} }
val mutable = enableMutableUDTs && (fields forall { f => f.setter != NoSymbol }) val mutable = enableMutableUDTs && (fields forall { f => f.setter != NoSymbol })
if (mutable) { if (mutable) {
...@@ -226,8 +183,9 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] ...@@ -226,8 +183,9 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
case errs @ _ :: _ => case errs @ _ :: _ =>
val msgs = errs flatMap { f => val msgs = errs flatMap { f =>
(f: @unchecked) match { (f: @unchecked) match {
case FieldAccessor(fgetter, _,_,_, UnsupportedDescriptor(_, fTpe, errors)) => case FieldDescriptor(
errors map { err => "Field " + fgetter.name + ": " + fTpe + " - " + err } fName, _, _, _, UnsupportedDescriptor(_, fTpe, errors)) =>
errors map { err => "Field " + fName + ": " + fTpe + " - " + err }
} }
} }
UnsupportedDescriptor(id, tpe, msgs) UnsupportedDescriptor(id, tpe, msgs)
...@@ -296,11 +254,6 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] ...@@ -296,11 +254,6 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
} }
private object BaseClassType {
def unapply(tpe: Type): Boolean =
tpe.typeSymbol.asClass.isAbstractClass && tpe.typeSymbol.asClass.isSealed
}
private object ValueType { private object ValueType {
def unapply(tpe: Type): Boolean = def unapply(tpe: Type): Boolean =
tpe.typeSymbol.asClass.baseClasses exists { tpe.typeSymbol.asClass.baseClasses exists {
...@@ -315,6 +268,10 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] ...@@ -315,6 +268,10 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
} }
} }
private object JavaType {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isJava
}
private class UDTAnalyzerCache { private class UDTAnalyzerCache {
private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map()) private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map())
......
...@@ -36,10 +36,8 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] ...@@ -36,10 +36,8 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
def canBeKey: Boolean def canBeKey: Boolean
def mkRoot: UDTDescriptor = this
def flatten: Seq[UDTDescriptor] def flatten: Seq[UDTDescriptor]
def getters: Seq[FieldAccessor] = Seq() def getters: Seq[FieldDescriptor] = Seq()
def select(member: String): Option[UDTDescriptor] = def select(member: String): Option[UDTDescriptor] =
getters find { _.getter.name.toString == member } map { _.desc } getters find { _.getter.name.toString == member } map { _.desc }
...@@ -48,7 +46,7 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] ...@@ -48,7 +46,7 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
case Nil => Seq(Some(this)) case Nil => Seq(Some(this))
case head :: tail => getters find { _.getter.name.toString == head } match { case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None) case None => Seq(None)
case Some(d : FieldAccessor) => d.desc.select(tail) case Some(d : FieldDescriptor) => d.desc.select(tail)
} }
} }
...@@ -60,7 +58,7 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] ...@@ -60,7 +58,7 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
} }
def getRecursiveRefs: Seq[UDTDescriptor] = def getRecursiveRefs: Seq[UDTDescriptor] =
findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.map { _.mkRoot }.distinct findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.distinct
} }
case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor { case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
...@@ -116,30 +114,45 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] ...@@ -116,30 +114,45 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
} }
} }
case class BaseClassDescriptor( case class PojoDescriptor(id: Int, tpe: Type, override val getters: Seq[FieldDescriptor])
id: Int, tpe: Type, override val getters: Seq[FieldAccessor], subTypes: Seq[UDTDescriptor])
extends UDTDescriptor { extends UDTDescriptor {
override def flatten = override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct)
this +: ((getters flatMap { _.desc.flatten }) ++ (subTypes flatMap { _.flatten }))
override def flatten = this +: (getters flatMap { _.desc.flatten })
override def canBeKey = flatten forall { f => f.canBeKey } override def canBeKey = flatten forall { f => f.canBeKey }
// Hack: ignore the ctorTpe, since two Type instances representing
// the same ctor function type don't appear to be considered equal.
// Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
override def hashCode = (id, tpe, getters).hashCode
override def equals(that: Any) = that match {
case PojoDescriptor(thatId, thatTpe, thatGetters) =>
(id, tpe, getters).equals(
thatId, thatTpe, thatGetters)
case _ => false
}
override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match { override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => getters flatMap { g => g.desc.select(Nil) } case Nil => getters flatMap { g => g.desc.select(Nil) }
case head :: tail => getters find { _.getter.name.toString == head } match { case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None) case None => Seq(None)
case Some(d : FieldAccessor) => d.desc.select(tail) case Some(d : FieldDescriptor) => d.desc.select(tail)
} }
} }
} }
case class CaseClassDescriptor( case class CaseClassDescriptor(
id: Int, tpe: Type, mutable: Boolean, ctor: Symbol, override val getters: Seq[FieldAccessor]) id: Int,
tpe: Type,
mutable: Boolean,
ctor: Symbol,
override val getters: Seq[FieldDescriptor])
extends UDTDescriptor { extends UDTDescriptor {
override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct) override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct)
override def mkRoot = this.copy(getters = getters map { _.copy(isBaseField = false) })
override def flatten = this +: (getters flatMap { _.desc.flatten }) override def flatten = this +: (getters flatMap { _.desc.flatten })
override def canBeKey = flatten forall { f => f.canBeKey } override def canBeKey = flatten forall { f => f.canBeKey }
...@@ -159,16 +172,16 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] ...@@ -159,16 +172,16 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
case Nil => getters flatMap { g => g.desc.select(Nil) } case Nil => getters flatMap { g => g.desc.select(Nil) }
case head :: tail => getters find { _.getter.name.toString == head } match { case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None) case None => Seq(None)
case Some(d : FieldAccessor) => d.desc.select(tail) case Some(d : FieldDescriptor) => d.desc.select(tail)
} }
} }
} }
case class FieldAccessor( case class FieldDescriptor(
name: String,
getter: Symbol, getter: Symbol,
setter: Symbol, setter: Symbol,
tpe: Type, tpe: Type,
isBaseField: Boolean,
desc: UDTDescriptor) desc: UDTDescriptor)
case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor { case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor {
......
...@@ -20,7 +20,6 @@ package org.apache.flink.api.scala.codegen ...@@ -20,7 +20,6 @@ package org.apache.flink.api.scala.codegen
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo
import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.typeutils._ import org.apache.flink.api.java.typeutils._
...@@ -29,6 +28,8 @@ import org.apache.flink.types.Value ...@@ -29,6 +28,8 @@ import org.apache.flink.types.Value
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.hadoop.io.Writable import org.apache.hadoop.io.Writable
import scala.collection.JavaConverters._
import scala.reflect.macros.Context import scala.reflect.macros.Context
private[flink] trait TypeInformationGen[C <: Context] { private[flink] trait TypeInformationGen[C <: Context] {
...@@ -41,7 +42,7 @@ private[flink] trait TypeInformationGen[C <: Context] { ...@@ -41,7 +42,7 @@ private[flink] trait TypeInformationGen[C <: Context] {
// This is for external calling by TypeUtils.createTypeInfo // This is for external calling by TypeUtils.createTypeInfo
def mkTypeInfo[T: c.WeakTypeTag]: c.Expr[TypeInformation[T]] = { def mkTypeInfo[T: c.WeakTypeTag]: c.Expr[TypeInformation[T]] = {
val desc = getUDTDescriptor(weakTypeOf[T]) val desc = getUDTDescriptor(weakTypeTag[T].tpe)
val result: c.Expr[TypeInformation[T]] = mkTypeInfo(desc)(c.WeakTypeTag(desc.tpe)) val result: c.Expr[TypeInformation[T]] = mkTypeInfo(desc)(c.WeakTypeTag(desc.tpe))
result result
} }
...@@ -61,6 +62,7 @@ private[flink] trait TypeInformationGen[C <: Context] { ...@@ -61,6 +62,7 @@ private[flink] trait TypeInformationGen[C <: Context] {
case d : WritableDescriptor => case d : WritableDescriptor =>
mkWritableTypeInfo(d)(c.WeakTypeTag(d.tpe).asInstanceOf[c.WeakTypeTag[Writable]]) mkWritableTypeInfo(d)(c.WeakTypeTag(d.tpe).asInstanceOf[c.WeakTypeTag[Writable]])
.asInstanceOf[c.Expr[TypeInformation[T]]] .asInstanceOf[c.Expr[TypeInformation[T]]]
case pojo: PojoDescriptor => mkPojo(pojo)
case d => mkGenericTypeInfo(d) case d => mkGenericTypeInfo(d)
} }
...@@ -96,7 +98,7 @@ private[flink] trait TypeInformationGen[C <: Context] { ...@@ -96,7 +98,7 @@ private[flink] trait TypeInformationGen[C <: Context] {
def mkListTypeInfo[T: c.WeakTypeTag](desc: ListDescriptor): c.Expr[TypeInformation[T]] = { def mkListTypeInfo[T: c.WeakTypeTag](desc: ListDescriptor): c.Expr[TypeInformation[T]] = {
val arrayClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) val arrayClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
val elementClazz = c.Expr[Class[T]](Literal(Constant(desc.elem.tpe))) val elementClazz = c.Expr[Class[T]](Literal(Constant(desc.elem.tpe)))
val elementTypeInfo = mkTypeInfo(desc.elem) val elementTypeInfo = mkTypeInfo(desc.elem)(c.WeakTypeTag(desc.elem.tpe))
desc.elem match { desc.elem match {
// special case for string, which in scala is a primitive, but not in java // special case for string, which in scala is a primitive, but not in java
case p: PrimitiveDescriptor if p.tpe <:< typeOf[String] => case p: PrimitiveDescriptor if p.tpe <:< typeOf[String] =>
...@@ -115,7 +117,8 @@ private[flink] trait TypeInformationGen[C <: Context] { ...@@ -115,7 +117,8 @@ private[flink] trait TypeInformationGen[C <: Context] {
reify { reify {
ObjectArrayTypeInfo.getInfoFor( ObjectArrayTypeInfo.getInfoFor(
arrayClazz.splice, arrayClazz.splice,
elementTypeInfo.splice).asInstanceOf[TypeInformation[T]] elementTypeInfo.splice.asInstanceOf[TypeInformation[_]])
.asInstanceOf[TypeInformation[T]]
} }
} }
} }
...@@ -136,6 +139,35 @@ private[flink] trait TypeInformationGen[C <: Context] { ...@@ -136,6 +139,35 @@ private[flink] trait TypeInformationGen[C <: Context] {
} }
} }
def mkPojo[T: c.WeakTypeTag](desc: PojoDescriptor): c.Expr[TypeInformation[T]] = {
val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
val fieldsTrees = desc.getters map {
f =>
val name = c.Expr(Literal(Constant(f.name)))
val fieldType = mkTypeInfo(f.desc)(c.WeakTypeTag(f.tpe))
reify { (name.splice, fieldType.splice) }.tree
}
val fieldsList = c.Expr[List[(String, TypeInformation[_])]](mkList(fieldsTrees.toList))
reify {
val fields = fieldsList.splice
val clazz: Class[T] = tpeClazz.splice
val fieldMap = TypeExtractor.getAllDeclaredFields(clazz).asScala map {
f => (f.getName, f)
} toMap
val pojoFields = fields map {
case (fName, fTpe) =>
new PojoField(fieldMap(fName), fTpe)
}
new PojoTypeInfo(clazz, pojoFields.asJava)
}
}
def mkGenericTypeInfo[T: c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = { def mkGenericTypeInfo[T: c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = {
val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
reify { reify {
...@@ -158,39 +190,4 @@ private[flink] trait TypeInformationGen[C <: Context] { ...@@ -158,39 +190,4 @@ private[flink] trait TypeInformationGen[C <: Context] {
val result = Apply(Select(New(TypeTree(desc.tpe)), nme.CONSTRUCTOR), fields.toList) val result = Apply(Select(New(TypeTree(desc.tpe)), nme.CONSTRUCTOR), fields.toList)
c.Expr[T](result) c.Expr[T](result)
} }
// def mkCaseClassTypeInfo[T: c.WeakTypeTag](
// desc: CaseClassDescriptor): c.Expr[TypeInformation[T]] = {
// val tpeClazz = c.Expr[Class[_]](Literal(Constant(desc.tpe)))
// val caseFields = mkCaseFields(desc)
// reify {
// new ScalaTupleTypeInfo[T] {
// def createSerializer: TypeSerializer[T] = {
// null
// }
//
// val fields: Map[String, TypeInformation[_]] = caseFields.splice
// val clazz = tpeClazz.splice
// }
// }
// }
//
// private def mkCaseFields(desc: UDTDescriptor): c.Expr[Map[String, TypeInformation[_]]] = {
// val fields = getFields("_root_", desc).toList map { case (fieldName, fieldDesc) =>
// val nameTree = c.Expr(Literal(Constant(fieldName)))
// val fieldTypeInfo = mkTypeInfo(fieldDesc)(c.WeakTypeTag(fieldDesc.tpe))
// reify { (nameTree.splice, fieldTypeInfo.splice) }.tree
// }
//
// c.Expr(mkMap(fields))
// }
//
// protected def getFields(name: String, desc: UDTDescriptor): Seq[(String, UDTDescriptor)] =
// desc match {
// // Flatten product types
// case CaseClassDescriptor(_, _, _, _, getters) =>
// getters filterNot { _.isBaseField } flatMap {
// f => getFields(name + "." + f.getter.name, f.desc) }
// case _ => Seq((name, desc))
// }
} }
...@@ -19,8 +19,10 @@ package org.apache.flink.api.scala.typeutils ...@@ -19,8 +19,10 @@ package org.apache.flink.api.scala.typeutils
import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeinfo.AtomicType import org.apache.flink.api.common.typeinfo.AtomicType
import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor
import org.apache.flink.api.java.operators.Keys.ExpressionKeys
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} import org.apache.flink.api.common.typeutils.{CompositeType, TypeComparator}
/** /**
* TypeInformation for Case Classes. Creation and access is different from * TypeInformation for Case Classes. Creation and access is different from
...@@ -58,16 +60,82 @@ abstract class CaseClassTypeInfo[T <: Product]( ...@@ -58,16 +60,82 @@ abstract class CaseClassTypeInfo[T <: Product](
override protected def getNewComparator: TypeComparator[T] = { override protected def getNewComparator: TypeComparator[T] = {
val finalLogicalKeyFields = logicalKeyFields.take(comparatorHelperIndex) val finalLogicalKeyFields = logicalKeyFields.take(comparatorHelperIndex)
val finalComparators = fieldComparators.take(comparatorHelperIndex) val finalComparators = fieldComparators.take(comparatorHelperIndex)
var maxKey: Int = 0 val maxKey = finalLogicalKeyFields.max
for (key <- finalLogicalKeyFields) {
maxKey = Math.max(maxKey, key) // create serializers only up to the last key, fields after that are not needed
val fieldSerializers = types.take(maxKey + 1).map(_.createSerializer)
new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers.toArray)
}
override def getKey(
fieldExpression: String,
offset: Int,
result: java.util.List[FlatFieldDescriptor]): Unit = {
if (fieldExpression == ExpressionKeys.SELECT_ALL_CHAR) {
var keyPosition = 0
for (tpe <- types) {
tpe match {
case a: AtomicType[_] =>
result.add(new CompositeType.FlatFieldDescriptor(offset + keyPosition, tpe))
case co: CompositeType[_] =>
co.getKey(ExpressionKeys.SELECT_ALL_CHAR, offset + keyPosition, result)
keyPosition += co.getTotalFields - 1
case _ => throw new RuntimeException(s"Unexpected key type: $tpe")
}
keyPosition += 1
}
return
}
if (fieldExpression == null || fieldExpression.length <= 0) {
throw new IllegalArgumentException("Field expression must not be empty.")
} }
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](maxKey + 1)
for (i <- 0 to maxKey) { fieldExpression.split('.').toList match {
fieldSerializers(i) = types(i).createSerializer case headField :: Nil =>
var fieldId = 0
for (i <- 0 until fieldNames.length) {
fieldId += types(i).getTotalFields - 1
if (fieldNames(i) == headField) {
if (fieldTypes(i).isInstanceOf[CompositeType[_]]) {
throw new IllegalArgumentException(
s"The specified field '$fieldExpression' is refering to a composite type.\n"
+ s"Either select all elements in this type with the " +
s"'${ExpressionKeys.SELECT_ALL_CHAR}' operator or specify a field in" +
s" the sub-type")
}
result.add(new CompositeType.FlatFieldDescriptor(offset + fieldId, fieldTypes(i)))
return
}
fieldId += 1
}
case firstField :: rest =>
var fieldId = 0
for (i <- 0 until fieldNames.length) {
if (fieldNames(i) == firstField) {
fieldTypes(i) match {
case co: CompositeType[_] =>
co.getKey(rest.mkString("."), offset + fieldId, result)
return
case _ =>
throw new RuntimeException(s"Field ${fieldTypes(i)} is not a composite type.")
}
}
fieldId += types(i).getTotalFields
}
} }
new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers)
throw new RuntimeException(s"Unable to find field $fieldExpression in type $this.")
} }
override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map { override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map {
......
...@@ -69,11 +69,9 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O]( ...@@ -69,11 +69,9 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O](
* This only works on a CaseClass [[DataSet]]. * This only works on a CaseClass [[DataSet]].
*/ */
def where(firstLeftField: String, otherLeftFields: String*) = { def where(firstLeftField: String, otherLeftFields: String*) = {
val fieldIndices = fieldNames2Indices( val leftKey = new ExpressionKeys[L](
leftInput.getType, firstLeftField +: otherLeftFields.toArray,
firstLeftField +: otherLeftFields.toArray) leftInput.getType)
val leftKey = new ExpressionKeys[L](fieldIndices, leftInput.getType)
new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey) new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey)
} }
...@@ -118,11 +116,9 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( ...@@ -118,11 +116,9 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O](
* This only works on a CaseClass [[DataSet]]. * This only works on a CaseClass [[DataSet]].
*/ */
def equalTo(firstRightField: String, otherRightFields: String*): O = { def equalTo(firstRightField: String, otherRightFields: String*): O = {
val fieldIndices = fieldNames2Indices( val rightKey = new ExpressionKeys[R](
unfinished.rightInput.getType, firstRightField +: otherRightFields.toArray,
firstRightField +: otherRightFields.toArray) unfinished.rightInput.getType)
val rightKey = new ExpressionKeys[R](fieldIndices, unfinished.rightInput.getType)
if (!leftKey.areCompatible(rightKey)) { if (!leftKey.areCompatible(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " + throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey) leftKey + " Right: " + rightKey)
......
...@@ -19,6 +19,7 @@ package org.apache.flink.api.scala.operators ...@@ -19,6 +19,7 @@ package org.apache.flink.api.scala.operators
import org.apache.flink.api.java.aggregation.Aggregations import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.scala.ExecutionEnvironment import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.configuration.Configuration import org.apache.flink.configuration.Configuration
import org.apache.flink.test.util.JavaProgramTestBase import org.apache.flink.test.util.JavaProgramTestBase
import org.junit.runner.RunWith import org.junit.runner.RunWith
...@@ -34,38 +35,14 @@ import org.apache.flink.api.scala._ ...@@ -34,38 +35,14 @@ import org.apache.flink.api.scala._
object AggregateProgs { object AggregateProgs {
var NUM_PROGRAMS: Int = 3 var NUM_PROGRAMS: Int = 3
val tupleInput = Array(
(1,1L,"Hi"),
(2,2L,"Hello"),
(3,2L,"Hello world"),
(4,3L,"Hello world, how are you?"),
(5,3L,"I am fine."),
(6,3L,"Luke Skywalker"),
(7,4L,"Comment#1"),
(8,4L,"Comment#2"),
(9,4L,"Comment#3"),
(10,4L,"Comment#4"),
(11,5L,"Comment#5"),
(12,5L,"Comment#6"),
(13,5L,"Comment#7"),
(14,5L,"Comment#8"),
(15,5L,"Comment#9"),
(16,6L,"Comment#10"),
(17,6L,"Comment#11"),
(18,6L,"Comment#12"),
(19,6L,"Comment#13"),
(20,6L,"Comment#14"),
(21,6L,"Comment#15")
)
def runProgram(progId: Int, resultPath: String): String = { def runProgram(progId: Int, resultPath: String): String = {
progId match { progId match {
case 1 => case 1 =>
// Full aggregate // Full aggregate
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
env.setDegreeOfParallelism(10) env.setDegreeOfParallelism(10)
val ds = env.fromCollection(tupleInput) // val ds = CollectionDataSets.get3TupleDataSet(env)
val ds = CollectionDataSets.get3TupleDataSet(env)
val aggregateDs = ds val aggregateDs = ds
.aggregate(Aggregations.SUM,0) .aggregate(Aggregations.SUM,0)
...@@ -84,7 +61,7 @@ object AggregateProgs { ...@@ -84,7 +61,7 @@ object AggregateProgs {
case 2 => case 2 =>
// Grouped aggregate // Grouped aggregate
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
val aggregateDs = ds val aggregateDs = ds
.groupBy(1) .groupBy(1)
...@@ -103,7 +80,7 @@ object AggregateProgs { ...@@ -103,7 +80,7 @@ object AggregateProgs {
case 3 => case 3 =>
// Nested aggregate // Nested aggregate
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
val aggregateDs = ds val aggregateDs = ds
.groupBy(1) .groupBy(1)
...@@ -111,7 +88,7 @@ object AggregateProgs { ...@@ -111,7 +88,7 @@ object AggregateProgs {
.aggregate(Aggregations.MIN, 0) .aggregate(Aggregations.MIN, 0)
// Ensure aggregate operator correctly copies other fields // Ensure aggregate operator correctly copies other fields
.filter(_._3 != null) .filter(_._3 != null)
.map { t => Tuple1(t._1) } .map { t => new Tuple1(t._1) }
aggregateDs.writeAsCsv(resultPath) aggregateDs.writeAsCsv(resultPath)
...@@ -140,7 +117,7 @@ class AggregateITCase(config: Configuration) extends JavaProgramTestBase(config) ...@@ -140,7 +117,7 @@ class AggregateITCase(config: Configuration) extends JavaProgramTestBase(config)
} }
protected def testProgram(): Unit = { protected def testProgram(): Unit = {
expectedResult = AggregateProgs.runProgram(curProgId, resultPath) expectedResult = DistinctProgs.runProgram(curProgId, resultPath)
} }
protected override def postSubmit(): Unit = { protected override def postSubmit(): Unit = {
...@@ -152,7 +129,7 @@ object AggregateITCase { ...@@ -152,7 +129,7 @@ object AggregateITCase {
@Parameters @Parameters
def getConfigurations: java.util.Collection[Array[AnyRef]] = { def getConfigurations: java.util.Collection[Array[AnyRef]] = {
val configs = mutable.MutableList[Array[AnyRef]]() val configs = mutable.MutableList[Array[AnyRef]]()
for (i <- 1 to AggregateProgs.NUM_PROGRAMS) { for (i <- 1 to DistinctProgs.NUM_PROGRAMS) {
val config = new Configuration() val config = new Configuration()
config.setInteger("ProgramId", i) config.setInteger("ProgramId", i)
configs += Array(config) configs += Array(config)
......
...@@ -17,13 +17,11 @@ ...@@ -17,13 +17,11 @@
*/ */
package org.apache.flink.api.scala.operators package org.apache.flink.api.scala.operators
import java.io.Serializable
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.junit.Assert import org.junit.Assert
import org.junit.Ignore
import org.junit.Test import org.junit.Test
import org.apache.flink.api.scala._ import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
class CoGroupOperatorTest { class CoGroupOperatorTest {
...@@ -130,7 +128,7 @@ class CoGroupOperatorTest { ...@@ -130,7 +128,7 @@ class CoGroupOperatorTest {
ds1.coGroup(ds2).where("_1", "_2").equalTo("_3") ds1.coGroup(ds2).where("_1", "_2").equalTo("_3")
} }
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[RuntimeException])
def testCoGroupKeyFieldNames4(): Unit = { def testCoGroupKeyFieldNames4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData) val ds1 = env.fromCollection(emptyTupleData)
...@@ -140,7 +138,7 @@ class CoGroupOperatorTest { ...@@ -140,7 +138,7 @@ class CoGroupOperatorTest {
ds1.coGroup(ds2).where("_6").equalTo("_1") ds1.coGroup(ds2).where("_6").equalTo("_1")
} }
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[RuntimeException])
def testCoGroupKeyFieldNames5(): Unit = { def testCoGroupKeyFieldNames5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData) val ds1 = env.fromCollection(emptyTupleData)
...@@ -150,7 +148,7 @@ class CoGroupOperatorTest { ...@@ -150,7 +148,7 @@ class CoGroupOperatorTest {
ds1.coGroup(ds2).where("_1").equalTo("bar") ds1.coGroup(ds2).where("_1").equalTo("bar")
} }
@Test(expected = classOf[UnsupportedOperationException]) @Test(expected = classOf[RuntimeException])
def testCoGroupKeyFieldNames6(): Unit = { def testCoGroupKeyFieldNames6(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData) val ds1 = env.fromCollection(emptyTupleData)
...@@ -160,7 +158,6 @@ class CoGroupOperatorTest { ...@@ -160,7 +158,6 @@ class CoGroupOperatorTest {
ds1.coGroup(ds2).where("_3").equalTo("_1") ds1.coGroup(ds2).where("_3").equalTo("_1")
} }
@Ignore
@Test @Test
def testCoGroupKeyExpressions1(): Unit = { def testCoGroupKeyExpressions1(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
...@@ -176,29 +173,26 @@ class CoGroupOperatorTest { ...@@ -176,29 +173,26 @@ class CoGroupOperatorTest {
} }
} }
@Ignore @Test(expected = classOf[IncompatibleKeysException])
@Test(expected = classOf[InvalidProgramException])
def testCoGroupKeyExpressions2(): Unit = { def testCoGroupKeyExpressions2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(customTypeData) val ds1 = env.fromCollection(customTypeData)
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, incompatible key types // should not work, incompatible key types
// ds1.coGroup(ds2).where("i").equalTo("s") ds1.coGroup(ds2).where("myInt").equalTo("myString")
} }
@Ignore @Test(expected = classOf[IncompatibleKeysException])
@Test(expected = classOf[InvalidProgramException])
def testCoGroupKeyExpressions3(): Unit = { def testCoGroupKeyExpressions3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(customTypeData) val ds1 = env.fromCollection(customTypeData)
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, incompatible number of keys // should not work, incompatible number of keys
// ds1.coGroup(ds2).where("i", "s").equalTo("s") ds1.coGroup(ds2).where("myInt", "myString").equalTo("myString")
} }
@Ignore
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[IllegalArgumentException])
def testCoGroupKeyExpressions4(): Unit = { def testCoGroupKeyExpressions4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
...@@ -207,7 +201,7 @@ class CoGroupOperatorTest { ...@@ -207,7 +201,7 @@ class CoGroupOperatorTest {
// should not work, key non-existent // should not work, key non-existent
// ds1.coGroup(ds2).where("myNonExistent").equalTo("i") ds1.coGroup(ds2).where("myNonExistent").equalTo("i")
} }
@Test @Test
...@@ -218,7 +212,7 @@ class CoGroupOperatorTest { ...@@ -218,7 +212,7 @@ class CoGroupOperatorTest {
// Should work // Should work
try { try {
ds1.coGroup(ds2).where { _.l } equalTo { _.l } ds1.coGroup(ds2).where { _.myLong } equalTo { _.myLong }
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
...@@ -233,7 +227,7 @@ class CoGroupOperatorTest { ...@@ -233,7 +227,7 @@ class CoGroupOperatorTest {
// Should work // Should work
try { try {
ds1.coGroup(ds2).where { _.l}.equalTo(3) ds1.coGroup(ds2).where { _.myLong }.equalTo(3)
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
...@@ -248,7 +242,7 @@ class CoGroupOperatorTest { ...@@ -248,7 +242,7 @@ class CoGroupOperatorTest {
// Should work // Should work
try { try {
ds1.coGroup(ds2).where(3).equalTo { _.l } ds1.coGroup(ds2).where(3).equalTo { _.myLong }
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
...@@ -262,7 +256,7 @@ class CoGroupOperatorTest { ...@@ -262,7 +256,7 @@ class CoGroupOperatorTest {
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// Should not work, incompatible types // Should not work, incompatible types
ds1.coGroup(ds2).where(2).equalTo { _.l } ds1.coGroup(ds2).where(2).equalTo { _.myLong }
} }
@Test(expected = classOf[IncompatibleKeysException]) @Test(expected = classOf[IncompatibleKeysException])
...@@ -272,7 +266,7 @@ class CoGroupOperatorTest { ...@@ -272,7 +266,7 @@ class CoGroupOperatorTest {
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// Should not work, more than one field position key // Should not work, more than one field position key
ds1.coGroup(ds2).where(1, 3).equalTo { _.l } ds1.coGroup(ds2).where(1, 3).equalTo { _.myLong }
} }
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.operators
import java.io.Serializable
/**
* A custom data type that is used by the operator Tests.
*/
class CustomType(var i:Int, var l: Long, var s: String) extends Serializable {
def this() {
this(0, 0, null)
}
override def toString: String = {
i + "," + l + "," + s
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.operators
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.configuration.Configuration
import org.apache.flink.test.util.JavaProgramTestBase
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.flink.api.scala._
object DistinctProgs {
var NUM_PROGRAMS: Int = 8
def runProgram(progId: Int, resultPath: String): String = {
progId match {
case 1 =>
/*
* Check correctness of distinct on tuples with key field selector
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getSmall3TupleDataSet(env)
val distinctDs = ds.union(ds).distinct(0, 1, 2)
distinctDs.writeAsCsv(resultPath)
env.execute()
// return expected result
"1,1,Hi\n" +
"2,2,Hello\n" +
"3,2,Hello world\n"
case 2 =>
/*
* check correctness of distinct on tuples with key field selector with not all fields
* selected
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getSmall5TupleDataSet(env)
val distinctDs = ds.union(ds).distinct(0).map(_._1)
distinctDs.writeAsText(resultPath)
env.execute()
"1\n" + "2\n"
case 3 =>
/*
* check correctness of distinct on tuples with key extractor
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getSmall5TupleDataSet(env)
val reduceDs = ds.union(ds).distinct(_._1).map(_._1)
reduceDs.writeAsText(resultPath)
env.execute()
"1\n" + "2\n"
case 4 =>
/*
* check correctness of distinct on custom type with type extractor
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getCustomTypeDataSet(env)
val reduceDs = ds.distinct(_.myInt).map( t => new Tuple1(t.myInt))
reduceDs.writeAsCsv(resultPath)
env.execute()
"1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
case 5 =>
/*
* check correctness of distinct on tuples
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getSmall3TupleDataSet(env)
val distinctDs = ds.union(ds).distinct()
distinctDs.writeAsCsv(resultPath)
env.execute()
"1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n"
case 6 =>
/*
* check correctness of distinct on custom type with tuple-returning type extractor
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.get5TupleDataSet(env)
val reduceDs = ds.distinct( t => (t._1, t._5)).map( t => (t._1, t._5) )
reduceDs.writeAsCsv(resultPath)
env.execute()
"1,1\n" + "2,1\n" + "2,2\n" + "3,2\n" + "3,3\n" + "4,1\n" + "4,2\n" + "5," +
"1\n" + "5,2\n" + "5,3\n"
case 7 =>
/*
* check correctness of distinct on tuples with field expressions
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getSmall5TupleDataSet(env)
val reduceDs = ds.union(ds).distinct("_1").map(t => new Tuple1(t._1))
reduceDs.writeAsCsv(resultPath)
env.execute()
"1\n" + "2\n"
case 8 =>
/*
* check correctness of distinct on Pojos
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = CollectionDataSets.getDuplicatePojoDataSet(env)
val reduceDs = ds.distinct("nestedPojo.longNumber").map(_.nestedPojo.longNumber.toInt)
reduceDs.writeAsText(resultPath)
env.execute()
"10000\n20000\n30000\n"
case _ =>
throw new IllegalArgumentException("Invalid program id")
}
}
}
@RunWith(classOf[Parameterized])
class DistinctITCase(config: Configuration) extends JavaProgramTestBase(config) {
private var curProgId: Int = config.getInteger("ProgramId", -1)
private var resultPath: String = null
private var expectedResult: String = null
protected override def preSubmit(): Unit = {
resultPath = getTempDirPath("result")
}
protected def testProgram(): Unit = {
expectedResult = DistinctProgs.runProgram(curProgId, resultPath)
}
protected override def postSubmit(): Unit = {
compareResultsByLinesInMemory(expectedResult, resultPath)
}
}
object DistinctITCase {
@Parameters
def getConfigurations: java.util.Collection[Array[AnyRef]] = {
val configs = mutable.MutableList[Array[AnyRef]]()
for (i <- 1 to DistinctProgs.NUM_PROGRAMS) {
val config = new Configuration()
config.setInteger("ProgramId", i)
configs += Array(config)
}
configs.asJavaCollection
}
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
*/ */
package org.apache.flink.api.scala.operators package org.apache.flink.api.scala.operators
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.common.InvalidProgramException
import org.junit.Test import org.junit.Test
...@@ -102,7 +103,7 @@ class DistinctOperatorTest { ...@@ -102,7 +103,7 @@ class DistinctOperatorTest {
} }
} }
@Test(expected = classOf[UnsupportedOperationException]) @Test(expected = classOf[RuntimeException])
def testDistinctByKeyFields2(): Unit = { def testDistinctByKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData) val longDs = env.fromCollection(emptyLongData)
...@@ -111,16 +112,16 @@ class DistinctOperatorTest { ...@@ -111,16 +112,16 @@ class DistinctOperatorTest {
longDs.distinct("_1") longDs.distinct("_1")
} }
@Test(expected = classOf[UnsupportedOperationException]) @Test(expected = classOf[RuntimeException])
def testDistinctByKeyFields3(): Unit = { def testDistinctByKeyFields3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData) val customDs = env.fromCollection(customTypeData)
// should not work: field key on custom type // should not work: invalid fields
customDs.distinct("_1") customDs.distinct("_1")
} }
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[RuntimeException])
def testDistinctByKeyFields4(): Unit = { def testDistinctByKeyFields4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData) val tupleDs = env.fromCollection(emptyTupleData)
...@@ -129,12 +130,21 @@ class DistinctOperatorTest { ...@@ -129,12 +130,21 @@ class DistinctOperatorTest {
tupleDs.distinct("foo") tupleDs.distinct("foo")
} }
@Test
def testDistinctByKeyFields5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
// should work
customDs.distinct("myInt")
}
@Test @Test
def testDistinctByKeySelector1(): Unit = { def testDistinctByKeySelector1(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
try { try {
val customDs = env.fromCollection(customTypeData) val customDs = env.fromCollection(customTypeData)
customDs.distinct {_.l} customDs.distinct {_.myLong}
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
......
...@@ -29,25 +29,36 @@ import scala.collection.JavaConverters._ ...@@ -29,25 +29,36 @@ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
// TODO case class Tuple2[T1, T2](_1: T1, _2: T2) // TODO case class Tuple2[T1, T2](_1: T1, _2: T2)
// TODO case class Foo(a: Int, b: String) // TODO case class Foo(a: Int, b: String
class Nested(var myLong: Long) { case class Nested(myLong: Long)
class Pojo(var myString: String, var myInt: Int, var nested: Nested) {
def this() = { def this() = {
this(0); this("", 0, new Nested(1))
} }
def this(myString: String, myInt: Int, myLong: Long) { this(myString, myInt, new Nested(myLong)) }
override def toString = s"myString=$myString myInt=$myInt nested.myLong=${nested.myLong}"
}
class NestedPojo(var myLong: Long) {
def this() { this(0) }
} }
class Pojo(var myString: String, var myInt: Int, myLong: Long) {
var nested = new Nested(myLong)
class PojoWithPojo(var myString: String, var myInt: Int, var nested: Nested) {
def this() = { def this() = {
this("", 0, 0) this("", 0, new Nested(1))
} }
override def toString() = "myString="+myString+" myInt="+myInt+" nested.myLong="+nested.myLong def this(myString: String, myInt: Int, myLong: Long) { this(myString, myInt, new Nested(myLong)) }
override def toString = s"myString=$myString myInt=$myInt nested.myLong=${nested.myLong}"
} }
object ExampleProgs { object ExampleProgs {
var NUM_PROGRAMS: Int = 3 var NUM_PROGRAMS: Int = 4
def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = {
progId match { progId match {
...@@ -58,27 +69,53 @@ object ExampleProgs { ...@@ -58,27 +69,53 @@ object ExampleProgs {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) )
val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)}) val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1, e1._1._2), e1._2 + e2._2)})
grouped.writeAsText(resultPath) grouped.writeAsText(resultPath)
env.execute() env.execute()
"((this,hello),3)\n((this,is),3)\n" "((this,hello),3)\n((this,is),3)\n"
case 2 => case 2 =>
/* /*
Test nested tuples with int offset Test nested tuples with int offset
*/ */
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) )
val grouped = ds.groupBy("f0.f0").reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)}) val grouped = ds.groupBy("_1._1").reduce{
grouped.writeAsText(resultPath) (e1, e2) => ((e1._1._1, e1._1._2), e1._2 + e2._2)
env.execute() }
"((this,is),6)\n" grouped.writeAsText(resultPath)
env.execute()
"((this,is),6)\n"
case 3 => case 3 =>
/* /*
Test nested pojos Test nested pojos
*/ */
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( new Pojo("one", 1, 1L),new Pojo("one", 1, 1L),new Pojo("two", 666, 2L) ) val ds = env.fromElements(
new PojoWithPojo("one", 1, 1L),
new PojoWithPojo("one", 1, 1L),
new PojoWithPojo("two", 666, 2L) )
val grouped = ds.groupBy("nested.myLong").reduce {
(p1, p2) =>
p1.myInt += p2.myInt
p1
}
grouped.writeAsText(resultPath)
env.execute()
"myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n"
case 4 =>
/*
Test pojo with nested case class
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(
new Pojo("one", 1, 1L),
new Pojo("one", 1, 1L),
new Pojo("two", 666, 2L) )
val grouped = ds.groupBy("nested.myLong").reduce { val grouped = ds.groupBy("nested.myLong").reduce {
(p1, p2) => (p1, p2) =>
...@@ -124,4 +161,4 @@ object ExamplesITCase { ...@@ -124,4 +161,4 @@ object ExamplesITCase {
configs.asJavaCollection configs.asJavaCollection
} }
} }
\ No newline at end of file
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
*/ */
package org.apache.flink.api.scala.operators package org.apache.flink.api.scala.operators
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.operators.Order import org.apache.flink.api.common.operators.Order
import org.junit.Ignore
import org.junit.Test import org.junit.Test
import org.apache.flink.api.scala._ import org.apache.flink.api.scala._
...@@ -96,7 +96,7 @@ class GroupingTest { ...@@ -96,7 +96,7 @@ class GroupingTest {
} }
} }
@Test(expected = classOf[UnsupportedOperationException]) @Test(expected = classOf[IllegalArgumentException])
def testGroupByKeyFields2(): Unit = { def testGroupByKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData) val longDs = env.fromCollection(emptyLongData)
...@@ -105,7 +105,7 @@ class GroupingTest { ...@@ -105,7 +105,7 @@ class GroupingTest {
longDs.groupBy("_1") longDs.groupBy("_1")
} }
@Test(expected = classOf[UnsupportedOperationException]) @Test(expected = classOf[IllegalArgumentException])
def testGroupByKeyFields3(): Unit = { def testGroupByKeyFields3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData) val customDs = env.fromCollection(customTypeData)
...@@ -114,7 +114,7 @@ class GroupingTest { ...@@ -114,7 +114,7 @@ class GroupingTest {
customDs.groupBy("_1") customDs.groupBy("_1")
} }
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[RuntimeException])
def testGroupByKeyFields4(): Unit = { def testGroupByKeyFields4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData) val tupleDs = env.fromCollection(emptyTupleData)
...@@ -123,7 +123,15 @@ class GroupingTest { ...@@ -123,7 +123,15 @@ class GroupingTest {
tupleDs.groupBy("foo") tupleDs.groupBy("foo")
} }
@Ignore @Test
def testGroupByKeyFields5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
// should not work
customDs.groupBy("myInt")
}
@Test @Test
def testGroupByKeyExpressions1(): Unit = { def testGroupByKeyExpressions1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
...@@ -131,24 +139,22 @@ class GroupingTest { ...@@ -131,24 +139,22 @@ class GroupingTest {
// should work // should work
try { try {
// ds.groupBy("i"); ds.groupBy("myInt")
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
} }
} }
@Ignore @Test(expected = classOf[IllegalArgumentException])
@Test(expected = classOf[UnsupportedOperationException])
def testGroupByKeyExpressions2(): Unit = { def testGroupByKeyExpressions2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
// should not work: groups on basic type // should not work: groups on basic type
// longDs.groupBy("l");
val longDs = env.fromCollection(emptyLongData) val longDs = env.fromCollection(emptyLongData)
longDs.groupBy("l")
} }
@Ignore
@Test(expected = classOf[InvalidProgramException]) @Test(expected = classOf[InvalidProgramException])
def testGroupByKeyExpressions3(): Unit = { def testGroupByKeyExpressions3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
...@@ -158,14 +164,13 @@ class GroupingTest { ...@@ -158,14 +164,13 @@ class GroupingTest {
customDs.groupBy(0) customDs.groupBy(0)
} }
@Ignore
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[IllegalArgumentException])
def testGroupByKeyExpressions4(): Unit = { def testGroupByKeyExpressions4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(customTypeData) val ds = env.fromCollection(customTypeData)
// should not work, non-existent field // should not work, non-existent field
// ds.groupBy("myNonExistent"); ds.groupBy("myNonExistent")
} }
@Test @Test
...@@ -173,7 +178,7 @@ class GroupingTest { ...@@ -173,7 +178,7 @@ class GroupingTest {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
try { try {
val customDs = env.fromCollection(customTypeData) val customDs = env.fromCollection(customTypeData)
customDs.groupBy { _.l } customDs.groupBy { _.myLong }
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.flink.api.scala.operators package org.apache.flink.api.scala.operators
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.common.InvalidProgramException
import org.junit.Ignore import org.junit.Ignore
...@@ -132,7 +133,7 @@ class JoinOperatorTest { ...@@ -132,7 +133,7 @@ class JoinOperatorTest {
ds1.join(ds2).where("_1", "_2").equalTo("_3") ds1.join(ds2).where("_1", "_2").equalTo("_3")
} }
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[RuntimeException])
def testJoinKeyFields4(): Unit = { def testJoinKeyFields4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData) val ds1 = env.fromCollection(emptyTupleData)
...@@ -142,7 +143,7 @@ class JoinOperatorTest { ...@@ -142,7 +143,7 @@ class JoinOperatorTest {
ds1.join(ds2).where("foo").equalTo("_1") ds1.join(ds2).where("foo").equalTo("_1")
} }
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[RuntimeException])
def testJoinKeyFields5(): Unit = { def testJoinKeyFields5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData) val ds1 = env.fromCollection(emptyTupleData)
...@@ -152,7 +153,7 @@ class JoinOperatorTest { ...@@ -152,7 +153,7 @@ class JoinOperatorTest {
ds1.join(ds2).where("_1").equalTo("bar") ds1.join(ds2).where("_1").equalTo("bar")
} }
@Test(expected = classOf[UnsupportedOperationException]) @Test(expected = classOf[IllegalArgumentException])
def testJoinKeyFields6(): Unit = { def testJoinKeyFields6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData) val ds1 = env.fromCollection(emptyTupleData)
...@@ -162,7 +163,6 @@ class JoinOperatorTest { ...@@ -162,7 +163,6 @@ class JoinOperatorTest {
ds1.join(ds2).where("_2").equalTo("_1") ds1.join(ds2).where("_2").equalTo("_1")
} }
@Ignore
@Test @Test
def testJoinKeyExpressions1(): Unit = { def testJoinKeyExpressions1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
...@@ -171,36 +171,33 @@ class JoinOperatorTest { ...@@ -171,36 +171,33 @@ class JoinOperatorTest {
// should work // should work
try { try {
// ds1.join(ds2).where("i").equalTo("i") ds1.join(ds2).where("myInt").equalTo("myInt")
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
} }
} }
@Ignore @Test(expected = classOf[IncompatibleKeysException])
@Test(expected = classOf[InvalidProgramException])
def testJoinKeyExpressions2(): Unit = { def testJoinKeyExpressions2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(customTypeData) val ds1 = env.fromCollection(customTypeData)
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, incompatible join key types // should not work, incompatible join key types
// ds1.join(ds2).where("i").equalTo("s") ds1.join(ds2).where("myInt").equalTo("myString")
} }
@Ignore @Test(expected = classOf[IncompatibleKeysException])
@Test(expected = classOf[InvalidProgramException])
def testJoinKeyExpressions3(): Unit = { def testJoinKeyExpressions3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(customTypeData) val ds1 = env.fromCollection(customTypeData)
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, incompatible number of keys // should not work, incompatible number of keys
// ds1.join(ds2).where("i", "s").equalTo("i") ds1.join(ds2).where("myInt", "myString").equalTo("myInt")
} }
@Ignore
@Test(expected = classOf[IllegalArgumentException]) @Test(expected = classOf[IllegalArgumentException])
def testJoinKeyExpressions4(): Unit = { def testJoinKeyExpressions4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
...@@ -208,7 +205,7 @@ class JoinOperatorTest { ...@@ -208,7 +205,7 @@ class JoinOperatorTest {
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, join key non-existent // should not work, join key non-existent
// ds1.join(ds2).where("myNonExistent").equalTo("i") ds1.join(ds2).where("myNonExistent").equalTo("i")
} }
@Test @Test
...@@ -219,7 +216,7 @@ class JoinOperatorTest { ...@@ -219,7 +216,7 @@ class JoinOperatorTest {
// should work // should work
try { try {
ds1.join(ds2).where { _.l} equalTo { _.l } ds1.join(ds2).where { _.myLong} equalTo { _.myLong }
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
...@@ -234,7 +231,7 @@ class JoinOperatorTest { ...@@ -234,7 +231,7 @@ class JoinOperatorTest {
// should work // should work
try { try {
ds1.join(ds2).where { _.l }.equalTo(3) ds1.join(ds2).where { _.myLong }.equalTo(3)
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
...@@ -249,7 +246,7 @@ class JoinOperatorTest { ...@@ -249,7 +246,7 @@ class JoinOperatorTest {
// should work // should work
try { try {
ds1.join(ds2).where(3).equalTo { _.l } ds1.join(ds2).where(3).equalTo { _.myLong }
} }
catch { catch {
case e: Exception => Assert.fail() case e: Exception => Assert.fail()
...@@ -263,7 +260,7 @@ class JoinOperatorTest { ...@@ -263,7 +260,7 @@ class JoinOperatorTest {
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, incompatible types // should not work, incompatible types
ds1.join(ds2).where(2).equalTo { _.l } ds1.join(ds2).where(2).equalTo { _.myLong }
} }
@Test(expected = classOf[IncompatibleKeysException]) @Test(expected = classOf[IncompatibleKeysException])
...@@ -273,7 +270,7 @@ class JoinOperatorTest { ...@@ -273,7 +270,7 @@ class JoinOperatorTest {
val ds2 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData)
// should not work, more than one field position key // should not work, more than one field position key
ds1.join(ds2).where(1, 3) equalTo { _.l } ds1.join(ds2).where(1, 3) equalTo { _.myLong }
} }
} }
...@@ -19,6 +19,7 @@ package org.apache.flink.api.scala.operators ...@@ -19,6 +19,7 @@ package org.apache.flink.api.scala.operators
import org.apache.flink.api.common.functions.{RichFilterFunction, RichMapFunction} import org.apache.flink.api.common.functions.{RichFilterFunction, RichMapFunction}
import org.apache.flink.api.scala.ExecutionEnvironment import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.configuration.Configuration import org.apache.flink.configuration.Configuration
import org.apache.flink.test.util.JavaProgramTestBase import org.apache.flink.test.util.JavaProgramTestBase
import org.junit.runner.RunWith import org.junit.runner.RunWith
...@@ -32,40 +33,18 @@ import org.apache.flink.api.scala._ ...@@ -32,40 +33,18 @@ import org.apache.flink.api.scala._
object PartitionProgs { object PartitionProgs {
var NUM_PROGRAMS: Int = 6 var NUM_PROGRAMS: Int = 7
val tupleInput = Array(
(1, "Foo"),
(1, "Foo"),
(1, "Foo"),
(2, "Foo"),
(2, "Foo"),
(2, "Foo"),
(2, "Foo"),
(2, "Foo"),
(3, "Foo"),
(3, "Foo"),
(3, "Foo"),
(4, "Foo"),
(4, "Foo"),
(4, "Foo"),
(4, "Foo"),
(5, "Foo"),
(5, "Foo"),
(6, "Foo"),
(6, "Foo"),
(6, "Foo"),
(6, "Foo")
)
def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = {
progId match { progId match {
case 1 => case 1 =>
/*
* Test hash partition by tuple field
*/
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
val unique = ds.partitionByHash(0).mapPartition( _.map(_._1).toSet ) val unique = ds.partitionByHash(1).mapPartition( _.map(_._2).toSet )
unique.writeAsText(resultPath) unique.writeAsText(resultPath)
env.execute() env.execute()
...@@ -73,16 +52,22 @@ object PartitionProgs { ...@@ -73,16 +52,22 @@ object PartitionProgs {
"1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
case 2 => case 2 =>
/*
* Test hash partition by key selector
*/
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
val unique = ds.partitionByHash( _._1 ).mapPartition( _.map(_._1).toSet ) val unique = ds.partitionByHash( _._2 ).mapPartition( _.map(_._2).toSet )
unique.writeAsText(resultPath) unique.writeAsText(resultPath)
env.execute() env.execute()
"1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
case 3 => case 3 =>
val env = ExecutionEnvironment.getExecutionEnvironment /*
* Test forced rebalancing
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.generateSequence(1, 3000) val ds = env.generateSequence(1, 3000)
val skewed = ds.filter(_ > 780) val skewed = ds.filter(_ > 780)
...@@ -101,8 +86,8 @@ object PartitionProgs { ...@@ -101,8 +86,8 @@ object PartitionProgs {
countsInPartition.writeAsText(resultPath) countsInPartition.writeAsText(resultPath)
env.execute() env.execute()
val numPerPartition : Int = 2220 / env.getDegreeOfParallelism / 10; val numPerPartition : Int = 2220 / env.getDegreeOfParallelism / 10
var result = ""; var result = ""
for (i <- 0 until env.getDegreeOfParallelism) { for (i <- 0 until env.getDegreeOfParallelism) {
result += "(" + i + "," + numPerPartition + ")\n" result += "(" + i + "," + numPerPartition + ")\n"
} }
...@@ -112,10 +97,12 @@ object PartitionProgs { ...@@ -112,10 +97,12 @@ object PartitionProgs {
// Verify that mapPartition operation after repartition picks up correct // Verify that mapPartition operation after repartition picks up correct
// DOP // DOP
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
env.setDegreeOfParallelism(1) env.setDegreeOfParallelism(1)
val unique = ds.partitionByHash(0).setParallelism(4).mapPartition( _.map(_._1).toSet ) val unique = ds.partitionByHash(1)
.setParallelism(4)
.mapPartition( _.map(_._2).toSet )
unique.writeAsText(resultPath) unique.writeAsText(resultPath)
env.execute() env.execute()
...@@ -126,13 +113,13 @@ object PartitionProgs { ...@@ -126,13 +113,13 @@ object PartitionProgs {
// Verify that map operation after repartition picks up correct // Verify that map operation after repartition picks up correct
// DOP // DOP
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
env.setDegreeOfParallelism(1) env.setDegreeOfParallelism(1)
val count = ds.partitionByHash(0).setParallelism(4).map( val count = ds.partitionByHash(0).setParallelism(4).map(
new RichMapFunction[(Int, String), Tuple1[Int]] { new RichMapFunction[(Int, Long, String), Tuple1[Int]] {
var first = true var first = true
override def map(in: (Int, String)): Tuple1[Int] = { override def map(in: (Int, Long, String)): Tuple1[Int] = {
// only output one value with count 1 // only output one value with count 1
if (first) { if (first) {
first = false first = false
...@@ -152,13 +139,13 @@ object PartitionProgs { ...@@ -152,13 +139,13 @@ object PartitionProgs {
// Verify that filter operation after repartition picks up correct // Verify that filter operation after repartition picks up correct
// DOP // DOP
val env = ExecutionEnvironment.getExecutionEnvironment val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput) val ds = CollectionDataSets.get3TupleDataSet(env)
env.setDegreeOfParallelism(1) env.setDegreeOfParallelism(1)
val count = ds.partitionByHash(0).setParallelism(4).filter( val count = ds.partitionByHash(0).setParallelism(4).filter(
new RichFilterFunction[(Int, String)] { new RichFilterFunction[(Int, Long, String)] {
var first = true var first = true
override def filter(in: (Int, String)): Boolean = { override def filter(in: (Int, Long, String)): Boolean = {
// only output one value with count 1 // only output one value with count 1
if (first) { if (first) {
first = false first = false
...@@ -175,6 +162,19 @@ object PartitionProgs { ...@@ -175,6 +162,19 @@ object PartitionProgs {
if (onCollection) "(1)\n" else "(4)\n" if (onCollection) "(1)\n" else "(4)\n"
case 7 =>
val env = ExecutionEnvironment.getExecutionEnvironment
env.setDegreeOfParallelism(3)
val ds = CollectionDataSets.getDuplicatePojoDataSet(env)
val uniqLongs = ds
.partitionByHash("nestedPojo.longNumber")
.setParallelism(4)
.mapPartition( _.map(_.nestedPojo.longNumber).toSet )
uniqLongs.writeAsText(resultPath)
env.execute()
"10000\n" + "20000\n" + "30000\n"
case _ => case _ =>
throw new IllegalArgumentException("Invalid program id") throw new IllegalArgumentException("Invalid program id")
} }
...@@ -194,7 +194,7 @@ class PartitionITCase(config: Configuration) extends JavaProgramTestBase(config) ...@@ -194,7 +194,7 @@ class PartitionITCase(config: Configuration) extends JavaProgramTestBase(config)
} }
protected def testProgram(): Unit = { protected def testProgram(): Unit = {
expectedResult = PartitionProgs.runProgram(curProgId, resultPath, isCollectionExecution) expectedResult = GroupReduceProgs.runProgram(curProgId, resultPath, isCollectionExecution)
} }
protected override def postSubmit(): Unit = { protected override def postSubmit(): Unit = {
......
...@@ -44,7 +44,9 @@ class CustomType(var myField1: String, var myField2: Int) { ...@@ -44,7 +44,9 @@ class CustomType(var myField1: String, var myField2: Int) {
} }
} }
class MyObject[A](var a: A) class MyObject[A](var a: A) {
def this() { this(null.asInstanceOf[A]) }
}
class TypeInformationGenTest { class TypeInformationGenTest {
...@@ -139,7 +141,7 @@ class TypeInformationGenTest { ...@@ -139,7 +141,7 @@ class TypeInformationGenTest {
Assert.assertFalse(ti.isBasicType) Assert.assertFalse(ti.isBasicType)
Assert.assertFalse(ti.isTupleType) Assert.assertFalse(ti.isTupleType)
Assert.assertTrue(ti.isInstanceOf[GenericTypeInfo[_]]) Assert.assertTrue(ti.isInstanceOf[PojoTypeInfo[_]])
Assert.assertEquals(ti.getTypeClass, classOf[CustomType]) Assert.assertEquals(ti.getTypeClass, classOf[CustomType])
} }
...@@ -152,7 +154,7 @@ class TypeInformationGenTest { ...@@ -152,7 +154,7 @@ class TypeInformationGenTest {
val tti = ti.asInstanceOf[TupleTypeInfoBase[_]] val tti = ti.asInstanceOf[TupleTypeInfoBase[_]]
Assert.assertEquals(classOf[Tuple2[_, _]], tti.getTypeClass) Assert.assertEquals(classOf[Tuple2[_, _]], tti.getTypeClass)
Assert.assertEquals(classOf[java.lang.Long], tti.getTypeAt(0).getTypeClass) Assert.assertEquals(classOf[java.lang.Long], tti.getTypeAt(0).getTypeClass)
Assert.assertTrue(tti.getTypeAt(1).isInstanceOf[GenericTypeInfo[_]]) Assert.assertTrue(tti.getTypeAt(1).isInstanceOf[PojoTypeInfo[_]])
Assert.assertEquals(classOf[CustomType], tti.getTypeAt(1).getTypeClass) Assert.assertEquals(classOf[CustomType], tti.getTypeAt(1).getTypeClass)
} }
...@@ -235,7 +237,7 @@ class TypeInformationGenTest { ...@@ -235,7 +237,7 @@ class TypeInformationGenTest {
def testParamertizedCustomObject(): Unit = { def testParamertizedCustomObject(): Unit = {
val ti = createTypeInformation[MyObject[String]] val ti = createTypeInformation[MyObject[String]]
Assert.assertTrue(ti.isInstanceOf[GenericTypeInfo[_]]) Assert.assertTrue(ti.isInstanceOf[PojoTypeInfo[_]])
} }
@Test @Test
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.util
import org.apache.hadoop.io.IntWritable
import org.apache.flink.api.scala._
import scala.collection.mutable
import scala.util.Random
/**
* #################################################################################################
*
* BE AWARE THAT OTHER TESTS DEPEND ON THIS TEST DATA.
* IF YOU MODIFY THE DATA MAKE SURE YOU CHECK THAT ALL TESTS ARE STILL WORKING!
*
* #################################################################################################
*/
object CollectionDataSets {
def get3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = {
val data = new mutable.MutableList[(Int, Long, String)]
data.+=((1, 1L, "Hi"))
data.+=((2, 2L, "Hello"))
data.+=((3, 2L, "Hello world"))
data.+=((4, 3L, "Hello world, how are you?"))
data.+=((5, 3L, "I am fine."))
data.+=((6, 3L, "Luke Skywalker"))
data.+=((7, 4L, "Comment#1"))
data.+=((8, 4L, "Comment#2"))
data.+=((9, 4L, "Comment#3"))
data.+=((10, 4L, "Comment#4"))
data.+=((11, 5L, "Comment#5"))
data.+=((12, 5L, "Comment#6"))
data.+=((13, 5L, "Comment#7"))
data.+=((14, 5L, "Comment#8"))
data.+=((15, 5L, "Comment#9"))
data.+=((16, 6L, "Comment#10"))
data.+=((17, 6L, "Comment#11"))
data.+=((18, 6L, "Comment#12"))
data.+=((19, 6L, "Comment#13"))
data.+=((20, 6L, "Comment#14"))
data.+=((21, 6L, "Comment#15"))
Random.shuffle(data)
env.fromCollection(Random.shuffle(data))
}
def getSmall3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = {
val data = new mutable.MutableList[(Int, Long, String)]
data.+=((1, 1L, "Hi"))
data.+=((2, 2L, "Hello"))
data.+=((3, 2L, "Hello world"))
env.fromCollection(Random.shuffle(data))
}
def get5TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, Int, String, Long)] = {
val data = new mutable.MutableList[(Int, Long, Int, String, Long)]
data.+=((1, 1L, 0, "Hallo", 1L))
data.+=((2, 2L, 1, "Hallo Welt", 2L))
data.+=((2, 3L, 2, "Hallo Welt wie", 1L))
data.+=((3, 4L, 3, "Hallo Welt wie gehts?", 2L))
data.+=((3, 5L, 4, "ABC", 2L))
data.+=((3, 6L, 5, "BCD", 3L))
data.+=((4, 7L, 6, "CDE", 2L))
data.+=((4, 8L, 7, "DEF", 1L))
data.+=((4, 9L, 8, "EFG", 1L))
data.+=((4, 10L, 9, "FGH", 2L))
data.+=((5, 11L, 10, "GHI", 1L))
data.+=((5, 12L, 11, "HIJ", 3L))
data.+=((5, 13L, 12, "IJK", 3L))
data.+=((5, 14L, 13, "JKL", 2L))
data.+=((5, 15L, 14, "KLM", 2L))
env.fromCollection(Random.shuffle(data))
}
def getSmall5TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, Int, String, Long)] = {
val data = new mutable.MutableList[(Int, Long, Int, String, Long)]
data.+=((1, 1L, 0, "Hallo", 1L))
data.+=((2, 2L, 1, "Hallo Welt", 2L))
data.+=((2, 3L, 2, "Hallo Welt wie", 1L))
env.fromCollection(Random.shuffle(data))
}
def getSmallNestedTupleDataSet(env: ExecutionEnvironment): DataSet[((Int, Int), String)] = {
val data = new mutable.MutableList[((Int, Int), String)]
data.+=(((1, 1), "one"))
data.+=(((2, 2), "two"))
data.+=(((3, 3), "three"))
env.fromCollection(Random.shuffle(data))
}
def getGroupSortedNestedTupleDataSet(env: ExecutionEnvironment): DataSet[((Int, Int), String)] = {
val data = new mutable.MutableList[((Int, Int), String)]
data.+=(((1, 3), "a"))
data.+=(((1, 2), "a"))
data.+=(((2, 1), "a"))
data.+=(((2, 2), "b"))
data.+=(((3, 3), "c"))
data.+=(((3, 6), "c"))
data.+=(((4, 9), "c"))
env.fromCollection(Random.shuffle(data))
}
def getStringDataSet(env: ExecutionEnvironment): DataSet[String] = {
val data = new mutable.MutableList[String]
data.+=("Hi")
data.+=("Hello")
data.+=("Hello world")
data.+=("Hello world, how are you?")
data.+=("I am fine.")
data.+=("Luke Skywalker")
data.+=("Random comment")
data.+=("LOL")
env.fromCollection(Random.shuffle(data))
}
def getIntDataSet(env: ExecutionEnvironment): DataSet[Int] = {
val data = new mutable.MutableList[Int]
data.+=(1)
data.+=(2)
data.+=(2)
data.+=(3)
data.+=(3)
data.+=(3)
data.+=(4)
data.+=(4)
data.+=(4)
data.+=(4)
data.+=(5)
data.+=(5)
data.+=(5)
data.+=(5)
data.+=(5)
env.fromCollection(Random.shuffle(data))
}
def getCustomTypeDataSet(env: ExecutionEnvironment): DataSet[CustomType] = {
val data = new mutable.MutableList[CustomType]
data.+=(new CustomType(1, 0L, "Hi"))
data.+=(new CustomType(2, 1L, "Hello"))
data.+=(new CustomType(2, 2L, "Hello world"))
data.+=(new CustomType(3, 3L, "Hello world, how are you?"))
data.+=(new CustomType(3, 4L, "I am fine."))
data.+=(new CustomType(3, 5L, "Luke Skywalker"))
data.+=(new CustomType(4, 6L, "Comment#1"))
data.+=(new CustomType(4, 7L, "Comment#2"))
data.+=(new CustomType(4, 8L, "Comment#3"))
data.+=(new CustomType(4, 9L, "Comment#4"))
data.+=(new CustomType(5, 10L, "Comment#5"))
data.+=(new CustomType(5, 11L, "Comment#6"))
data.+=(new CustomType(5, 12L, "Comment#7"))
data.+=(new CustomType(5, 13L, "Comment#8"))
data.+=(new CustomType(5, 14L, "Comment#9"))
data.+=(new CustomType(6, 15L, "Comment#10"))
data.+=(new CustomType(6, 16L, "Comment#11"))
data.+=(new CustomType(6, 17L, "Comment#12"))
data.+=(new CustomType(6, 18L, "Comment#13"))
data.+=(new CustomType(6, 19L, "Comment#14"))
data.+=(new CustomType(6, 20L, "Comment#15"))
env.fromCollection(Random.shuffle(data))
}
def getSmallCustomTypeDataSet(env: ExecutionEnvironment): DataSet[CustomType] = {
val data = new mutable.MutableList[CustomType]
data.+=(new CustomType(1, 0L, "Hi"))
data.+=(new CustomType(2, 1L, "Hello"))
data.+=(new CustomType(2, 2L, "Hello world"))
env.fromCollection(Random.shuffle(data))
}
def getSmallTuplebasedPojoMatchingDataSet(env: ExecutionEnvironment):
DataSet[(Int, String, Int, Int, Long, String, Long)] = {
val data = new mutable.MutableList[(Int, String, Int, Int, Long, String, Long)]
data.+=((1, "First", 10, 100, 1000L, "One", 10000L))
data.+=((2, "Second", 20, 200, 2000L, "Two", 20000L))
data.+=((3, "Third", 30, 300, 3000L, "Three", 30000L))
env.fromCollection(Random.shuffle(data))
}
def getSmallPojoDataSet(env: ExecutionEnvironment): DataSet[POJO] = {
val data = new mutable.MutableList[POJO]
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10000L))
data.+=(new POJO(2, "Second", 20, 200, 2000L, "Two", 20000L))
data.+=(new POJO(3, "Third", 30, 300, 3000L, "Three", 30000L))
env.fromCollection(Random.shuffle(data))
}
def getDuplicatePojoDataSet(env: ExecutionEnvironment): DataSet[POJO] = {
val data = new mutable.MutableList[POJO]
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10000L))
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10000L))
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10000L))
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10000L))
data.+=(new POJO(1, "First", 10, 100, 1000L, "One", 10000L))
data.+=(new POJO(2, "Second", 20, 200, 2000L, "Two", 20000L))
data.+=(new POJO(3, "Third", 30, 300, 3000L, "Three", 30000L))
data.+=(new POJO(3, "Third", 30, 300, 3000L, "Three", 30000L))
env.fromCollection(data)
}
def getCrazyNestedDataSet(env: ExecutionEnvironment): DataSet[CrazyNested] = {
val data = new mutable.MutableList[CrazyNested]
data.+=(new CrazyNested("aa"))
data.+=(new CrazyNested("bb"))
data.+=(new CrazyNested("bb"))
data.+=(new CrazyNested("cc"))
data.+=(new CrazyNested("cc"))
data.+=(new CrazyNested("cc"))
env.fromCollection(data)
}
def getPojoContainingTupleAndWritable(env: ExecutionEnvironment): DataSet[CollectionDataSets
.PojoContainingTupleAndWritable] = {
val data = new
mutable.MutableList[PojoContainingTupleAndWritable]
data.+=(new PojoContainingTupleAndWritable(1, 10L, 100L))
data.+=(new PojoContainingTupleAndWritable(2, 20L, 200L))
data.+=(new PojoContainingTupleAndWritable(2, 20L, 200L))
data.+=(new PojoContainingTupleAndWritable(2, 20L, 200L))
data.+=(new PojoContainingTupleAndWritable(2, 20L, 200L))
data.+=(new PojoContainingTupleAndWritable(2, 20L, 200L))
env.fromCollection(data)
}
def getTupleContainingPojos(env: ExecutionEnvironment): DataSet[(Int, CrazyNested, POJO)] = {
val data = new mutable.MutableList[(Int, CrazyNested, POJO)]
data.+=((
1,
new CrazyNested("one", "uno", 1L),
new POJO(1, "First", 10, 100, 1000L, "One", 10000L)))
data.+=((
1,
new CrazyNested("one", "uno", 1L),
new POJO(1, "First", 10, 100, 1000L, "One", 10000L)))
data.+=((
1,
new CrazyNested("one", "uno", 1L),
new POJO(1, "First", 10, 100, 1000L, "One", 10000L)))
data.+=((
2,
new CrazyNested("two", "duo", 2L),
new POJO(1, "First", 10, 100, 1000L, "One", 10000L)))
env.fromCollection(data)
}
def getPojoWithMultiplePojos(env: ExecutionEnvironment): DataSet[CollectionDataSets
.PojoWithMultiplePojos] = {
val data = new mutable.MutableList[CollectionDataSets
.PojoWithMultiplePojos]
data.+=(new PojoWithMultiplePojos("a", "aa", "b", "bb", 1))
data.+=(new PojoWithMultiplePojos("b", "bb", "c", "cc", 2))
data.+=(new PojoWithMultiplePojos("d", "dd", "e", "ee", 3))
env.fromCollection(data)
}
class CustomType(var myInt: Int, var myLong: Long, var myString: String) {
def this() {
this(0, 0, "")
}
override def toString: String = {
myInt + "," + myLong + "," + myString
}
}
class POJO(
var number: Int,
var str: String,
var nestedTupleWithCustom: (Int, CustomType),
var nestedPojo: NestedPojo) {
def this() {
this(0, "", null, null)
}
def this(i0: Int, s0: String, i1: Int, i2: Int, l0: Long, s1: String, l1: Long) {
this(i0, s0, (i1, new CustomType(i2, l0, s1)), new NestedPojo(l1))
}
override def toString: String = {
number + " " + str + " " + nestedTupleWithCustom + " " + nestedPojo.longNumber
}
@transient var ignoreMe: Long = 1L
}
class NestedPojo(var longNumber: Long) {
def this() {
this(0)
}
}
class CrazyNested(var nest_Lvl1: CrazyNestedL1, var something: Long) {
def this() {
this(new CrazyNestedL1, 0)
}
def this(set: String) {
this()
nest_Lvl1 = new CrazyNestedL1
nest_Lvl1.nest_Lvl2 = new CrazyNestedL2
nest_Lvl1.nest_Lvl2.nest_Lvl3 = new CrazyNestedL3
nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4 = new CrazyNestedL4
nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal = set
}
def this(set: String, second: String, s: Long) {
this(set)
something = s
nest_Lvl1.a = second
}
}
class CrazyNestedL1 {
var a: String = null
var b: Int = 0
var nest_Lvl2: CrazyNestedL2 = null
}
class CrazyNestedL2 {
var nest_Lvl3: CrazyNestedL3 = null
}
class CrazyNestedL3 {
var nest_Lvl4: CrazyNestedL4 = null
}
class CrazyNestedL4 {
var f1nal: String = null
}
class PojoContainingTupleAndWritable(
var someInt: Int,
var someString: String,
var hadoopFan: IntWritable,
var theTuple: (Long, Long)) {
def this() {
this(0, "", new IntWritable(0), (0, 0))
}
def this(i: Int, l1: Long, l2: Long) {
this()
hadoopFan = new IntWritable(i)
someInt = i
theTuple = (l1, l2)
}
}
class Pojo1 {
var a: String = null
var b: String = null
}
class Pojo2 {
var a2: String = null
var b2: String = null
}
class PojoWithMultiplePojos {
def this(a: String, b: String, a1: String, b1: String, i0: Int) {
this()
p1 = new Pojo1
p1.a = a
p1.b = b
p2 = new Pojo2
p2.a2 = a1
p2.a2 = b1
this.i0 = i0
}
var p1: Pojo1 = null
var p2: Pojo2 = null
var i0: Int = 0
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册