提交 2d3992b7 编写于 作者: D dapan1121

feat: support case when clause

上级 0022a766
......@@ -176,7 +176,7 @@ struct SScalarParam {
SColumnInfoData *columnData;
SHashObj *pHashFilter;
int32_t hashValueType;
void *param; // other parameter, such as meta handle from vnode, to extract table name/tag value
void *param; // other parameter, such as meta handle from vnode, to extract table name/tag value
int32_t numOfRows;
int32_t numOfQualified; // number of qualified elements in the final results
};
......
......@@ -177,7 +177,7 @@ void sclFreeRes(SHashObj *res) {
}
void sclFreeParam(SScalarParam *param) {
if (!param->colAlloced) {
if (NULL == param || !param->colAlloced) {
return;
}
......@@ -378,7 +378,8 @@ int32_t sclInitParam(SNode* node, SScalarParam *param, SScalarCtx *ctx, int32_t
}
case QUERY_NODE_FUNCTION:
case QUERY_NODE_OPERATOR:
case QUERY_NODE_LOGIC_CONDITION: {
case QUERY_NODE_LOGIC_CONDITION:
case QUERY_NODE_CASE_WHEN: {
SScalarParam *res = (SScalarParam *)taosHashGet(ctx->pRes, &node, POINTER_BYTES);
if (NULL == res) {
sclError("no result for node, type:%d, node:%p", nodeType(node), node);
......@@ -532,10 +533,100 @@ _return:
SCL_RET(code);
}
int32_t sclGetNodeRes(SNode* node, SArray* pBlockList, SScalarParam **res) {
int32_t sclGetNodeRes(SNode* node, SScalarCtx *ctx, SScalarParam **res) {
if (NULL == node) {
return TSDB_CODE_SUCCESS;
}
int32_t rowNum = 0;
*res = taosMemoryCalloc(1, sizeof(**res));
if (NULL == *res) {
SCL_ERR_RET(TSDB_CODE_OUT_OF_MEMORY);
}
SCL_ERR_RET(sclInitParam(node, *res, ctx, &rowNum));
return TSDB_CODE_SUCCESS;
}
int32_t sclWalkCaseWhenList(SScalarCtx *ctx, SNodeList* pList, struct SListCell* pCell, SScalarParam *pCase, SScalarParam *pElse, SScalarParam *pComp, SScalarParam *output, int32_t rowIdx) {
SNode *node = NULL;
SWhenThenNode* pWhenThen = NULL;
SScalarParam *pWhen = NULL;
SScalarParam *pThen = NULL;
for (SListCell* cell = pCell; (NULL != cell ? (node = cell->pNode, true) : (node = NULL, false)); cell = cell->pNext) {
pWhenThen = (SWhenThenNode*)node;
SCL_ERR_RET(sclGetNodeRes(pWhenThen->pWhen, ctx, &pWhen));
SCL_ERR_RET(sclGetNodeRes(pWhenThen->pThen, ctx, &pThen));
doVectorCompare(pCase, pWhen, pComp, rowIdx, 1, TSDB_ORDER_ASC, OP_TYPE_EQUAL);
bool *equal = colDataGetData(pComp->columnData, rowIdx);
if (*equal) {
colDataAppend(output->columnData, rowIdx, colDataGetData(pThen, (pThen->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? rowIdx : 0)));
return TSDB_CODE_SUCCESS;
}
}
if (pElse) {
colDataAppend(output->columnData, rowIdx, colDataGetData(pElse, (pElse->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pElse, (pElse->numOfRows > 1 ? rowIdx : 0)));
return TSDB_CODE_SUCCESS;
}
colDataAppend(output->columnData, rowIdx, NULL, true);
return TSDB_CODE_SUCCESS;
}
int32_t sclWalkWhenList(SScalarCtx *ctx, SNodeList* pList, struct SListCell* pCell, SScalarParam *pElse, SScalarParam *output, int32_t rowIdx) {
SNode *node = NULL;
SWhenThenNode* pWhenThen = NULL;
SScalarParam *pWhen = NULL;
SScalarParam *pThen = NULL;
int32_t code = 0;
for (SListCell* cell = pCell; (NULL != cell ? (node = cell->pNode, true) : (node = NULL, false)); cell = cell->pNext) {
pWhenThen = (SWhenThenNode*)node;
pWhen = NULL;
pThen = NULL;
SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pWhen, ctx, &pWhen));
SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pThen, ctx, &pThen));
bool *whenValue = colDataGetData(pWhen->columnData, (pWhen->numOfRows > 1 ? rowIdx : 0));
if (*whenValue) {
colDataAppend(output->columnData, rowIdx, colDataGetData(pThen, (pThen->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? rowIdx : 0)));
goto _return;
}
sclFreeParam(pWhen);
sclFreeParam(pThen);
}
if (pElse) {
colDataAppend(output->columnData, rowIdx, colDataGetData(pElse, (pElse->numOfRows > 1 ? rowIdx : 0)), colDataIsNull_s(pElse, (pElse->numOfRows > 1 ? rowIdx : 0)));
goto _return;
}
colDataAppend(output->columnData, rowIdx, NULL, true);
_return:
sclFreeParam(pWhen);
sclFreeParam(pThen);
SCL_RET(code);
}
int32_t sclExecFunction(SFunctionNode *node, SScalarCtx *ctx, SScalarParam *output) {
SScalarParam *params = NULL;
int32_t rowNum = 0;
......@@ -693,42 +784,80 @@ _return:
}
int32_t sclExecCaseWhen(SCaseWhenNode *node, SScalarCtx *ctx, SScalarParam *output) {
int32_t rowNum = 0;
int32_t code = 0;
SScalarParam *pCase = NULL;
SScalarParam *pElse = NULL;
if (node->pCase) {
SCL_ERR_RET(sclGetNodeRes(node->pCase, ctx->pBlockList, &pCase));
SScalarParam *pWhen = NULL;
SScalarParam *pThen = NULL;
SScalarParam *pComp = NULL;
int32_t rowNum = 1;
if (NULL == node->pWhenThenList || node->pWhenThenList->length <= 0) {
sclError("invalid whenThen list");
SCL_ERR_RET(TSDB_CODE_INVALID_PARA);
}
if (ctx->pBlockList) {
SSDataBlock* pb = taosArrayGetP(ctx->pBlockList, 0);
rowNum = pb->info.rows;
}
SCL_ERR_JRET(sclCreateColumnInfoData(&node->node.resType, rowNum, output));
SCL_ERR_JRET(sclGetNodeRes(node->pCase, ctx, &pCase));
SCL_ERR_JRET(sclGetNodeRes(node->pElse, ctx, &pElse));
SDataType compType = {0};
compType.type = TSDB_DATA_TYPE_BOOL;
compType.bytes = tDataTypes[compType.type].bytes;
SCL_ERR_JRET(sclCreateColumnInfoData(&compType, rowNum, pComp));
SNode* tnode = NULL;
FOREACH(tnode, node->pWhenThenList) {
if (QUERY_NODE_VALUE == tnode->type) {
return DEAL_RES_CONTINUE;
SWhenThenNode* pWhenThen = (SWhenThenNode*)node->pWhenThenList->pHead->pNode;
SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pWhen, ctx, &pWhen));
SCL_ERR_JRET(sclGetNodeRes(pWhenThen->pThen, ctx, &pThen));
if (pCase) {
vectorCompare(pCase, pWhen, pComp, TSDB_ORDER_ASC, OP_TYPE_EQUAL);
for (int32_t i = 0; i < rowNum; ++i) {
bool *equal = colDataGetData(pComp->columnData, i);
if (*equal) {
colDataAppend(output->columnData, i, colDataGetData(pThen, (pThen->numOfRows > 1 ? i : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? i : 0)));
} else {
SCL_ERR_JRET(sclWalkCaseWhenList(ctx, node->pWhenThenList, node->pWhenThenList->pHead->pNext, pCase, pElse, pComp, output, i));
}
}
}
if (output->columnData == NULL) {
code = sclCreateColumnInfoData(&node->node.resType, rowNum, output);
if (code != TSDB_CODE_SUCCESS) {
SCL_ERR_JRET(code);
} else {
for (int32_t i = 0; i < rowNum; ++i) {
bool *equal = colDataGetData(pWhen->columnData, i);
if (*equal) {
colDataAppend(output->columnData, i, colDataGetData(pThen, (pThen->numOfRows > 1 ? i : 0)), colDataIsNull_s(pThen, (pThen->numOfRows > 1 ? i : 0)));
} else {
SCL_ERR_JRET(sclWalkWhenList(ctx, node->pWhenThenList, node->pWhenThenList->pHead->pNext, pElse, output, i));
}
}
}
_bin_scalar_fn_t OperatorFn = getBinScalarOperatorFn(node->opType);
sclFreeParam(pCase);
sclFreeParam(pElse);
sclFreeParam(pComp);
sclFreeParam(pWhen);
sclFreeParam(pThen);
int32_t paramNum = scalarGetOperatorParamNum(node->opType);
SScalarParam* pLeft = &params[0];
SScalarParam* pRight = paramNum > 1 ? &params[1] : NULL;
terrno = TSDB_CODE_SUCCESS;
OperatorFn(pLeft, pRight, output, TSDB_ORDER_ASC);
code = terrno;
return TSDB_CODE_SUCCESS;
_return:
sclFreeParamList(params, paramNum);
sclFreeParam(pCase);
sclFreeParam(pElse);
sclFreeParam(pComp);
sclFreeParam(pWhen);
sclFreeParam(pThen);
sclFreeParam(output);
SCL_RET(code);
}
......
......@@ -1646,8 +1646,8 @@ void vectorBitOr(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut,
doReleaseVec(pRightCol, rightConvert);
}
int32_t doVectorCompareImpl(int32_t numOfRows, SScalarParam *pOut, int32_t startIndex, int32_t step, __compar_fn_t fp,
SScalarParam *pLeft, SScalarParam *pRight, int32_t optr) {
int32_t doVectorCompareImpl(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t startIndex, int32_t numOfRows,
int32_t step, __compar_fn_t fp, int32_t optr) {
int32_t num = 0;
for (int32_t i = startIndex; i < numOfRows && i >= 0; i += step) {
......@@ -1700,12 +1700,14 @@ int32_t doVectorCompareImpl(int32_t numOfRows, SScalarParam *pOut, int32_t start
return num;
}
void vectorCompareImpl(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut, int32_t _ord, int32_t optr) {
int32_t i = ((_ord) == TSDB_ORDER_ASC) ? 0 : TMAX(pLeft->numOfRows, pRight->numOfRows) - 1;
void doVectorCompare(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut, int32_t startIndex, int32_t numOfRows,
int32_t _ord, int32_t optr) {
int32_t i = 0;
int32_t step = ((_ord) == TSDB_ORDER_ASC) ? 1 : -1;
int32_t lType = GET_PARAM_TYPE(pLeft);
int32_t rType = GET_PARAM_TYPE(pRight);
__compar_fn_t fp = NULL;
int32_t compRows = 0;
if (lType == rType) {
fp = filterGetCompFunc(lType, optr);
......@@ -1715,6 +1717,14 @@ void vectorCompareImpl(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *
pOut->numOfRows = TMAX(pLeft->numOfRows, pRight->numOfRows);
if (startIndex < 0) {
i = ((_ord) == TSDB_ORDER_ASC) ? 0 : TMAX(pLeft->numOfRows, pRight->numOfRows) - 1;
compRows = pOut->numOfRows;
} else {
compRows = startIndex + numOfRows;
i = startIndex;
}
if (pRight->pHashFilter != NULL) {
for (; i >= 0 && i < pLeft->numOfRows; i += step) {
if (IS_HELPER_NULL(pLeft->columnData, i)) {
......@@ -1731,7 +1741,7 @@ void vectorCompareImpl(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *
}
}
} else { // normal compare
pOut->numOfQualified = doVectorCompareImpl(pOut->numOfRows, pOut, i, step, fp, pLeft, pRight, optr);
pOut->numOfQualified = doVectorCompareImpl(pLeft, pRight, pOut, i, compRows, step, fp, optr);
}
}
......@@ -1760,7 +1770,8 @@ void vectorCompare(SScalarParam* pLeft, SScalarParam* pRight, SScalarParam *pOut
}
}
vectorCompareImpl(param1, param2, pOut, _ord, optr);
doVectorCompare(param1, param2, pOut, -1, -1, _ord, optr);
sclFreeParam(&pLeftOut);
sclFreeParam(&pRightOut);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册