From 9489e1400062c9f9b95dad97e82f3bd7ad6d11b3 Mon Sep 17 00:00:00 2001 From: zhagnlu <1542303831@qq.com> Date: Fri, 11 Aug 2023 15:31:29 +0800 Subject: [PATCH] Optimize multi logical exprs performance when meet some situations (#26265) Signed-off-by: luzhang Co-authored-by: luzhang --- .../src/query/visitors/ExecExprVisitor.cpp | 13 ++++ internal/core/unittest/test_expr.cpp | 77 +++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index ef753c6bc..8e1c011df 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -128,7 +128,20 @@ ExecExprVisitor::visit(LogicalUnaryExpr& expr) { void ExecExprVisitor::visit(LogicalBinaryExpr& expr) { using OpType = LogicalBinaryExpr::OpType; + auto skip_right_expr = [](const BitsetType& left_res, + const OpType& op_type) -> bool { + return (op_type == OpType::LogicalAnd && left_res.none()) || + (op_type == OpType::LogicalOr && left_res.all()); + }; + auto left = call_child(*expr.left_); + // skip execute right node for some situations + if (skip_right_expr(left, expr.op_type_)) { + AssertInfo(left.size() == row_count_, + "[ExecExprVisitor]Size of results not equal row count"); + bitset_opt_ = std::move(left); + return; + } auto right = call_child(*expr.right_); AssertInfo(left.size() == right.size(), "[ExecExprVisitor]Left size not equal to right size"); diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 0dd3ec46f..0e177a7f1 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -1358,6 +1358,83 @@ TEST(Expr, TestCompareExpr) { std::cout << "end compare test" << std::endl; } +TEST(Expr, TestMultiLogicalExprsOptimization) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + size_t N = 1000000; + auto raw_data = DataGen(schema, N); + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + ExecExprVisitor visitor(*seg, seg->get_row_count(), MAX_TIMESTAMP); + auto build_expr_with_optim = [&]() -> std::shared_ptr { + ExprPtr child1_expr = + std::make_unique>( + ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::LessThan, + -1, + proto::plan::GenericValue::ValCase::kInt64Val); + ExprPtr child2_expr = + std::make_unique>( + ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::NotEqual, + 100, + proto::plan::GenericValue::ValCase::kInt64Val); + return std::make_shared( + LogicalBinaryExpr::OpType::LogicalAnd, child1_expr, child2_expr); + }; + auto build_expr = [&]() -> std::shared_ptr { + ExprPtr child1_expr = + std::make_unique>( + ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + 10, + proto::plan::GenericValue::ValCase::kInt64Val); + ExprPtr child2_expr = + std::make_unique>( + ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::NotEqual, + 100, + proto::plan::GenericValue::ValCase::kInt64Val); + return std::make_shared( + LogicalBinaryExpr::OpType::LogicalAnd, child1_expr, child2_expr); + }; + auto start = std::chrono::steady_clock::now(); + auto expr = build_expr_with_optim(); + auto final = visitor.call_child(*expr); + auto cost_op = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + std::cout << "cost: " << cost_op << "us" << std::endl; + start = std::chrono::steady_clock::now(); + expr = build_expr(); + final = visitor.call_child(*expr); + auto cost = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + std::cout << "cost: " << cost << "us" << std::endl; + ASSERT_LT(cost_op, cost); +} + TEST(Expr, TestExprs) { using namespace milvus; using namespace milvus::query; -- GitLab