未验证 提交 68afd048 编写于 作者: R Rui Li 提交者: GitHub

[FLINK-16275][table-planner-blink] AggsHandlerCodeGenerator can fail with custom module (#11215)

上级 ee0c21e1
......@@ -17,6 +17,7 @@
*/
package org.apache.flink.table.planner.codegen.agg
import org.apache.flink.api.common.functions.RuntimeContext
import org.apache.flink.table.api.TableException
import org.apache.flink.table.dataformat.GenericRow
import org.apache.flink.table.dataformat.util.BaseRowUtil
......@@ -333,18 +334,27 @@ class AggsHandlerCodeGenerator(
val functionName = newName(name)
val RUNTIME_CONTEXT = className[RuntimeContext]
val functionCode =
j"""
public final class $functionName implements $AGGS_HANDLER_FUNCTION {
${ctx.reuseMemberCode()}
private $STATE_DATA_VIEW_STORE store;
public $functionName(java.lang.Object[] references) throws Exception {
${ctx.reuseInitCode()}
}
private $RUNTIME_CONTEXT getRuntimeContext() {
return store.getRuntimeContext();
}
@Override
public void open($STATE_DATA_VIEW_STORE store) throws Exception {
this.store = store;
${ctx.reuseOpenCode()}
}
......
......@@ -23,16 +23,19 @@ import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo}
import org.apache.flink.api.scala._
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, FunctionDefinition, ScalarFunctionDefinition}
import org.apache.flink.table.module.{CoreModule, Module}
import org.apache.flink.table.planner.runtime.utils.BatchTestBase
import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
import org.apache.flink.table.planner.runtime.utils.TestData._
import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.IsNullUDF
import org.apache.flink.table.planner.utils.DateTimeTestUtil._
import org.apache.flink.types.Row
import org.junit.{Before, Test}
import java.lang.{Iterable => JIterable, Long => JLong}
import java.util.{Collections, Optional}
import scala.collection.Seq
import scala.util.Random
......@@ -2522,6 +2525,20 @@ class OverWindowITCase extends BatchTestBase {
row(5, 14L, 30, 1, 15, 1, 1, true, null, false, null),
row(5, 15L, 30, 1, 15, 1, 1, true, null, false, null)))
}
@Test
def testRankWithCustomModule(): Unit = {
tEnv.unloadModule("core")
tEnv.loadModule("test-module", new TestModule)
tEnv.loadModule("core", CoreModule.INSTANCE)
registerCollection("emp",
Seq(row("1", "A", 1), row("1", "B", 2), row("2", "C", 3)),
new RowTypeInfo(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO),
"dep,name,salary")
checkResult(
"select dep,name,rank() over (partition by dep order by salary desc) as rnk from emp",
Seq(row("1", "A", 2), row("1", "B", 1), row("2", "C", 1)))
}
}
/** The initial accumulator for count aggregate function */
......@@ -2576,3 +2593,18 @@ class CountAggFunction extends AggregateFunction[JLong, CountAccumulator] {
override def getResultType: TypeInformation[JLong] = Types.LONG
}
private class TestModule extends Module {
private val funcName = "isnull"
override def listFunctions(): java.util.Set[String] = Collections.singleton(funcName)
override def getFunctionDefinition(name: String): Optional[FunctionDefinition] = {
if (name.equalsIgnoreCase(funcName)) {
Optional.of(new ScalarFunctionDefinition(name, IsNullUDF))
} else {
Optional.empty()
}
}
}
......@@ -377,6 +377,13 @@ object UserDefinedFunctionTestUtils {
override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = Types.JAVA_BIG_DEC
}
@SerialVersionUID(1L)
object IsNullUDF extends ScalarFunction {
def eval(v: Any): Boolean = v == null
override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = Types.BOOLEAN
}
// ------------------------------------------------------------------------------------
// POJOs
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册