From 6e118d1dc97b3a8c0b013d2002fad80219751253 Mon Sep 17 00:00:00 2001 From: Fabian Hueske Date: Thu, 2 Nov 2017 21:10:03 +0100 Subject: [PATCH] [FLINK-6226] [table] Add tests for UDFs with Byte, Short, and Float arguments. --- .../UserDefinedScalarFunctionTest.scala | 28 +++++++++++++++++-- .../utils/userDefinedScalarFunctions.scala | 6 ++++ .../runtime/batch/table/CorrelateITCase.scala | 23 ++++++++++++++- .../utils/UserDefinedTableFunctions.scala | 12 +++++++- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala index 71ff70d1d04..a3b2f07b3b5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala @@ -47,6 +47,24 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "Func1(f0)", "43") + testAllApis( + Func1('f11), + "Func1(f11)", + "Func1(f11)", + "4") + + testAllApis( + Func1('f12), + "Func1(f12)", + "Func1(f12)", + "4") + + testAllApis( + Func1('f13), + "Func1(f13)", + "Func1(f13)", + "4.0") + testAllApis( Func2('f0, 'f1, 'f3), "Func2(f0, f1, f3)", @@ -360,7 +378,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { // ---------------------------------------------------------------------------------------------- override def testData: Any = { - val testData = new Row(11) + val testData = new Row(14) testData.setField(0, 42) testData.setField(1, "Test") testData.setField(2, null) @@ -372,6 +390,9 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { testData.setField(8, 1000L) testData.setField(9, Seq("Hello", "World")) testData.setField(10, Array[Integer](1, 2, null)) + testData.setField(11, 3.toByte) + testData.setField(12, 3.toShort) + testData.setField(13, 3.toFloat) testData } @@ -387,7 +408,10 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { Types.INTERVAL_MONTHS, Types.INTERVAL_MILLIS, TypeInformation.of(classOf[Seq[String]]), - BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO + BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, + Types.BYTE, + Types.SHORT, + Types.FLOAT ).asInstanceOf[TypeInformation[Any]] } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala index 528556968bf..9535cdf18d6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/userDefinedScalarFunctions.scala @@ -41,6 +41,12 @@ object Func1 extends ScalarFunction { def eval(index: Integer): Integer = { index + 1 } + + def eval(b: Byte): Byte = (b + 1).toByte + + def eval(s: Short): Short = (s + 1).toShort + + def eval(f: Float): Float = f + 1 } object Func2 extends ScalarFunction { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala index b10975214de..79243dd0ddd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CorrelateITCase.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException} +import org.apache.flink.table.api.{TableEnvironment, TableException, Types, ValidationException} import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0 import org.apache.flink.table.api.scala._ import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2} @@ -230,6 +230,27 @@ class CorrelateITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testByteShortFloatArguments(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val tFunc = new TableFunc4 + + val result = in + .select('a.cast(Types.BYTE) as 'a, 'a.cast(Types.SHORT) as 'b, 'b.cast(Types.FLOAT) as 'c) + .join(tFunc('a, 'b, 'c) as ('a2, 'b2, 'c2)) + .toDataSet[Row] + + val results = result.collect() + val expected = Seq( + "1,1,1.0,Byte=1,Short=1,Float=1.0", + "2,2,2.0,Byte=2,Short=2,Float=2.0", + "3,3,2.0,Byte=3,Short=3,Float=2.0", + "4,4,3.0,Byte=4,Short=4,Float=3.0").mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + @Test def testUserDefinedTableFunctionWithParameter(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala index d0ffade253c..e1af23b71ec 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala @@ -22,7 +22,7 @@ import java.lang.Boolean import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.Tuple3 import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.api.{Types, ValidationException} import org.apache.flink.table.functions.{FunctionContext, TableFunction} import org.apache.flink.types.Row import org.junit.Assert @@ -109,6 +109,16 @@ class TableFunc3(data: String, conf: Map[String, String]) extends TableFunction[ } } +class TableFunc4 extends TableFunction[Row] { + def eval(b: Byte, s: Short, f: Float): Unit = { + collect(Row.of("Byte=" + b, "Short=" + s, "Float=" + f)) + } + + override def getResultType: TypeInformation[Row] = { + new RowTypeInfo(Types.STRING, Types.STRING, Types.STRING) + } +} + class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] { def eval(user: String) { if (user.contains("#")) { -- GitLab