提交 2b82578a 编写于 作者: T twalthr

[FLINK-7596] [table] Restrict equality and improve tests

上级 62ebda3f
......@@ -19,7 +19,6 @@
package org.apache.flink.table.calcite
import java.util
import java.util.List
import org.apache.calcite.avatica.util.TimeUnit
import org.apache.calcite.jdbc.JavaTypeFactoryImpl
......@@ -248,39 +247,35 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp
canonize(newType)
}
private def resolveAnySqlType(types: java.util.List[RelDataType]): RelDataType = {
val hasAny = types.asScala.map(_.getSqlTypeName).exists(_ == SqlTypeName.ANY)
val nullable = types.asScala.exists(
sqlType => sqlType.isNullable || sqlType.getSqlTypeName == SqlTypeName.NULL
)
if (hasAny) {
if (types.get(0).isInstanceOf[GenericRelDataType] &&
types.get(1).isInstanceOf[GenericRelDataType]) {
createTypeWithNullability(types.get(0), nullable)
} else {
throw new RuntimeException("only GenericRelDataType of ANY is supported")
}
} else {
null
}
}
override def leastRestrictive(types: util.List[RelDataType]): RelDataType = {
assert(types != null)
assert(types.size >= 1)
val type0 = types.get(0)
if (type0.getSqlTypeName != null) {
val resultType = resolveAnySqlType(types)
val resultType = resolveAny(types)
if (resultType != null) {
resultType
return resultType
}
}
super.leastRestrictive(types)
}
private def resolveAny(types: util.List[RelDataType]): RelDataType = {
val allTypes = types.asScala
val hasAny = allTypes.exists(_.getSqlTypeName == SqlTypeName.ANY)
if (hasAny) {
val head = allTypes.head
// only allow ANY with exactly the same GenericRelDataType for all types
if (allTypes.forall(_ == head)) {
val nullable = allTypes.exists(
sqlType => sqlType.isNullable || sqlType.getSqlTypeName == SqlTypeName.NULL
)
createTypeWithNullability(head, nullable)
} else {
super.leastRestrictive(types)
throw TableException("Generic ANY types must have a common type information.")
}
} else {
super.leastRestrictive(types)
null
}
}
}
object FlinkTypeFactory {
......
......@@ -73,18 +73,24 @@ class SetOperatorsITCase extends StreamingMultipleProgramsTestBase {
@Test
def testUnionWithAnyType(): Unit = {
val list = List((1, new NODE), (2, new NODE))
val list2 = List((3, new NODE), (4, new NODE))
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
val s1 = tEnv.fromDataStream(env.fromCollection(list))
val s2 = tEnv.fromDataStream(env.fromCollection(list2))
StreamITCase.testResults = mutable.MutableList()
val s1 = env.fromElements((1, new NonPojo), (2, new NonPojo)).toTable(tEnv, 'a, 'b)
val s2 = env.fromElements((3, new NonPojo), (4, new NonPojo)).toTable(tEnv, 'a, 'b)
val result = s1.unionAll(s2).toAppendStream[Row]
result.addSink(new StreamITCase.StringSink[Row])
env.execute()
val expected = mutable.MutableList("1,{}", "2,{}", "3,{}", "4,{}")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
class NODE {
class NonPojo {
val x = new java.util.HashMap[String, String]()
override def toString: String = x.toString
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册