diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala index 6e32e1a3b76c767426f2dca7111c2b46902029db..6ad73a5f5458cd261e9d86139af7d3b31bdf84a8 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala @@ -17,21 +17,13 @@ */ package org.apache.flink.api.scala.codegen +import org.apache.flink.types.{BooleanValue, ByteValue, CharValue, DoubleValue, FloatValue, IntValue, LongValue, ShortValue, StringValue} + import scala.collection._ import scala.collection.generic.CanBuildFrom import scala.reflect.macros.Context import scala.util.DynamicVariable -import org.apache.flink.types.BooleanValue -import org.apache.flink.types.ByteValue -import org.apache.flink.types.CharValue -import org.apache.flink.types.DoubleValue -import org.apache.flink.types.FloatValue -import org.apache.flink.types.IntValue -import org.apache.flink.types.StringValue -import org.apache.flink.types.LongValue -import org.apache.flink.types.ShortValue - private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] with TypeDescriptors[C] => @@ -83,6 +75,8 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] case TraitType() => GenericClassDescriptor(id, tpe) + case JavaTupleType() => analyzeJavaTuple(id, tpe) + case JavaType() => // It's a Java Class, let the TypeExtractor deal with it... GenericClassDescriptor(id, tpe) @@ -136,6 +130,20 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] case elemDesc => OptionDescriptor(id, tpe, elemDesc) } + private def analyzeJavaTuple(id: Int, tpe: Type): UDTDescriptor = { + // check how many tuple fields we have and determine type + val fields = (0 until org.apache.flink.api.java.tuple.Tuple.MAX_ARITY ) flatMap { i => + tpe.members find { m => m.name.toString.equals("f" + i)} match { + case Some(m) => Some(analyze(m.typeSignatureIn(tpe))) + + case _ => None + } + } + + JavaTupleDescriptor(id, tpe, fields) + } + + private def analyzePojo(id: Int, tpe: Type): UDTDescriptor = { val immutableFields = tpe.members filter { _.isTerm } map { _.asTerm } filter { _.isVal } if (immutableFields.nonEmpty) { @@ -150,6 +158,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] .filter { _.isTerm } .map { _.asTerm } .filter { _.isVar } + .filter { !_.isStatic } .filterNot { _.annotations.exists( _.tpe <:< typeOf[scala.transient]) } if (fields.isEmpty) { @@ -188,7 +197,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] val fieldDescriptors = fields map { f => - val fieldTpe = f.getter.asMethod.returnType.asSeenFrom(tpe, tpe.typeSymbol) + val fieldTpe = f.typeSignatureIn(tpe) FieldDescriptor(f.name.toString.trim, f.getter, f.setter, fieldTpe, analyze(fieldTpe)) } @@ -391,6 +400,10 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isJava } + private object JavaTupleType { + def unapply(tpe: Type): Boolean = tpe <:< typeOf[org.apache.flink.api.java.tuple.Tuple] + } + private class UDTAnalyzerCache { private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map()) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala index fcb03174529119509711ca5c1012042b15480d0d..c6006a2a6d8dbe50407a2f0efa0ba25f72409179 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala @@ -32,97 +32,36 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] abstract sealed class UDTDescriptor { val id: Int val tpe: Type - val isPrimitiveProduct: Boolean = false - - def canBeKey: Boolean - - def flatten: Seq[UDTDescriptor] - def getters: Seq[FieldDescriptor] = Seq() - - def select(member: String): Option[UDTDescriptor] = - getters find { _.getter.name.toString == member } map { _.desc } - - def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match { - case Nil => Seq(Some(this)) - case head :: tail => getters find { _.getter.name.toString == head } match { - case None => Seq(None) - case Some(d : FieldDescriptor) => d.desc.select(tail) - } - } - - def findById(id: Int): Option[UDTDescriptor] = flatten.find { _.id == id } - - def findByType[T <: UDTDescriptor: ClassTag]: Seq[T] = { - val clazz = classTag[T].runtimeClass - flatten filter { item => clazz.isAssignableFrom(item.getClass) } map { _.asInstanceOf[T] } - } - - def getRecursiveRefs: Seq[UDTDescriptor] = - findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.distinct } - case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor { - override def flatten = Seq(this) - - def canBeKey = false - } + case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor - case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor { - override def flatten = Seq(this) - - def canBeKey = tpe <:< typeOf[Comparable[_]] - } + case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor - case class TypeParameterDescriptor(id: Int, tpe: Type) extends UDTDescriptor { - override val isPrimitiveProduct = false - override def flatten = Seq(this) - override def canBeKey = false - } + case class TypeParameterDescriptor(id: Int, tpe: Type) extends UDTDescriptor case class PrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type) - extends UDTDescriptor { - override val isPrimitiveProduct = true - override def flatten = Seq(this) - override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]] - } + extends UDTDescriptor - case class NothingDesciptor(id: Int, tpe: Type) - extends UDTDescriptor { - override val isPrimitiveProduct = false - override def flatten = Seq(this) - override def canBeKey = false - } + case class NothingDesciptor(id: Int, tpe: Type) extends UDTDescriptor case class EitherDescriptor(id: Int, tpe: Type, left: UDTDescriptor, right: UDTDescriptor) - extends UDTDescriptor { - override val isPrimitiveProduct = false - override def flatten = Seq(this) - override def canBeKey = false - } + extends UDTDescriptor - case class TryDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) - extends UDTDescriptor { - override val isPrimitiveProduct = false - override def flatten = Seq(this) - override def canBeKey = false - } + case class TryDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor - case class OptionDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) - extends UDTDescriptor { - override val isPrimitiveProduct = false - override def flatten = Seq(this) - override def canBeKey = false - } + case class OptionDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor case class BoxedPrimitiveDescriptor( - id: Int, tpe: Type, default: Literal, wrapper: Type, box: Tree => Tree, unbox: Tree => Tree) - extends UDTDescriptor { - - override val isPrimitiveProduct = true - override def flatten = Seq(this) - override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]] + id: Int, + tpe: Type, + default: Literal, + wrapper: Type, + box: Tree => Tree, + unbox: Tree => Tree) extends UDTDescriptor { override def hashCode() = (id, tpe, default, wrapper, "BoxedPrimitiveDescriptor").hashCode() + override def equals(that: Any) = that match { case BoxedPrimitiveDescriptor(thatId, thatTpe, thatDefault, thatWrapper, _, _) => (id, tpe, default, wrapper).equals(thatId, thatTpe, thatDefault, thatWrapper) @@ -130,12 +69,10 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] } } - case class ArrayDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) - extends UDTDescriptor { - override def canBeKey = false - override def flatten = this +: elem.flatten + case class ArrayDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor { override def hashCode() = (id, tpe, elem).hashCode() + override def equals(that: Any) = that match { case that @ ArrayDescriptor(thatId, thatTpe, thatElem) => (id, tpe, elem).equals((thatId, thatTpe, thatElem)) @@ -143,17 +80,15 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] } } - case class TraversableDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) - extends UDTDescriptor { - override def canBeKey = false - override def flatten = this +: elem.flatten + case class TraversableDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor { - def getInnermostElem: UDTDescriptor = elem match { - case list: TraversableDescriptor => list.getInnermostElem - case _ => elem - } +// def getInnermostElem: UDTDescriptor = elem match { +// case list: TraversableDescriptor => list.getInnermostElem +// case _ => elem +// } override def hashCode() = (id, tpe, elem).hashCode() + override def equals(that: Any) = that match { case that @ TraversableDescriptor(thatId, thatTpe, thatElem) => (id, tpe, elem).equals((thatId, thatTpe, thatElem)) @@ -161,19 +96,14 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] } } - case class PojoDescriptor(id: Int, tpe: Type, override val getters: Seq[FieldDescriptor]) + case class PojoDescriptor(id: Int, tpe: Type, getters: Seq[FieldDescriptor]) extends UDTDescriptor { - override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct) - - override def flatten = this +: (getters flatMap { _.desc.flatten }) - - override def canBeKey = flatten forall { f => f.canBeKey } - // Hack: ignore the ctorTpe, since two Type instances representing // the same ctor function type don't appear to be considered equal. // Equality of the tpe and ctor fields implies equality of ctorTpe anyway. override def hashCode = (id, tpe, getters).hashCode + override def equals(that: Any) = that match { case PojoDescriptor(thatId, thatTpe, thatGetters) => (id, tpe, getters).equals( @@ -181,13 +111,6 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] case _ => false } - override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match { - case Nil => getters flatMap { g => g.desc.select(Nil) } - case head :: tail => getters find { _.getter.name.toString == head } match { - case None => Seq(None) - case Some(d : FieldDescriptor) => d.desc.select(tail) - } - } } case class CaseClassDescriptor( @@ -195,33 +118,20 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] tpe: Type, mutable: Boolean, ctor: Symbol, - override val getters: Seq[FieldDescriptor]) - extends UDTDescriptor { - - override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct) - - override def flatten = this +: (getters flatMap { _.desc.flatten }) - - override def canBeKey = flatten forall { f => f.canBeKey } + getters: Seq[FieldDescriptor]) extends UDTDescriptor { // Hack: ignore the ctorTpe, since two Type instances representing // the same ctor function type don't appear to be considered equal. // Equality of the tpe and ctor fields implies equality of ctorTpe anyway. override def hashCode = (id, tpe, ctor, getters).hashCode + override def equals(that: Any) = that match { case CaseClassDescriptor(thatId, thatTpe, thatMutable, thatCtor, thatGetters) => (id, tpe, mutable, ctor, getters).equals( thatId, thatTpe, thatMutable, thatCtor, thatGetters) case _ => false } - - override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match { - case Nil => getters flatMap { g => g.desc.select(Nil) } - case head :: tail => getters find { _.getter.name.toString == head } match { - case None => Seq(None) - case Some(d : FieldDescriptor) => d.desc.select(tail) - } - } + } case class FieldDescriptor( @@ -231,21 +141,30 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] tpe: Type, desc: UDTDescriptor) - case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor { - override def flatten = Seq(this) - override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]] - } - - case class ValueDescriptor(id: Int, tpe: Type) extends UDTDescriptor { - override val isPrimitiveProduct = true - override def flatten = Seq(this) - override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]] - } + case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor + + case class ValueDescriptor(id: Int, tpe: Type) extends UDTDescriptor + + case class WritableDescriptor(id: Int, tpe: Type) extends UDTDescriptor + + case class JavaTupleDescriptor( + id: Int, + tpe: Type, + fields: Seq[UDTDescriptor]) + extends UDTDescriptor { + + // Hack: ignore the ctorTpe, since two Type instances representing + // the same ctor function type don't appear to be considered equal. + // Equality of the tpe and ctor fields implies equality of ctorTpe anyway. + override def hashCode = (id, tpe, fields).hashCode + + override def equals(that: Any) = that match { + case JavaTupleDescriptor(thatId, thatTpe, thatFields) => + (id, tpe, fields).equals( + thatId, thatTpe, thatFields) + case _ => false + } - case class WritableDescriptor(id: Int, tpe: Type) extends UDTDescriptor { - override val isPrimitiveProduct = true - override def flatten = Seq(this) - override def canBeKey = tpe <:< typeOf[org.apache.hadoop.io.WritableComparable[_]] } } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala index 0a8874b959798868148726e09a2de338facc1ecb..a6fbb71e4feaa4235b32c670f49afcbfe6511f76 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala @@ -23,6 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.typeinfo._ import org.apache.flink.api.common.typeutils._ +import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.api.java.typeutils._ import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo} import org.apache.flink.types.Value @@ -83,6 +84,8 @@ private[flink] trait TypeInformationGen[C <: Context] { case pojo: PojoDescriptor => mkPojo(pojo) + case javaTuple: JavaTupleDescriptor => mkJavaTuple(javaTuple) + case d => mkGenericTypeInfo(d) } @@ -275,6 +278,23 @@ private[flink] trait TypeInformationGen[C <: Context] { } } + def mkJavaTuple[T: c.WeakTypeTag](desc: JavaTupleDescriptor): c.Expr[TypeInformation[T]] = { + + val fieldsTrees = desc.fields map { f => mkTypeInfo(f)(c.WeakTypeTag(f.tpe)).tree } + + val fieldsList = c.Expr[List[TypeInformation[_]]](mkList(fieldsTrees.toList)) + + val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) + + reify { + val fields = fieldsList.splice + val clazz = tpeClazz.splice.asInstanceOf[Class[org.apache.flink.api.java.tuple.Tuple]] + new TupleTypeInfo[org.apache.flink.api.java.tuple.Tuple](clazz, fields: _*) + .asInstanceOf[TypeInformation[T]] + } + } + + def mkPojo[T: c.WeakTypeTag](desc: PojoDescriptor): c.Expr[TypeInformation[T]] = { val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) val fieldsTrees = desc.getters map { @@ -296,7 +316,7 @@ private[flink] trait TypeInformationGen[C <: Context] { var error = false while (traversalClazz != null) { for (field <- traversalClazz.getDeclaredFields) { - if (clazzFields.contains(field.getName)) { + if (clazzFields.contains(field.getName) && !Modifier.isStatic(field.getModifiers)) { println(s"The field $field is already contained in the " + s"hierarchy of the class $clazz. Please use unique field names throughout " + "your class hierarchy") diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala index 3d1ec0c2b8abe93ed347d32ad40f21274d495923..8460bbcb2e4083e90e66804ab06b02402aa15ab6 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala @@ -19,6 +19,7 @@ package org.apache.flink.api.scala.types import java.io.{DataInput, DataOutput} +import org.apache.flink.api.java.`type`.extractor.TypeExtractorTest.CustomTuple import org.apache.hadoop.io.Writable import org.junit.{Assert, Test} @@ -52,6 +53,41 @@ class MyObject[A](var a: A) { class TypeInformationGenTest { + @Test + def testJavaTuple(): Unit = { + val ti = createTypeInformation[org.apache.flink.api.java.tuple.Tuple3[Int, String, Integer]] + + Assert.assertTrue(ti.isTupleType) + Assert.assertEquals(3, ti.getArity) + Assert.assertTrue(ti.isInstanceOf[TupleTypeInfoBase[_]]) + val tti = ti.asInstanceOf[TupleTypeInfoBase[_]] + Assert.assertEquals(classOf[org.apache.flink.api.java.tuple.Tuple3[_, _, _]], tti.getTypeClass) + for (i <- 0 until 3) { + Assert.assertTrue(tti.getTypeAt(i).isInstanceOf[BasicTypeInfo[_]]) + } + + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, tti.getTypeAt(0)) + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, tti.getTypeAt(1)) + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, tti.getTypeAt(2)) + } + + @Test + def testCustomJavaTuple(): Unit = { + val ti = createTypeInformation[CustomTuple] + + Assert.assertTrue(ti.isTupleType) + Assert.assertEquals(2, ti.getArity) + Assert.assertTrue(ti.isInstanceOf[TupleTypeInfoBase[_]]) + val tti = ti.asInstanceOf[TupleTypeInfoBase[_]] + Assert.assertEquals(classOf[CustomTuple], tti.getTypeClass) + for (i <- 0 until 2) { + Assert.assertTrue(tti.getTypeAt(i).isInstanceOf[BasicTypeInfo[_]]) + } + + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, tti.getTypeAt(0)) + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, tti.getTypeAt(1)) + } + @Test def testBasicType(): Unit = { val ti = createTypeInformation[Boolean] @@ -162,7 +198,7 @@ class TypeInformationGenTest { Assert.assertTrue(ti.isInstanceOf[TupleTypeInfoBase[_]]) val tti = ti.asInstanceOf[TupleTypeInfoBase[_]] Assert.assertEquals(classOf[Tuple9[_,_,_,_,_,_,_,_,_]], tti.getTypeClass) - for (i <- 0 until 0) { + for (i <- 0 until 9) { Assert.assertTrue(tti.getTypeAt(i).isInstanceOf[BasicTypeInfo[_]]) }