From f350ade453d35f2a037998d0015c83bcbe323c47 Mon Sep 17 00:00:00 2001 From: Xiaoyu Wang Date: Thu, 16 Jun 2022 19:45:09 +0800 Subject: [PATCH] feat: the unique function and tail function are rewritten as the corresponding clauses --- include/libs/nodes/querynodes.h | 3 +++ source/libs/function/src/builtins.c | 7 ++++++- source/libs/nodes/src/nodesUtilFuncs.c | 20 ++++++++++++++++++ source/libs/parser/src/parTranslater.c | 3 +++ tests/system-test/2-query/tail.py | 23 ++++++++++---------- tests/system-test/2-query/unique.py | 29 +++++++++++++------------- 6 files changed, 59 insertions(+), 26 deletions(-) diff --git a/include/libs/nodes/querynodes.h b/include/libs/nodes/querynodes.h index 3adc619567..944b1f4b2e 100644 --- a/include/libs/nodes/querynodes.h +++ b/include/libs/nodes/querynodes.h @@ -376,6 +376,9 @@ bool nodesIsComparisonOp(const SOperatorNode* pOp); bool nodesIsJsonOp(const SOperatorNode* pOp); bool nodesIsRegularOp(const SOperatorNode* pOp); +bool nodesExprHasColumn(SNode* pNode); +bool nodesExprsHasColumn(SNodeList* pList); + void* nodesGetValueFromNode(SValueNode* pNode); int32_t nodesSetValueNodeValue(SValueNode* pNode, void* value); char* nodesGetStrValueFromNode(SValueNode* pNode); diff --git a/source/libs/function/src/builtins.c b/source/libs/function/src/builtins.c index 1cfa4d201e..e1375e8a6f 100644 --- a/source/libs/function/src/builtins.c +++ b/source/libs/function/src/builtins.c @@ -1085,7 +1085,12 @@ static int32_t translateUnique(SFunctionNode* pFunc, char* pErrBuf, int32_t len) return invaildFuncParaNumErrMsg(pErrBuf, len, pFunc->functionName); } - pFunc->node.resType = ((SExprNode*)nodesListGetNode(pFunc->pParameterList, 0))->resType; + SNode* pPara = nodesListGetNode(pFunc->pParameterList, 0); + if (!nodesExprHasColumn(pPara)) { + return buildFuncErrMsg(pErrBuf, len, TSDB_CODE_FUNC_FUNTION_ERROR, "The parameters of UNIQUE must contain columns"); + } + + pFunc->node.resType = ((SExprNode*)pPara)->resType; return TSDB_CODE_SUCCESS; } diff --git a/source/libs/nodes/src/nodesUtilFuncs.c b/source/libs/nodes/src/nodesUtilFuncs.c index a45ba53ad1..e2f4dd26f6 100644 --- a/source/libs/nodes/src/nodesUtilFuncs.c +++ b/source/libs/nodes/src/nodesUtilFuncs.c @@ -1463,6 +1463,26 @@ int32_t nodesCollectSpecialNodes(SSelectStmt* pSelect, ESqlClause clause, ENodeT return TSDB_CODE_SUCCESS; } +static EDealRes hasColumn(SNode* pNode, void* pContext) { + if (QUERY_NODE_COLUMN == nodeType(pNode)) { + *(bool*)pContext = true; + return DEAL_RES_END; + } + return DEAL_RES_CONTINUE; +} + +bool nodesExprHasColumn(SNode* pNode) { + bool hasCol = false; + nodesWalkExprPostOrder(pNode, hasColumn, &hasCol); + return hasCol; +} + +bool nodesExprsHasColumn(SNodeList* pList) { + bool hasCol = false; + nodesWalkExprsPostOrder(pList, hasColumn, &hasCol); + return hasCol; +} + char* nodesGetFillModeString(EFillMode mode) { switch (mode) { case FILL_MODE_NONE: diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index f17d91f042..814c247e09 100644 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -2170,6 +2170,7 @@ static EDealRes rewriteSeletcValueFunc(STranslateContext* pCxt, SNode** pNode) { } strcpy(pFirst->functionName, "first"); TSWAP(pFirst->pParameterList, ((SFunctionNode*)*pNode)->pParameterList); + strcpy(pFirst->node.aliasName, ((SExprNode*)*pNode)->aliasName); nodesDestroyNode(*pNode); *pNode = (SNode*)pFirst; pCxt->errCode = fmGetFuncInfo(pFirst, pCxt->msgBuf.buf, pCxt->msgBuf.len); @@ -2184,6 +2185,7 @@ static EDealRes rewriteUniqueFunc(SNode** pNode, void* pContext) { if (FUNCTION_TYPE_UNIQUE == pFunc->funcType) { SNode* pExpr = nodesListGetNode(pFunc->pParameterList, 0); NODES_CLEAR_LIST(pFunc->pParameterList); + strcpy(((SExprNode*)pExpr)->aliasName, ((SExprNode*)*pNode)->aliasName); nodesDestroyNode(*pNode); *pNode = pExpr; pCxt->pExpr = pExpr; @@ -2238,6 +2240,7 @@ static EDealRes rewriteTailFunc(SNode** pNode, void* pContext) { pCxt->offset = ((SValueNode*)nodesListGetNode(pFunc->pParameterList, 2))->datum.i; } SNode* pExpr = nodesListGetNode(pFunc->pParameterList, 0); + strcpy(((SExprNode*)pExpr)->aliasName, ((SExprNode*)*pNode)->aliasName); NODES_CLEAR_LIST(pFunc->pParameterList); nodesDestroyNode(*pNode); *pNode = pExpr; diff --git a/tests/system-test/2-query/tail.py b/tests/system-test/2-query/tail.py index 6039f3effa..bb11fbbef5 100644 --- a/tests/system-test/2-query/tail.py +++ b/tests/system-test/2-query/tail.py @@ -188,7 +188,8 @@ class TDTestCase: def check_tail_table(self , tbname , col_name , tail_rows , offset): tail_sql = f"select tail({col_name} , {tail_rows} , {offset}) from {tbname}" - equal_sql = f"select {col_name} from (select ts , {col_name} from {tbname} order by ts desc limit {tail_rows} offset {offset}) order by ts" + #equal_sql = f"select {col_name} from (select ts , {col_name} from {tbname} order by ts desc limit {tail_rows} offset {offset}) order by ts" + equal_sql = f"select {col_name} from {tbname} order by ts desc limit {tail_rows} offset {offset}" tdSql.query(tail_sql) tail_result = tdSql.queryResult @@ -294,21 +295,21 @@ class TDTestCase: tdSql.checkData(1, 0, None) tdSql.query("select tail(c1,3,2) from ct4 where c1 >2 ") - tdSql.checkData(0, 0, 7) + tdSql.checkData(0, 0, 5) tdSql.checkData(1, 0, 6) - tdSql.checkData(2, 0, 5) + tdSql.checkData(2, 0, 7) tdSql.query("select tail(c1,2,1) from ct4 where c2 between 0 and 99999") - tdSql.checkData(0, 0, 2) - tdSql.checkData(1, 0, 1) + tdSql.checkData(0, 0, 1) + tdSql.checkData(1, 0, 2) # tail with union all tdSql.query("select tail(c1,2,1) from ct4 union all select c1 from ct1") tdSql.checkRows(15) tdSql.query("select tail(c1,2,1) from ct4 union all select c1 from ct2") tdSql.checkRows(2) - tdSql.checkData(0, 0, 1) - tdSql.checkData(1, 0, 0) + tdSql.checkData(0, 0, 0) + tdSql.checkData(1, 0, 1) tdSql.query("select tail(c2,2,1) from ct4 union all select abs(c2)/2 from ct4") tdSql.checkRows(14) @@ -336,16 +337,16 @@ class TDTestCase: tdSql.query("select tail(tb2.num,3,2) from tb1, tb2 where tb1.ts=tb2.ts ") tdSql.checkRows(3) - tdSql.checkData(0,0,5) + tdSql.checkData(0,0,7) tdSql.checkData(1,0,6) - tdSql.checkData(2,0,7) + tdSql.checkData(2,0,5) # nest query # tdSql.query("select tail(c1,2) from (select c1 from ct1)") tdSql.query("select c1 from (select tail(c1,2) c1 from ct4)") tdSql.checkRows(2) - tdSql.checkData(0, 0, 0) - tdSql.checkData(1, 0, None) + tdSql.checkData(0, 0, None) + tdSql.checkData(1, 0, 0) tdSql.query("select sum(c1) from (select tail(c1,2) c1 from ct1)") tdSql.checkRows(1) diff --git a/tests/system-test/2-query/unique.py b/tests/system-test/2-query/unique.py index 227efa6f9c..f48202010a 100644 --- a/tests/system-test/2-query/unique.py +++ b/tests/system-test/2-query/unique.py @@ -93,8 +93,8 @@ class TDTestCase: "select unique(c1) , min(c1) from t1", "select unique(c1) , spread(c1) from t1", "select unique(c1) , diff(c1) from t1", - "select unique(c1) , abs(c1) from t1", - "select unique(c1) , c1 from t1", + #"select unique(c1) , abs(c1) from t1", # support + #"select unique(c1) , c1 from t1", "select unique from stb1 partition by tbname", "select unique(123--123)==1 from stb1 partition by tbname", "select unique(123) from stb1 partition by tbname", @@ -104,21 +104,21 @@ class TDTestCase: "select unique(c1 ,c2 ) from stb1 partition by tbname", "select unique(c1 ,NULL) from stb1 partition by tbname", "select unique(,) from stb1 partition by tbname;", - "select unique(floor(c1) ab from stb1 partition by tbname)", - "select unique(c1) as int from stb1 partition by tbname", + #"select unique(floor(c1) ab from stb1 partition by tbname)", # support + #"select unique(c1) as int from stb1 partition by tbname", "select unique('c1') from stb1 partition by tbname", "select unique(NULL) from stb1 partition by tbname", "select unique('') from stb1 partition by tbname", "select unique(c%) from stb1 partition by tbname", - #"select unique(t1) from stb1 partition by tbname", + #"select unique(t1) from stb1 partition by tbname", # support "select unique(True) from stb1 partition by tbname", "select unique(c1) , count(c1) from stb1 partition by tbname", "select unique(c1) , avg(c1) from stb1 partition by tbname", "select unique(c1) , min(c1) from stb1 partition by tbname", "select unique(c1) , spread(c1) from stb1 partition by tbname", "select unique(c1) , diff(c1) from stb1 partition by tbname", - "select unique(c1) , abs(c1) from stb1 partition by tbname", - "select unique(c1) , c1 from stb1 partition by tbname" + #"select unique(c1) , abs(c1) from stb1 partition by tbname", # support + #"select unique(c1) , c1 from stb1 partition by tbname" # support ] for error_sql in error_sql_lists: @@ -198,7 +198,7 @@ class TDTestCase: unique_datas = [] for elem in unique_result: unique_datas.append(elem[0]) - + unique_datas.sort(key=lambda x: (x is None, x)) tdSql.query(origin_sql) origin_result = tdSql.queryResult @@ -212,6 +212,7 @@ class TDTestCase: continue else: pre_unique.append(elem) + pre_unique.sort(key=lambda x: (x is None, x)) if pre_unique == unique_datas: tdLog.info(" unique query check pass , unique sql is: %s" %unique_sql) @@ -266,16 +267,16 @@ class TDTestCase: tdSql.checkRows(10) tdSql.error("select unique(c1),tbname from ct1") - tdSql.error("select unique(c1),t1 from ct1") + #tdSql.error("select unique(c1),t1 from ct1") #support # unique with common col - tdSql.error("select unique(c1) ,ts from ct1") - tdSql.error("select unique(c1) ,c1 from ct1") + #tdSql.error("select unique(c1) ,ts from ct1") + #tdSql.error("select unique(c1) ,c1 from ct1") # unique with scalar function - tdSql.error("select unique(c1) ,abs(c1) from ct1") + #tdSql.error("select unique(c1) ,abs(c1) from ct1") tdSql.error("select unique(c1) , unique(c2) from ct1") - tdSql.error("select unique(c1) , abs(c2)+2 from ct1") + #tdSql.error("select unique(c1) , abs(c2)+2 from ct1") # unique with aggregate function @@ -288,7 +289,7 @@ class TDTestCase: tdSql.query("select unique(c1) from ct4 where c1 is null") tdSql.checkData(0, 0, None) - tdSql.query("select unique(c1) from ct4 where c1 >2 ") + tdSql.query("select unique(c1) from ct4 where c1 >2") tdSql.checkData(0, 0, 8) tdSql.checkData(1, 0, 7) tdSql.checkData(2, 0, 6) -- GitLab