提交 f8a381f9 编写于 作者: A Aljoscha Krettek 提交者: Stephan Ewen

[FLINK-2023] [scala] Improve type analysis

 - Exclude static fields in Scala Pojo analysis
 - Recognize Java Tuples
 - Clean up legacy code

(And also make one of the field type retrieval methods nicer)

This closes #669
上级 b92ff121
......@@ -17,21 +17,13 @@
*/
package org.apache.flink.api.scala.codegen
import org.apache.flink.types.{BooleanValue, ByteValue, CharValue, DoubleValue, FloatValue, IntValue, LongValue, ShortValue, StringValue}
import scala.collection._
import scala.collection.generic.CanBuildFrom
import scala.reflect.macros.Context
import scala.util.DynamicVariable
import org.apache.flink.types.BooleanValue
import org.apache.flink.types.ByteValue
import org.apache.flink.types.CharValue
import org.apache.flink.types.DoubleValue
import org.apache.flink.types.FloatValue
import org.apache.flink.types.IntValue
import org.apache.flink.types.StringValue
import org.apache.flink.types.LongValue
import org.apache.flink.types.ShortValue
private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
with TypeDescriptors[C] =>
......@@ -83,6 +75,8 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
case TraitType() => GenericClassDescriptor(id, tpe)
case JavaTupleType() => analyzeJavaTuple(id, tpe)
case JavaType() =>
// It's a Java Class, let the TypeExtractor deal with it...
GenericClassDescriptor(id, tpe)
......@@ -136,6 +130,20 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
case elemDesc => OptionDescriptor(id, tpe, elemDesc)
}
private def analyzeJavaTuple(id: Int, tpe: Type): UDTDescriptor = {
// check how many tuple fields we have and determine type
val fields = (0 until org.apache.flink.api.java.tuple.Tuple.MAX_ARITY ) flatMap { i =>
tpe.members find { m => m.name.toString.equals("f" + i)} match {
case Some(m) => Some(analyze(m.typeSignatureIn(tpe)))
case _ => None
}
}
JavaTupleDescriptor(id, tpe, fields)
}
private def analyzePojo(id: Int, tpe: Type): UDTDescriptor = {
val immutableFields = tpe.members filter { _.isTerm } map { _.asTerm } filter { _.isVal }
if (immutableFields.nonEmpty) {
......@@ -150,6 +158,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
.filter { _.isTerm }
.map { _.asTerm }
.filter { _.isVar }
.filter { !_.isStatic }
.filterNot { _.annotations.exists( _.tpe <:< typeOf[scala.transient]) }
if (fields.isEmpty) {
......@@ -188,7 +197,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
val fieldDescriptors = fields map {
f =>
val fieldTpe = f.getter.asMethod.returnType.asSeenFrom(tpe, tpe.typeSymbol)
val fieldTpe = f.typeSignatureIn(tpe)
FieldDescriptor(f.name.toString.trim, f.getter, f.setter, fieldTpe, analyze(fieldTpe))
}
......@@ -391,6 +400,10 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isJava
}
private object JavaTupleType {
def unapply(tpe: Type): Boolean = tpe <:< typeOf[org.apache.flink.api.java.tuple.Tuple]
}
private class UDTAnalyzerCache {
private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map())
......
......@@ -32,97 +32,36 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
abstract sealed class UDTDescriptor {
val id: Int
val tpe: Type
val isPrimitiveProduct: Boolean = false
def canBeKey: Boolean
def flatten: Seq[UDTDescriptor]
def getters: Seq[FieldDescriptor] = Seq()
def select(member: String): Option[UDTDescriptor] =
getters find { _.getter.name.toString == member } map { _.desc }
def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => Seq(Some(this))
case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None)
case Some(d : FieldDescriptor) => d.desc.select(tail)
}
}
def findById(id: Int): Option[UDTDescriptor] = flatten.find { _.id == id }
def findByType[T <: UDTDescriptor: ClassTag]: Seq[T] = {
val clazz = classTag[T].runtimeClass
flatten filter { item => clazz.isAssignableFrom(item.getClass) } map { _.asInstanceOf[T] }
}
def getRecursiveRefs: Seq[UDTDescriptor] =
findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.distinct
}
case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
override def flatten = Seq(this)
def canBeKey = false
}
case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor
case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor {
override def flatten = Seq(this)
def canBeKey = tpe <:< typeOf[Comparable[_]]
}
case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor
case class TypeParameterDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}
case class TypeParameterDescriptor(id: Int, tpe: Type) extends UDTDescriptor
case class PrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type)
extends UDTDescriptor {
override val isPrimitiveProduct = true
override def flatten = Seq(this)
override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
}
extends UDTDescriptor
case class NothingDesciptor(id: Int, tpe: Type)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}
case class NothingDesciptor(id: Int, tpe: Type) extends UDTDescriptor
case class EitherDescriptor(id: Int, tpe: Type, left: UDTDescriptor, right: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}
extends UDTDescriptor
case class TryDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}
case class TryDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor
case class OptionDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}
case class OptionDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor
case class BoxedPrimitiveDescriptor(
id: Int, tpe: Type, default: Literal, wrapper: Type, box: Tree => Tree, unbox: Tree => Tree)
extends UDTDescriptor {
override val isPrimitiveProduct = true
override def flatten = Seq(this)
override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
id: Int,
tpe: Type,
default: Literal,
wrapper: Type,
box: Tree => Tree,
unbox: Tree => Tree) extends UDTDescriptor {
override def hashCode() = (id, tpe, default, wrapper, "BoxedPrimitiveDescriptor").hashCode()
override def equals(that: Any) = that match {
case BoxedPrimitiveDescriptor(thatId, thatTpe, thatDefault, thatWrapper, _, _) =>
(id, tpe, default, wrapper).equals(thatId, thatTpe, thatDefault, thatWrapper)
......@@ -130,12 +69,10 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
}
}
case class ArrayDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override def canBeKey = false
override def flatten = this +: elem.flatten
case class ArrayDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor {
override def hashCode() = (id, tpe, elem).hashCode()
override def equals(that: Any) = that match {
case that @ ArrayDescriptor(thatId, thatTpe, thatElem) =>
(id, tpe, elem).equals((thatId, thatTpe, thatElem))
......@@ -143,17 +80,15 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
}
}
case class TraversableDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override def canBeKey = false
override def flatten = this +: elem.flatten
case class TraversableDescriptor(id: Int, tpe: Type, elem: UDTDescriptor) extends UDTDescriptor {
def getInnermostElem: UDTDescriptor = elem match {
case list: TraversableDescriptor => list.getInnermostElem
case _ => elem
}
// def getInnermostElem: UDTDescriptor = elem match {
// case list: TraversableDescriptor => list.getInnermostElem
// case _ => elem
// }
override def hashCode() = (id, tpe, elem).hashCode()
override def equals(that: Any) = that match {
case that @ TraversableDescriptor(thatId, thatTpe, thatElem) =>
(id, tpe, elem).equals((thatId, thatTpe, thatElem))
......@@ -161,19 +96,14 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
}
}
case class PojoDescriptor(id: Int, tpe: Type, override val getters: Seq[FieldDescriptor])
case class PojoDescriptor(id: Int, tpe: Type, getters: Seq[FieldDescriptor])
extends UDTDescriptor {
override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct)
override def flatten = this +: (getters flatMap { _.desc.flatten })
override def canBeKey = flatten forall { f => f.canBeKey }
// Hack: ignore the ctorTpe, since two Type instances representing
// the same ctor function type don't appear to be considered equal.
// Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
override def hashCode = (id, tpe, getters).hashCode
override def equals(that: Any) = that match {
case PojoDescriptor(thatId, thatTpe, thatGetters) =>
(id, tpe, getters).equals(
......@@ -181,13 +111,6 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
case _ => false
}
override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => getters flatMap { g => g.desc.select(Nil) }
case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None)
case Some(d : FieldDescriptor) => d.desc.select(tail)
}
}
}
case class CaseClassDescriptor(
......@@ -195,33 +118,20 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
tpe: Type,
mutable: Boolean,
ctor: Symbol,
override val getters: Seq[FieldDescriptor])
extends UDTDescriptor {
override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct)
override def flatten = this +: (getters flatMap { _.desc.flatten })
override def canBeKey = flatten forall { f => f.canBeKey }
getters: Seq[FieldDescriptor]) extends UDTDescriptor {
// Hack: ignore the ctorTpe, since two Type instances representing
// the same ctor function type don't appear to be considered equal.
// Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
override def hashCode = (id, tpe, ctor, getters).hashCode
override def equals(that: Any) = that match {
case CaseClassDescriptor(thatId, thatTpe, thatMutable, thatCtor, thatGetters) =>
(id, tpe, mutable, ctor, getters).equals(
thatId, thatTpe, thatMutable, thatCtor, thatGetters)
case _ => false
}
override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
case Nil => getters flatMap { g => g.desc.select(Nil) }
case head :: tail => getters find { _.getter.name.toString == head } match {
case None => Seq(None)
case Some(d : FieldDescriptor) => d.desc.select(tail)
}
}
}
case class FieldDescriptor(
......@@ -231,21 +141,30 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
tpe: Type,
desc: UDTDescriptor)
case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor {
override def flatten = Seq(this)
override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]]
}
case class ValueDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
override val isPrimitiveProduct = true
override def flatten = Seq(this)
override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]]
}
case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor
case class ValueDescriptor(id: Int, tpe: Type) extends UDTDescriptor
case class WritableDescriptor(id: Int, tpe: Type) extends UDTDescriptor
case class JavaTupleDescriptor(
id: Int,
tpe: Type,
fields: Seq[UDTDescriptor])
extends UDTDescriptor {
// Hack: ignore the ctorTpe, since two Type instances representing
// the same ctor function type don't appear to be considered equal.
// Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
override def hashCode = (id, tpe, fields).hashCode
override def equals(that: Any) = that match {
case JavaTupleDescriptor(thatId, thatTpe, thatFields) =>
(id, tpe, fields).equals(
thatId, thatTpe, thatFields)
case _ => false
}
case class WritableDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
override val isPrimitiveProduct = true
override def flatten = Seq(this)
override def canBeKey = tpe <:< typeOf[org.apache.hadoop.io.WritableComparable[_]]
}
}
......@@ -23,6 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.common.typeutils._
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils._
import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo}
import org.apache.flink.types.Value
......@@ -83,6 +84,8 @@ private[flink] trait TypeInformationGen[C <: Context] {
case pojo: PojoDescriptor => mkPojo(pojo)
case javaTuple: JavaTupleDescriptor => mkJavaTuple(javaTuple)
case d => mkGenericTypeInfo(d)
}
......@@ -275,6 +278,23 @@ private[flink] trait TypeInformationGen[C <: Context] {
}
}
def mkJavaTuple[T: c.WeakTypeTag](desc: JavaTupleDescriptor): c.Expr[TypeInformation[T]] = {
val fieldsTrees = desc.fields map { f => mkTypeInfo(f)(c.WeakTypeTag(f.tpe)).tree }
val fieldsList = c.Expr[List[TypeInformation[_]]](mkList(fieldsTrees.toList))
val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
reify {
val fields = fieldsList.splice
val clazz = tpeClazz.splice.asInstanceOf[Class[org.apache.flink.api.java.tuple.Tuple]]
new TupleTypeInfo[org.apache.flink.api.java.tuple.Tuple](clazz, fields: _*)
.asInstanceOf[TypeInformation[T]]
}
}
def mkPojo[T: c.WeakTypeTag](desc: PojoDescriptor): c.Expr[TypeInformation[T]] = {
val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
val fieldsTrees = desc.getters map {
......@@ -296,7 +316,7 @@ private[flink] trait TypeInformationGen[C <: Context] {
var error = false
while (traversalClazz != null) {
for (field <- traversalClazz.getDeclaredFields) {
if (clazzFields.contains(field.getName)) {
if (clazzFields.contains(field.getName) && !Modifier.isStatic(field.getModifiers)) {
println(s"The field $field is already contained in the " +
s"hierarchy of the class $clazz. Please use unique field names throughout " +
"your class hierarchy")
......
......@@ -19,6 +19,7 @@ package org.apache.flink.api.scala.types
import java.io.{DataInput, DataOutput}
import org.apache.flink.api.java.`type`.extractor.TypeExtractorTest.CustomTuple
import org.apache.hadoop.io.Writable
import org.junit.{Assert, Test}
......@@ -52,6 +53,41 @@ class MyObject[A](var a: A) {
class TypeInformationGenTest {
@Test
def testJavaTuple(): Unit = {
val ti = createTypeInformation[org.apache.flink.api.java.tuple.Tuple3[Int, String, Integer]]
Assert.assertTrue(ti.isTupleType)
Assert.assertEquals(3, ti.getArity)
Assert.assertTrue(ti.isInstanceOf[TupleTypeInfoBase[_]])
val tti = ti.asInstanceOf[TupleTypeInfoBase[_]]
Assert.assertEquals(classOf[org.apache.flink.api.java.tuple.Tuple3[_, _, _]], tti.getTypeClass)
for (i <- 0 until 3) {
Assert.assertTrue(tti.getTypeAt(i).isInstanceOf[BasicTypeInfo[_]])
}
Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, tti.getTypeAt(0))
Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, tti.getTypeAt(1))
Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, tti.getTypeAt(2))
}
@Test
def testCustomJavaTuple(): Unit = {
val ti = createTypeInformation[CustomTuple]
Assert.assertTrue(ti.isTupleType)
Assert.assertEquals(2, ti.getArity)
Assert.assertTrue(ti.isInstanceOf[TupleTypeInfoBase[_]])
val tti = ti.asInstanceOf[TupleTypeInfoBase[_]]
Assert.assertEquals(classOf[CustomTuple], tti.getTypeClass)
for (i <- 0 until 2) {
Assert.assertTrue(tti.getTypeAt(i).isInstanceOf[BasicTypeInfo[_]])
}
Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, tti.getTypeAt(0))
Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, tti.getTypeAt(1))
}
@Test
def testBasicType(): Unit = {
val ti = createTypeInformation[Boolean]
......@@ -162,7 +198,7 @@ class TypeInformationGenTest {
Assert.assertTrue(ti.isInstanceOf[TupleTypeInfoBase[_]])
val tti = ti.asInstanceOf[TupleTypeInfoBase[_]]
Assert.assertEquals(classOf[Tuple9[_,_,_,_,_,_,_,_,_]], tti.getTypeClass)
for (i <- 0 until 0) {
for (i <- 0 until 9) {
Assert.assertTrue(tti.getTypeAt(i).isInstanceOf[BasicTypeInfo[_]])
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册