From 2d3992b72b33ca308b00f58e3bc15dca9a476b54 Mon Sep 17 00:00:00 2001 From: dapan1121 Date: Thu, 29 Sep 2022 19:19:49 +0800 Subject: [PATCH] feat: support case when clause --- include/libs/function/function.h | 2 +- source/libs/scalar/src/scalar.c | 179 +++++++++++++++++++++++++---- source/libs/scalar/src/sclvector.c | 23 +++- 3 files changed, 172 insertions(+), 32 deletions(-) diff --git a/include/libs/function/function.h b/include/libs/function/function.h index 60c7b18367..25eeda1c3a 100644 --- a/include/libs/function/function.h +++ b/include/libs/function/function.h @@ -176,7 +176,7 @@ struct SScalarParam { SColumnInfoData *columnData; SHashObj *pHashFilter; int32_t hashValueType; - void *param; // other parameter, such as meta handle from vnode, to extract table name/tag value + void *param; // other parameter, such as meta handle from vnode, to extract table name/tag value int32_t numOfRows; int32_t numOfQualified; // number of qualified elements in the final results }; diff --git a/source/libs/scalar/src/scalar.c b/source/libs/scalar/src/scalar.c index 9394c4f6ca..26c4f948cd 100644 --- a/source/libs/scalar/src/scalar.c +++ b/source/libs/scalar/src/scalar.c @@ -177,7 +177,7 @@ void sclFreeRes(SHashObj *res) { } void sclFreeParam(SScalarParam *param) { - if (!param->colAlloced) { + if (NULL == param || !param->colAlloced) { return; } @@ -378,7 +378,8 @@ int32_t sclInitParam(SNode* node, SScalarParam *param, SScalarCtx *ctx, int32_t } case QUERY_NODE_FUNCTION: case QUERY_NODE_OPERATOR: - case QUERY_NODE_LOGIC_CONDITION: { + case QUERY_NODE_LOGIC_CONDITION: + case QUERY_NODE_CASE_WHEN: { SScalarParam *res = (SScalarParam *)taosHashGet(ctx->pRes, &node, POINTER_BYTES); if (NULL == res) { sclError("no result for node, type:%d, node:%p", nodeType(node), node); @@ -532,10 +533,100 @@ _return: SCL_RET(code); } -int32_t sclGetNodeRes(SNode* node, SArray* pBlockList, SScalarParam **res) { +int32_t sclGetNodeRes(SNode* node, SScalarCtx *ctx, SScalarParam **res) { + if (NULL == node) { + return TSDB_CODE_SUCCESS; + } + + int32_t rowNum = 0; + *res = taosMemoryCalloc(1, sizeof(**res)); + if (NULL == *res) { + SCL_ERR_RET(TSDB_CODE_OUT_OF_MEMORY); + } + + SCL_ERR_RET(sclInitParam(node, *res, ctx, &rowNum)); + + return TSDB_CODE_SUCCESS; +} + +int32_t sclWalkCaseWhenList(SScalarCtx *ctx, SNodeList* pList, struct SListCell* pCell, SScalarParam *pCase, SScalarParam *pElse, SScalarParam *pComp, SScalarParam *output, int32_t rowIdx) { + SNode *node = NULL; + SWhenThenNode* pWhenThen = NULL; + SScalarParam *pWhen = NULL; + SScalarParam *pThen = NULL; + + for (SListCell* cell = pCell; (NULL != cell ? (node = cell->pNode, true) : (node = NULL, false)); cell = cell->pNext) { + pWhenThen = (SWhenThenNode*)node; + + SCL_ERR_RET(sclGetNodeRes(pWhenThen->pWhen, ctx, &pWhen)); + SCL_ERR_RET(sclGetNodeRes(pWhenThen->pThen, ctx, &pThen)); + + doVectorCompare(pCase, pWhen, pComp, rowIdx, 1, TSDB_ORDER_ASC, OP_TYPE_EQUAL); + + bool *equal = colDataGetData(pComp->columnData, rowIdx); + if (*equal) { + colDataAppend(output->columnData, rowIdx, colDataGetData(pThen, (pThen->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? rowIdx : 0))); + + return TSDB_CODE_SUCCESS; + } + } + + if (pElse) { + colDataAppend(output->columnData, rowIdx, colDataGetData(pElse, (pElse->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pElse, (pElse->numOfRows > 1 ? rowIdx : 0))); + + return TSDB_CODE_SUCCESS; + } + + colDataAppend(output->columnData, rowIdx, NULL, true); + + return TSDB_CODE_SUCCESS; +} + +int32_t sclWalkWhenList(SScalarCtx *ctx, SNodeList* pList, struct SListCell* pCell, SScalarParam *pElse, SScalarParam *output, int32_t rowIdx) { + SNode *node = NULL; + SWhenThenNode* pWhenThen = NULL; + SScalarParam *pWhen = NULL; + SScalarParam *pThen = NULL; + int32_t code = 0; + + for (SListCell* cell = pCell; (NULL != cell ? (node = cell->pNode, true) : (node = NULL, false)); cell = cell->pNext) { + pWhenThen = (SWhenThenNode*)node; + pWhen = NULL; + pThen = NULL; + + SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pWhen, ctx, &pWhen)); + SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pThen, ctx, &pThen)); + + bool *whenValue = colDataGetData(pWhen->columnData, (pWhen->numOfRows > 1 ? rowIdx : 0)); + + if (*whenValue) { + colDataAppend(output->columnData, rowIdx, colDataGetData(pThen, (pThen->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? rowIdx : 0))); + + goto _return; + } + + sclFreeParam(pWhen); + sclFreeParam(pThen); + } + + if (pElse) { + colDataAppend(output->columnData, rowIdx, colDataGetData(pElse, (pElse->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pElse, (pElse->numOfRows > 1 ? rowIdx : 0))); + + goto _return; + } + + colDataAppend(output->columnData, rowIdx, NULL, true); +_return: + + sclFreeParam(pWhen); + sclFreeParam(pThen); + + SCL_RET(code); } + + int32_t sclExecFunction(SFunctionNode *node, SScalarCtx *ctx, SScalarParam *output) { SScalarParam *params = NULL; int32_t rowNum = 0; @@ -693,42 +784,80 @@ _return: } int32_t sclExecCaseWhen(SCaseWhenNode *node, SScalarCtx *ctx, SScalarParam *output) { - int32_t rowNum = 0; int32_t code = 0; SScalarParam *pCase = NULL; SScalarParam *pElse = NULL; - - if (node->pCase) { - SCL_ERR_RET(sclGetNodeRes(node->pCase, ctx->pBlockList, &pCase)); + SScalarParam *pWhen = NULL; + SScalarParam *pThen = NULL; + SScalarParam *pComp = NULL; + int32_t rowNum = 1; + + if (NULL == node->pWhenThenList || node->pWhenThenList->length <= 0) { + sclError("invalid whenThen list"); + SCL_ERR_RET(TSDB_CODE_INVALID_PARA); } + + if (ctx->pBlockList) { + SSDataBlock* pb = taosArrayGetP(ctx->pBlockList, 0); + rowNum = pb->info.rows; + } + + SCL_ERR_JRET(sclCreateColumnInfoData(&node->node.resType, rowNum, output)); + + SCL_ERR_JRET(sclGetNodeRes(node->pCase, ctx, &pCase)); + SCL_ERR_JRET(sclGetNodeRes(node->pElse, ctx, &pElse)); + + SDataType compType = {0}; + compType.type = TSDB_DATA_TYPE_BOOL; + compType.bytes = tDataTypes[compType.type].bytes; + SCL_ERR_JRET(sclCreateColumnInfoData(&compType, rowNum, pComp)); + SNode* tnode = NULL; - FOREACH(tnode, node->pWhenThenList) { - if (QUERY_NODE_VALUE == tnode->type) { - return DEAL_RES_CONTINUE; + SWhenThenNode* pWhenThen = (SWhenThenNode*)node->pWhenThenList->pHead->pNode; + + SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pWhen, ctx, &pWhen)); + SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pThen, ctx, &pThen)); + + if (pCase) { + vectorCompare(pCase, pWhen, pComp, TSDB_ORDER_ASC, OP_TYPE_EQUAL); + + for (int32_t i = 0; i < rowNum; ++i) { + bool *equal = colDataGetData(pComp->columnData, i); + if (*equal) { + colDataAppend(output->columnData, i, colDataGetData(pThen, (pThen->numOfRows > 1 ? i : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? i : 0))); + } else { + SCL_ERR_JRET(sclWalkCaseWhenList(ctx, node->pWhenThenList, node->pWhenThenList->pHead->pNext, pCase, pElse, pComp, output, i)); + } } - } - - if (output->columnData == NULL) { - code = sclCreateColumnInfoData(&node->node.resType, rowNum, output); - if (code != TSDB_CODE_SUCCESS) { - SCL_ERR_JRET(code); + } else { + for (int32_t i = 0; i < rowNum; ++i) { + bool *equal = colDataGetData(pWhen->columnData, i); + if (*equal) { + colDataAppend(output->columnData, i, colDataGetData(pThen, (pThen->numOfRows > 1 ? i : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? i : 0))); + } else { + SCL_ERR_JRET(sclWalkWhenList(ctx, node->pWhenThenList, node->pWhenThenList->pHead->pNext, pElse, output, i)); + } } } - _bin_scalar_fn_t OperatorFn = getBinScalarOperatorFn(node->opType); + sclFreeParam(pCase); + sclFreeParam(pElse); + sclFreeParam(pComp); + sclFreeParam(pWhen); + sclFreeParam(pThen); - int32_t paramNum = scalarGetOperatorParamNum(node->opType); - SScalarParam* pLeft = ¶ms[0]; - SScalarParam* pRight = paramNum > 1 ? ¶ms[1] : NULL; - - terrno = TSDB_CODE_SUCCESS; - OperatorFn(pLeft, pRight, output, TSDB_ORDER_ASC); - code = terrno; + return TSDB_CODE_SUCCESS; _return: - sclFreeParamList(params, paramNum); + sclFreeParam(pCase); + sclFreeParam(pElse); + sclFreeParam(pComp); + sclFreeParam(pWhen); + sclFreeParam(pThen); + sclFreeParam(output); + SCL_RET(code); } diff --git a/source/libs/scalar/src/sclvector.c b/source/libs/scalar/src/sclvector.c index fe2a970aaa..41725a1cb7 100644 --- a/source/libs/scalar/src/sclvector.c +++ b/source/libs/scalar/src/sclvector.c @@ -1646,8 +1646,8 @@ void vectorBitOr(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut, doReleaseVec(pRightCol, rightConvert); } -int32_t doVectorCompareImpl(int32_t numOfRows, SScalarParam *pOut, int32_t startIndex, int32_t step, __compar_fn_t fp, - SScalarParam *pLeft, SScalarParam *pRight, int32_t optr) { +int32_t doVectorCompareImpl(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t startIndex, int32_t numOfRows, + int32_t step, __compar_fn_t fp, int32_t optr) { int32_t num = 0; for (int32_t i = startIndex; i < numOfRows && i >= 0; i += step) { @@ -1700,12 +1700,14 @@ int32_t doVectorCompareImpl(int32_t numOfRows, SScalarParam *pOut, int32_t start return num; } -void vectorCompareImpl(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut, int32_t _ord, int32_t optr) { - int32_t i = ((_ord) == TSDB_ORDER_ASC) ? 0 : TMAX(pLeft->numOfRows, pRight->numOfRows) - 1; +void doVectorCompare(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut, int32_t startIndex, int32_t numOfRows, + int32_t _ord, int32_t optr) { + int32_t i = 0; int32_t step = ((_ord) == TSDB_ORDER_ASC) ? 1 : -1; int32_t lType = GET_PARAM_TYPE(pLeft); int32_t rType = GET_PARAM_TYPE(pRight); __compar_fn_t fp = NULL; + int32_t compRows = 0; if (lType == rType) { fp = filterGetCompFunc(lType, optr); @@ -1715,6 +1717,14 @@ void vectorCompareImpl(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam * pOut->numOfRows = TMAX(pLeft->numOfRows, pRight->numOfRows); + if (startIndex < 0) { + i = ((_ord) == TSDB_ORDER_ASC) ? 0 : TMAX(pLeft->numOfRows, pRight->numOfRows) - 1; + compRows = pOut->numOfRows; + } else { + compRows = startIndex + numOfRows; + i = startIndex; + } + if (pRight->pHashFilter != NULL) { for (; i >= 0 && i < pLeft->numOfRows; i += step) { if (IS_HELPER_NULL(pLeft->columnData, i)) { @@ -1731,7 +1741,7 @@ void vectorCompareImpl(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam * } } } else { // normal compare - pOut->numOfQualified = doVectorCompareImpl(pOut->numOfRows, pOut, i, step, fp, pLeft, pRight, optr); + pOut->numOfQualified = doVectorCompareImpl(pLeft, pRight, pOut, i, compRows, step, fp, optr); } } @@ -1760,7 +1770,8 @@ void vectorCompare(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut } } - vectorCompareImpl(param1, param2, pOut, _ord, optr); + doVectorCompare(param1, param2, pOut, -1, -1, _ord, optr); + sclFreeParam(&pLeftOut); sclFreeParam(&pRightOut); } -- GitLab