提交 3954edb3 编写于 作者: X Xiaoyu Wang

feat: avg function rewrite

上级 3eda47e1
......@@ -171,6 +171,10 @@ bool fmIsRepeatScanFunc(int32_t funcId);
bool fmIsUserDefinedFunc(int32_t funcId);
bool fmIsDistExecFunc(int32_t funcId);
bool fmIsForbidFillFunc(int32_t funcId);
bool fmIsForbidStreamFunc(int32_t funcId);
bool fmNeedRewrite(int32_t funcId);
int32_t fmRewriteFunc(SNode** pFunc);
int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc, SFunctionNode** pMergeFunc);
......
......@@ -23,6 +23,7 @@ extern "C" {
#include "functionMgtInt.h"
typedef int32_t (*FTranslateFunc)(SFunctionNode* pFunc, char* pErrBuf, int32_t len);
typedef int32_t (*FRewriteFunc)(SNode** pFunc);
typedef EFuncDataRequired (*FFuncDataRequired)(SFunctionNode* pFunc, STimeWindow* pTimeWindow);
typedef struct SBuiltinFuncDefinition {
......@@ -30,6 +31,7 @@ typedef struct SBuiltinFuncDefinition {
EFunctionType type;
uint64_t classification;
FTranslateFunc translateFunc;
FRewriteFunc rewriteFunc;
FFuncDataRequired dataRequiredFunc;
FExecGetEnv getEnvFunc;
FExecInit initFunc;
......
......@@ -42,6 +42,7 @@ extern "C" {
#define FUNC_MGT_SELECT_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(13)
#define FUNC_MGT_REPEAT_SCAN_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(14)
#define FUNC_MGT_FORBID_FILL_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(15)
#define FUNC_MGT_FORBID_STREAM_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(16)
#define FUNC_MGT_TEST_MASK(val, mask) (((val) & (mask)) != 0)
......
......@@ -1333,6 +1333,49 @@ static bool getBlockDistFuncEnv(SFunctionNode* UNUSED_PARAM(pFunc), SFuncExecEnv
return true;
}
static int32_t rewriteAvg(SNode** pFunc) {
SOperatorNode* pOper = (SOperatorNode*)nodesMakeNode(QUERY_NODE_OPERATOR);
if (NULL == pOper) {
return TSDB_CODE_OUT_OF_MEMORY;
}
SFunctionNode* pAvg = (SFunctionNode*)*pFunc;
pOper->node.resType = pAvg->node.resType;
strcpy(pOper->node.aliasName, pAvg->node.aliasName);
pOper->opType = OP_TYPE_DIV;
pOper->pLeft = nodesMakeNode(QUERY_NODE_FUNCTION);
pOper->pRight = nodesMakeNode(QUERY_NODE_FUNCTION);
if (NULL == pOper->pLeft || NULL == pOper->pRight) {
nodesDestroyNode((SNode*)pOper);
return TSDB_CODE_OUT_OF_MEMORY;
}
SFunctionNode* pSum = (SFunctionNode*)pOper->pLeft;
strcpy(pSum->functionName, "sum");
pSum->pParameterList = nodesCloneList(pAvg->pParameterList);
if (NULL == pSum->pParameterList) {
nodesDestroyNode((SNode*)pOper);
return TSDB_CODE_OUT_OF_MEMORY;
}
char msgBuf[64] = {0};
int32_t code = fmGetFuncInfo(pSum, msgBuf, sizeof(msgBuf));
if (TSDB_CODE_SUCCESS == code) {
SFunctionNode* pCount = (SFunctionNode*)pOper->pRight;
strcpy(pCount->functionName, "count");
TSWAP(pCount->pParameterList, pAvg->pParameterList);
code = fmGetFuncInfo(pCount, msgBuf, sizeof(msgBuf));
}
if (TSDB_CODE_SUCCESS == code) {
nodesDestroyNode((SNode*)pAvg);
*pFunc = (SNode*)pOper;
} else {
nodesDestroyNode((SNode*)pOper);
}
return code;
}
// clang-format off
const SBuiltinFuncDefinition funcMgtBuiltins[] = {
{
......@@ -1422,6 +1465,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = {
.type = FUNCTION_TYPE_AVG,
.classification = FUNC_MGT_AGG_FUNC,
.translateFunc = translateInNumOutDou,
.rewriteFunc = rewriteAvg,
.getEnvFunc = getAvgFuncEnv,
.initFunc = avgFunctionSetup,
.processFunc = avgFunction,
......
......@@ -161,6 +161,8 @@ bool fmIsUserDefinedFunc(int32_t funcId) { return funcId > FUNC_UDF_ID_START; }
bool fmIsForbidFillFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_FILL_FUNC); }
bool fmIsForbidStreamFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_STREAM_FUNC); }
void fmFuncMgtDestroy() {
void* m = gFunMgtService.pFuncNameHashTable;
if (m != NULL && atomic_val_compare_exchange_ptr((void**)&gFunMgtService.pFuncNameHashTable, m, 0) == m) {
......@@ -297,3 +299,12 @@ int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc
return code;
}
bool fmNeedRewrite(int32_t funcId) {
if (fmIsUserDefinedFunc(funcId)) {
return false;
}
return NULL != funcMgtBuiltins[funcId].rewriteFunc;
}
int32_t fmRewriteFunc(SNode** pFunc) { return funcMgtBuiltins[((SFunctionNode*)*pFunc)->funcId].rewriteFunc(pFunc); }
......@@ -1076,29 +1076,32 @@ static void setFuncClassification(SSelectStmt* pSelect, SFunctionNode* pFunc) {
}
}
static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode* pFunc) {
static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode** pFunc) {
SNode* pParam = NULL;
FOREACH(pParam, pFunc->pParameterList) {
FOREACH(pParam, (*pFunc)->pParameterList) {
if (isMultiResFunc(pParam)) {
return generateDealNodeErrMsg(pCxt, TSDB_CODE_PAR_WRONG_VALUE_TYPE, ((SExprNode*)pParam)->aliasName);
}
}
pCxt->errCode = getFuncInfo(pCxt, pFunc);
pCxt->errCode = getFuncInfo(pCxt, *pFunc);
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateAggFunc(pCxt, pFunc);
pCxt->errCode = translateAggFunc(pCxt, *pFunc);
}
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateScanPseudoColumnFunc(pCxt, pFunc);
pCxt->errCode = translateScanPseudoColumnFunc(pCxt, *pFunc);
}
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateIndefiniteRowsFunc(pCxt, pFunc);
pCxt->errCode = translateIndefiniteRowsFunc(pCxt, *pFunc);
}
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateForbidFillFunc(pCxt, pFunc);
pCxt->errCode = translateForbidFillFunc(pCxt, *pFunc);
}
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
setFuncClassification(pCxt->pCurrSelectStmt, pFunc);
setFuncClassification(pCxt->pCurrSelectStmt, *pFunc);
}
if (TSDB_CODE_SUCCESS == pCxt->errCode && fmNeedRewrite((*pFunc)->funcId)) {
pCxt->errCode = fmRewriteFunc((SNode**)pFunc);
}
return TSDB_CODE_SUCCESS == pCxt->errCode ? DEAL_RES_CONTINUE : DEAL_RES_ERROR;
}
......@@ -1123,7 +1126,7 @@ static EDealRes doTranslateExpr(SNode** pNode, void* pContext) {
case QUERY_NODE_OPERATOR:
return translateOperator(pCxt, (SOperatorNode**)pNode);
case QUERY_NODE_FUNCTION:
return translateFunction(pCxt, (SFunctionNode*)*pNode);
return translateFunction(pCxt, (SFunctionNode**)pNode);
case QUERY_NODE_LOGIC_CONDITION:
return translateLogicCond(pCxt, (SLogicConditionNode*)*pNode);
case QUERY_NODE_TEMP_TABLE:
......
......@@ -35,6 +35,7 @@ string toString(int32_t code) { return tstrerror(code); }
// [...];
class InsertTest : public Test {
protected:
InsertTest() : res_(nullptr) {}
~InsertTest() { reset(); }
void setDatabase(const string& acctId, const string& db) {
......
......@@ -53,6 +53,14 @@ TEST_F(PlanGroupByTest, aggFunc) {
run("SELECT SUM(10), COUNT(c1) FROM t1 GROUP BY c2");
}
TEST_F(PlanGroupByTest, rewriteFunc) {
useDb("root", "test");
run("SELECT AVG(c1) FROM t1");
run("SELECT AVG(c1) FROM t1 GROUP BY c2");
}
TEST_F(PlanGroupByTest, selectFunc) {
useDb("root", "test");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册