未验证 提交 04fffb08 编写于 作者: A Aivin V. Solatorio 提交者: GitHub

Support arithmetic operations on numerical fields for scalar filtering (#16520)

Signed-off-by: NAivin V. Solatorio <avsolatorio@gmail.com>
上级 b82f6a1a
此差异已折叠。
此差异已折叠。
......@@ -112,6 +112,49 @@ enum class OpType {
NotEqual = 6,
};
enum class ArithOpType {
Unknown = 0,
Add = 1,
Sub = 2,
Mul = 3,
Div = 4,
Mod = 5,
};
static const std::map<std::string, ArithOpType> arith_op_mapping_ = {
// arith_op_name -> arith_op
{"add", ArithOpType::Add}, {"sub", ArithOpType::Sub}, {"mul", ArithOpType::Mul},
{"div", ArithOpType::Div}, {"mod", ArithOpType::Mod},
};
static const std::map<ArithOpType, std::string> mapping_arith_op_ = {
// arith_op_name -> arith_op
{ArithOpType::Add, "add"}, {ArithOpType::Sub, "sub"}, {ArithOpType::Mul, "mul"},
{ArithOpType::Div, "div"}, {ArithOpType::Mod, "mod"},
};
struct BinaryArithOpEvalRangeExpr : Expr {
const FieldOffset field_offset_;
const DataType data_type_;
const OpType op_type_;
const ArithOpType arith_op_;
protected:
// prevent accidential instantiation
BinaryArithOpEvalRangeExpr() = delete;
BinaryArithOpEvalRangeExpr(const FieldOffset field_offset,
const DataType data_type,
const OpType op_type,
const ArithOpType arith_op)
: field_offset_(field_offset), data_type_(data_type), op_type_(op_type), arith_op_(arith_op) {
}
public:
void
accept(ExprVisitor&) override;
};
static const std::map<std::string, OpType> mapping_ = {
// op_name -> op
{"lt", OpType::LessThan}, {"le", OpType::LessEqual}, {"lte", OpType::LessEqual},
......
......@@ -33,6 +33,23 @@ struct TermExprImpl : TermExpr {
}
};
template <typename T>
struct BinaryArithOpEvalRangeExprImpl : BinaryArithOpEvalRangeExpr {
const T right_operand_;
const T value_;
BinaryArithOpEvalRangeExprImpl(const FieldOffset field_offset,
const DataType data_type,
const ArithOpType arith_op,
const T right_operand,
const OpType op_type,
const T value)
: BinaryArithOpEvalRangeExpr(field_offset, data_type, op_type, arith_op),
right_operand_(right_operand),
value_(value) {
}
};
template <typename T>
struct UnaryRangeExprImpl : UnaryRangeExpr {
const T value_;
......
......@@ -243,6 +243,65 @@ Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) {
auto item = body.begin();
auto op_name = boost::algorithm::to_lower_copy(std::string(item.key()));
AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found");
// This is an expression with an arithmetic operation
if (item.value().is_object()) {
/* // This is the expected DSL expression
{
range: {
field_name: {
op: {
arith_op: {
right_operand: operand,
value: value
},
}
}
}
}
EXAMPLE:
{
range: {
field_name: {
"EQ": {
"ADD": {
right_operand: 10,
value: 25
},
}
}
}
}
*/
auto arith = item.value();
auto arith_body = arith.begin();
auto arith_op_name = boost::algorithm::to_lower_copy(std::string(arith_body.key()));
AssertInfo(arith_op_mapping_.count(arith_op_name), "arith op(" + arith_op_name + ") not found");
auto& arith_op_body = arith_body.value();
Assert(arith_op_body.is_object());
auto right_operand = arith_op_body["right_operand"];
auto value = arith_op_body["value"];
if constexpr (std::is_same_v<T, bool>) {
throw std::runtime_error("bool type is not supported");
} else if constexpr (std::is_integral_v<T>) {
Assert(right_operand.is_number_integer());
Assert(value.is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(right_operand.is_number());
Assert(value.is_number());
} else {
static_assert(always_false<T>, "unsupported type");
}
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
schema.get_offset(field_name), schema[field_name].get_data_type(), arith_op_mapping_.at(arith_op_name),
right_operand, mapping_.at(op_name), value);
}
if constexpr (std::is_same_v<T, bool>) {
Assert(item.value().is_boolean());
} else if constexpr (std::is_integral_v<T>) {
......
......@@ -89,6 +89,31 @@ ExtractBinaryRangeExprImpl(FieldOffset field_offset, DataType data_type, const p
getValue(expr_proto.upper_value()));
}
template <typename T>
std::unique_ptr<BinaryArithOpEvalRangeExprImpl<T>>
ExtractBinaryArithOpEvalRangeExprImpl(FieldOffset field_offset,
DataType data_type,
const planpb::BinaryArithOpEvalRangeExpr& expr_proto) {
static_assert(std::is_fundamental_v<T>);
auto getValue = [&](const auto& value_proto) -> T {
if constexpr (std::is_same_v<T, bool>) {
// Handle bool here. Otherwise, it can go in `is_integral_v<T>`
static_assert(always_false<T>);
} else if constexpr (std::is_integral_v<T>) {
Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val);
return static_cast<T>(value_proto.int64_val());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal);
return static_cast<T>(value_proto.float_val());
} else {
static_assert(always_false<T>);
}
};
return std::make_unique<BinaryArithOpEvalRangeExprImpl<T>>(
field_offset, data_type, static_cast<ArithOpType>(expr_proto.arith_op()), getValue(expr_proto.right_operand()),
static_cast<OpType>(expr_proto.op()), getValue(expr_proto.value()));
}
std::unique_ptr<VectorPlanNode>
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
// TODO: add more buffs
......@@ -337,6 +362,42 @@ ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) {
return std::make_unique<LogicalBinaryExpr>(op, left_expr, right_expr);
}
ExprPtr
ProtoParser::ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto field_offset = schema.get_offset(field_id);
auto data_type = schema[field_offset].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
auto result = [&]() -> ExprPtr {
switch (data_type) {
case DataType::INT8: {
return ExtractBinaryArithOpEvalRangeExprImpl<int8_t>(field_offset, data_type, expr_pb);
}
case DataType::INT16: {
return ExtractBinaryArithOpEvalRangeExprImpl<int16_t>(field_offset, data_type, expr_pb);
}
case DataType::INT32: {
return ExtractBinaryArithOpEvalRangeExprImpl<int32_t>(field_offset, data_type, expr_pb);
}
case DataType::INT64: {
return ExtractBinaryArithOpEvalRangeExprImpl<int64_t>(field_offset, data_type, expr_pb);
}
case DataType::FLOAT: {
return ExtractBinaryArithOpEvalRangeExprImpl<float>(field_offset, data_type, expr_pb);
}
case DataType::DOUBLE: {
return ExtractBinaryArithOpEvalRangeExprImpl<double>(field_offset, data_type, expr_pb);
}
default: {
PanicInfo("unsupported data type");
}
}
}();
return result;
}
ExprPtr
ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) {
using ppe = proto::plan::Expr;
......@@ -359,6 +420,9 @@ ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) {
case ppe::kCompareExpr: {
return ParseCompareExpr(expr_pb.compare_expr());
}
case ppe::kBinaryArithOpEvalRangeExpr: {
return ParseBinaryArithOpEvalRangeExpr(expr_pb.binary_arith_op_eval_range_expr());
}
default:
PanicInfo("unsupported expr proto node");
}
......
......@@ -29,6 +29,9 @@ class ProtoParser {
// ExprPtr
// ExprFromProto(const proto::plan::Expr& expr_proto);
ExprPtr
ParseBinaryArithOpEvalRangeExpr(const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb);
ExprPtr
ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb);
......
......@@ -35,6 +35,9 @@ class ExecExprVisitor : public ExprVisitor {
void
visit(UnaryRangeExpr& expr) override;
void
visit(BinaryArithOpEvalRangeExpr& expr) override;
void
visit(BinaryRangeExpr& expr) override;
......@@ -61,10 +64,18 @@ class ExecExprVisitor : public ExprVisitor {
auto
ExecRangeVisitorImpl(FieldOffset field_offset, IndexFunc func, ElementFunc element_func) -> BitsetType;
template <typename T, typename ElementFunc>
auto
ExecDataRangeVisitorImpl(FieldOffset field_offset, ElementFunc element_func) -> BitsetType;
template <typename T>
auto
ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType;
template <typename T>
auto
ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
template <typename T>
auto
ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType;
......
......@@ -35,6 +35,10 @@ UnaryRangeExpr::accept(ExprVisitor& visitor) {
visitor.visit(*this);
}
void
BinaryArithOpEvalRangeExpr::accept(ExprVisitor& visitor) {
visitor.visit(*this);
}
void
BinaryRangeExpr::accept(ExprVisitor& visitor) {
visitor.visit(*this);
......
......@@ -31,6 +31,9 @@ class ExprVisitor {
virtual void
visit(UnaryRangeExpr&) = 0;
virtual void
visit(BinaryArithOpEvalRangeExpr&) = 0;
virtual void
visit(BinaryRangeExpr&) = 0;
......
......@@ -30,6 +30,9 @@ class ExtractInfoExprVisitor : public ExprVisitor {
void
visit(UnaryRangeExpr& expr) override;
void
visit(BinaryArithOpEvalRangeExpr& expr) override;
void
visit(BinaryRangeExpr& expr) override;
......
......@@ -31,6 +31,9 @@ class ShowExprVisitor : public ExprVisitor {
void
visit(UnaryRangeExpr& expr) override;
void
visit(BinaryArithOpEvalRangeExpr& expr) override;
void
visit(BinaryRangeExpr& expr) override;
......@@ -39,6 +42,7 @@ class ShowExprVisitor : public ExprVisitor {
public:
Json
call_child(Expr& expr) {
assert(!json_opt_.has_value());
expr.accept(*this);
......
......@@ -35,6 +35,9 @@ class VerifyExprVisitor : public ExprVisitor {
void
visit(UnaryRangeExpr& expr) override;
void
visit(BinaryArithOpEvalRangeExpr& expr) override;
void
visit(BinaryRangeExpr& expr) override;
......
......@@ -48,6 +48,10 @@ class ExecExprVisitor : ExprVisitor {
auto
ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType;
template <typename T>
auto
ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType;
template <typename T>
auto
ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType;
......@@ -174,6 +178,31 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldOffset field_offset, IndexFunc index_
return final_result;
}
template <typename T, typename ElementFunc>
auto
ExecExprVisitor::ExecDataRangeVisitorImpl(FieldOffset field_offset, ElementFunc element_func) -> BitsetType {
auto& schema = segment_.get_schema();
auto& field_meta = schema[field_offset];
auto size_per_chunk = segment_.size_per_chunk();
auto num_chunk = upper_div(row_count_, size_per_chunk);
std::deque<BitsetType> results;
for (auto chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk;
BitsetType result(this_size);
auto chunk = segment_.chunk_data<T>(field_offset, chunk_id);
const T* data = chunk.data();
for (int index = 0; index < this_size; ++index) {
result[index] = element_func(data[index]);
}
AssertInfo(result.size() == this_size, "[ExecExprVisitor]Chunk result size not equal to expected size");
results.emplace_back(std::move(result));
}
auto final_result = Assemble(results);
AssertInfo(final_result.size() == row_count_, "[ExecExprVisitor]Final result size not equal to row count");
return final_result;
}
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
......@@ -222,6 +251,84 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> Bi
}
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
auto
ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher(BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType {
auto& expr = static_cast<BinaryArithOpEvalRangeExprImpl<T>&>(expr_raw);
using Index = scalar::ScalarIndex<T>;
auto arith_op = expr.arith_op_;
auto right_operand = expr.right_operand_;
auto op = expr.op_type_;
auto val = expr.value_;
switch (op) {
case OpType::Equal: {
switch (arith_op) {
case ArithOpType::Add: {
auto elem_func = [val, right_operand](T x) { return ((x + right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Sub: {
auto elem_func = [val, right_operand](T x) { return ((x - right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Mul: {
auto elem_func = [val, right_operand](T x) { return ((x * right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Div: {
auto elem_func = [val, right_operand](T x) { return ((x / right_operand) == val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Mod: {
auto elem_func = [val, right_operand](T x) {
return (static_cast<T>(fmod(x, right_operand)) == val);
};
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
default: {
PanicInfo("unsupported arithmetic operation");
}
}
}
case OpType::NotEqual: {
switch (arith_op) {
case ArithOpType::Add: {
auto elem_func = [val, right_operand](T x) { return ((x + right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Sub: {
auto elem_func = [val, right_operand](T x) { return ((x - right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Mul: {
auto elem_func = [val, right_operand](T x) { return ((x * right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Div: {
auto elem_func = [val, right_operand](T x) { return ((x / right_operand) != val); };
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
case ArithOpType::Mod: {
auto elem_func = [val, right_operand](T x) {
return (static_cast<T>(fmod(x, right_operand)) != val);
};
return ExecDataRangeVisitorImpl<T>(expr.field_offset_, elem_func);
}
default: {
PanicInfo("unsupported arithmetic operation");
}
}
}
default: {
PanicInfo("unsupported range node with arithmetic operation");
}
}
}
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
template <typename T>
......@@ -297,6 +404,44 @@ ExecExprVisitor::visit(UnaryRangeExpr& expr) {
bitset_opt_ = std::move(res);
}
void
ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
auto& field_meta = segment_.get_schema()[expr.field_offset_];
AssertInfo(expr.data_type_ == field_meta.get_data_type(),
"[ExecExprVisitor]DataType of expr isn't field_meta data type");
BitsetType res;
switch (expr.data_type_) {
case DataType::INT8: {
res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int8_t>(expr);
break;
}
case DataType::INT16: {
res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int16_t>(expr);
break;
}
case DataType::INT32: {
res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int32_t>(expr);
break;
}
case DataType::INT64: {
res = ExecBinaryArithOpEvalRangeVisitorDispatcher<int64_t>(expr);
break;
}
case DataType::FLOAT: {
res = ExecBinaryArithOpEvalRangeVisitorDispatcher<float>(expr);
break;
}
case DataType::DOUBLE: {
res = ExecBinaryArithOpEvalRangeVisitorDispatcher<double>(expr);
break;
}
default:
PanicInfo("unsupported");
}
AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count");
bitset_opt_ = std::move(res);
}
void
ExecExprVisitor::visit(BinaryRangeExpr& expr) {
auto& field_meta = segment_.get_schema()[expr.field_offset_];
......
......@@ -59,4 +59,9 @@ ExtractInfoExprVisitor::visit(CompareExpr& expr) {
plan_info_.add_involved_field(expr.right_field_offset_);
}
void
ExtractInfoExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
plan_info_.add_involved_field(expr.field_offset_);
}
} // namespace milvus::query
......@@ -248,4 +248,53 @@ ShowExprVisitor::visit(CompareExpr& expr) {
json_opt_ = res;
}
template <typename T>
static Json
BinaryArithOpEvalRangeExtract(const BinaryArithOpEvalRangeExpr& expr_raw) {
using proto::plan::ArithOpType;
using proto::plan::ArithOpType_Name;
using proto::plan::OpType;
using proto::plan::OpType_Name;
auto expr = dynamic_cast<const BinaryArithOpEvalRangeExprImpl<T>*>(&expr_raw);
AssertInfo(expr, "[ShowExprVisitor]BinaryArithOpEvalRangeExpr cast to BinaryArithOpEvalRangeExprImpl failed");
Json res{{"expr_type", "BinaryArithOpEvalRange"},
{"field_offset", expr->field_offset_.get()},
{"data_type", datatype_name(expr->data_type_)},
{"arith_op", ArithOpType_Name(static_cast<ArithOpType>(expr->arith_op_))},
{"right_operand", expr->right_operand_},
{"op", OpType_Name(static_cast<OpType>(expr->op_type_))},
{"value", expr->value_}};
return res;
}
void
ShowExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
AssertInfo(!json_opt_.has_value(), "[ShowExprVisitor]Ret json already has value before visit");
AssertInfo(datatype_is_vector(expr.data_type_) == false, "[ShowExprVisitor]Data type of expr isn't vector type");
switch (expr.data_type_) {
case DataType::INT8:
json_opt_ = BinaryArithOpEvalRangeExtract<int8_t>(expr);
return;
case DataType::INT16:
json_opt_ = BinaryArithOpEvalRangeExtract<int16_t>(expr);
return;
case DataType::INT32:
json_opt_ = BinaryArithOpEvalRangeExtract<int32_t>(expr);
return;
case DataType::INT64:
json_opt_ = BinaryArithOpEvalRangeExtract<int64_t>(expr);
return;
case DataType::DOUBLE:
json_opt_ = BinaryArithOpEvalRangeExtract<double>(expr);
return;
case DataType::FLOAT:
json_opt_ = BinaryArithOpEvalRangeExtract<float>(expr);
return;
default:
PanicInfo("unsupported type");
}
}
} // namespace milvus::query
......@@ -32,6 +32,11 @@ VerifyExprVisitor::visit(UnaryRangeExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(BinaryRangeExpr& expr) {
// TODO
......
......@@ -581,3 +581,364 @@ TEST(Expr, TestCompare) {
}
}
}
TEST(Expr, TestBinaryArithOpEvalRange) {
using namespace milvus::query;
using namespace milvus::segcore;
std::vector<std::tuple<std::string, std::function<bool(int)>, DataType>> testcases = {
// Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types
{R"("EQ": {
"ADD": {
"right_operand": 4,
"value": 8
}
})", [](int8_t v) { return (v + 4) == 8; }, DataType::INT8},
{R"("EQ": {
"SUB": {
"right_operand": 500,
"value": 1500
}
})", [](int16_t v) { return (v - 500) == 1500; }, DataType::INT16},
{R"("EQ": {
"MUL": {
"right_operand": 2,
"value": 4000
}
})", [](int32_t v) { return (v * 2) == 4000; }, DataType::INT32},
{R"("EQ": {
"DIV": {
"right_operand": 2,
"value": 1000
}
})", [](int64_t v) { return (v / 2) == 1000; }, DataType::INT64},
{R"("EQ": {
"MOD": {
"right_operand": 100,
"value": 0
}
})", [](int32_t v) { return (v % 100) == 0; }, DataType::INT32},
{R"("EQ": {
"ADD": {
"right_operand": 500,
"value": 2500
}
})", [](float v) { return (v + 500) == 2500; }, DataType::FLOAT},
{R"("EQ": {
"ADD": {
"right_operand": 500,
"value": 2500
}
})", [](double v) { return (v + 500) == 2500; }, DataType::DOUBLE},
// Add test cases for BinaryArithOpEvalRangeExpr NE of various data types
{R"("NE": {
"ADD": {
"right_operand": 500,
"value": 2500
}
})", [](float v) { return (v + 500) != 2500; }, DataType::FLOAT},
{R"("NE": {
"SUB": {
"right_operand": 500,
"value": 2500
}
})", [](double v) { return (v - 500) != 2500; }, DataType::DOUBLE},
{R"("NE": {
"MUL": {
"right_operand": 2,
"value": 2
}
})", [](int8_t v) { return (v * 2) != 2; }, DataType::INT8},
{R"("NE": {
"DIV": {
"right_operand": 2,
"value": 1000
}
})", [](int16_t v) { return (v / 2) != 1000; }, DataType::INT16},
{R"("NE": {
"MOD": {
"right_operand": 100,
"value": 0
}
})", [](int32_t v) { return (v % 100) != 0; }, DataType::INT32},
{R"("NE": {
"ADD": {
"right_operand": 500,
"value": 2500
}
})", [](int64_t v) { return (v + 500) != 2500; }, DataType::INT64},
};
std::string dsl_string_tmp = R"({
"bool": {
"must": [
{
"range": {
@@@@@
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
})";
std::string dsl_string_int8 = R"(
"age8": {
@@@@
})";
std::string dsl_string_int16 = R"(
"age16": {
@@@@
})";
std::string dsl_string_int32 = R"(
"age32": {
@@@@
})";
std::string dsl_string_int64 = R"(
"age64": {
@@@@
})";
std::string dsl_string_float = R"(
"age_float": {
@@@@
})";
std::string dsl_string_double = R"(
"age_double": {
@@@@
})";
auto schema = std::make_shared<Schema>();
schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddDebugField("age8", DataType::INT8);
schema->AddDebugField("age16", DataType::INT16);
schema->AddDebugField("age32", DataType::INT32);
schema->AddDebugField("age64", DataType::INT64);
schema->AddDebugField("age_float", DataType::FLOAT);
schema->AddDebugField("age_double", DataType::DOUBLE);
auto seg = CreateGrowingSegment(schema);
int N = 1000;
std::vector<int8_t> age8_col;
std::vector<int16_t> age16_col;
std::vector<int32_t> age32_col;
std::vector<int64_t> age64_col;
std::vector<float> age_float_col;
std::vector<double> age_double_col;
int num_iters = 100;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGen(schema, N, iter);
auto new_age8_col = raw_data.get_col<int8_t>(1);
auto new_age16_col = raw_data.get_col<int16_t>(2);
auto new_age32_col = raw_data.get_col<int32_t>(3);
auto new_age64_col = raw_data.get_col<int64_t>(4);
auto new_age_float_col = raw_data.get_col<float>(5);
auto new_age_double_col = raw_data.get_col<double>(6);
age8_col.insert(age8_col.end(), new_age8_col.begin(), new_age8_col.end());
age16_col.insert(age16_col.end(), new_age16_col.begin(), new_age16_col.end());
age32_col.insert(age32_col.end(), new_age32_col.begin(), new_age32_col.end());
age64_col.insert(age64_col.end(), new_age64_col.begin(), new_age64_col.end());
age_float_col.insert(age_float_col.end(), new_age_float_col.begin(), new_age_float_col.end());
age_double_col.insert(age_double_col.end(), new_age_double_col.begin(), new_age_double_col.end());
seg->PreInsert(N);
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
}
auto seg_promote = dynamic_cast<SegmentGrowingImpl*>(seg.get());
ExecExprVisitor visitor(*seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP);
for (auto [clause, ref_func, dtype] : testcases) {
auto loc = dsl_string_tmp.find("@@@@@");
auto dsl_string = dsl_string_tmp;
if (dtype == DataType::INT8) {
dsl_string.replace(loc, 5, dsl_string_int8);
} else if (dtype == DataType::INT16) {
dsl_string.replace(loc, 5, dsl_string_int16);
} else if (dtype == DataType::INT32) {
dsl_string.replace(loc, 5, dsl_string_int32);
} else if (dtype == DataType::INT64) {
dsl_string.replace(loc, 5, dsl_string_int64);
} else if (dtype == DataType::FLOAT) {
dsl_string.replace(loc, 5, dsl_string_float);
} else if (dtype == DataType::DOUBLE) {
dsl_string.replace(loc, 5, dsl_string_double);
} else {
ASSERT_TRUE(false) << "No test case defined for this data type";
}
loc = dsl_string.find("@@@@");
dsl_string.replace(loc, 4, clause);
auto plan = CreatePlan(*schema, dsl_string);
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), N * num_iters);
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
if (dtype == DataType::INT8) {
auto val = age8_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
} else if (dtype == DataType::INT16) {
auto val = age16_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
} else if (dtype == DataType::INT32) {
auto val = age32_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
} else if (dtype == DataType::INT64) {
auto val = age64_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
} else if (dtype == DataType::FLOAT) {
auto val = age_float_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
} else if (dtype == DataType::DOUBLE) {
auto val = age_double_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
} else {
ASSERT_TRUE(false) << "No test case defined for this data type";
}
}
}
}
TEST(Expr, TestBinaryArithOpEvalRangeExceptions) {
using namespace milvus::query;
using namespace milvus::segcore;
std::vector<std::tuple<std::string, std::string, DataType>> testcases = {
// Add test for data type mismatch
{R"("EQ": {
"ADD": {
"right_operand": 500,
"value": 2500.00
}
})", "Assert \"(value.is_number_integer())\"", DataType::INT32},
{R"("EQ": {
"ADD": {
"right_operand": 500.0,
"value": 2500
}
})", "Assert \"(right_operand.is_number_integer())\"", DataType::INT32},
{R"("EQ": {
"ADD": {
"right_operand": 500.0,
"value": true
}
})", "Assert \"(value.is_number())\"", DataType::FLOAT},
{R"("EQ": {
"ADD": {
"right_operand": "500",
"value": 2500.0
}
})", "Assert \"(right_operand.is_number())\"", DataType::FLOAT},
// Check unsupported arithmetic operator type
{R"("EQ": {
"EXP": {
"right_operand": 500,
"value": 2500
}
})", "arith op(exp) not found", DataType::INT32},
// Check unsupported data type
{R"("EQ": {
"ADD": {
"right_operand": true,
"value": false
}
})", "bool type is not supported", DataType::BOOL},
};
std::string dsl_string_tmp = R"({
"bool": {
"must": [
{
"range": {
@@@@@
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
})";
std::string dsl_string_int = R"(
"age": {
@@@@
})";
std::string dsl_string_num = R"(
"FloatN": {
@@@@
})";
std::string dsl_string_bool = R"(
"BoolField": {
@@@@
})";
auto schema = std::make_shared<Schema>();
schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddDebugField("age", DataType::INT32);
schema->AddDebugField("FloatN", DataType::FLOAT);
schema->AddDebugField("BoolField", DataType::BOOL);
for (auto [clause, assert_info, dtype] : testcases) {
auto loc = dsl_string_tmp.find("@@@@@");
auto dsl_string = dsl_string_tmp;
if (dtype == DataType::INT32) {
dsl_string.replace(loc, 5, dsl_string_int);
} else if (dtype == DataType::FLOAT) {
dsl_string.replace(loc, 5, dsl_string_num);
} else if (dtype == DataType::BOOL) {
dsl_string.replace(loc, 5, dsl_string_bool);
} else {
ASSERT_TRUE(false) << "No test case defined for this data type";
}
loc = dsl_string.find("@@@@");
dsl_string.replace(loc, 4, clause);
try {
auto plan = CreatePlan(*schema, dsl_string);
FAIL() << "Expected AssertionError: " << assert_info << " not thrown";
}
catch(const std::exception& err) {
std::string err_msg = err.what();
ASSERT_TRUE(err_msg.find(assert_info) != std::string::npos);
}
catch(...) {
FAIL() << "Expected AssertionError: " << assert_info << " not thrown";
}
}
}
......@@ -543,4 +543,96 @@ vector_anns: <
auto ref_plan = CreatePlan(*schema, dsl_text);
plan->check_identical(*ref_plan);
}
\ No newline at end of file
}
TEST_P(PlanProtoTest, BinaryArithOpEvalRange) {
// xxx.query(predicates = "int64field > 3", topk = 10, ...)
auto data_type = std::get<0>(GetParam());
auto data_type_str = spb::DataType_Name(data_type);
auto field_id = 100 + (int)data_type;
auto field_name = data_type_str + "Field";
string value_tag = "bool_val";
if (datatype_is_floating((DataType)data_type)) {
value_tag = "float_val";
} else if (datatype_is_integer((DataType)data_type)) {
value_tag = "int64_val";
}
auto fmt1 = boost::format(R"(
vector_anns: <
field_id: 201
predicates: <
binary_arith_op_eval_range_expr: <
column_info: <
field_id: %1%
data_type: %2%
>
arith_op: Add
right_operand: <
%3%: 1029
>
op: Equal
value: <
%3%: 2016
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>
)") % field_id % data_type_str %
value_tag;
auto proto_text = fmt1.str();
planpb::PlanNode node_proto;
google::protobuf::TextFormat::ParseFromString(proto_text, &node_proto);
// std::cout << node_proto.DebugString();
auto plan = ProtoParser(*schema).CreatePlan(node_proto);
ShowPlanNodeVisitor visitor;
auto json = visitor.call_child(*plan->plan_node_);
// std::cout << json.dump(2);
auto extra_info = plan->extra_info_opt_.value();
std::string dsl_text = boost::str(boost::format(R"(
{
"bool": {
"must": [
{
"range": {
"%1%": {
"EQ": {
"ADD": {
"right_operand": 1029,
"value": 2016
}
}
}
}
},
{
"vector": {
"FloatVectorField": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
]
}
}
)") % field_name);
auto ref_plan = CreatePlan(*schema, dsl_text);
plan->check_identical(*ref_plan);
}
......@@ -168,6 +168,22 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42, uint64_t ts_offset = 0)
insert_cols(data);
break;
}
case engine::DataType::INT16: {
vector<int16_t> data(N);
for (auto& x : data) {
x = er() % (2 * N);
}
insert_cols(data);
break;
}
case engine::DataType::INT8: {
vector<int8_t> data(N);
for (auto& x : data) {
x = er() % (2 * N);
}
insert_cols(data);
break;
}
case engine::DataType::FLOAT: {
vector<float> data(N);
for (auto& x : data) {
......
......@@ -14,6 +14,15 @@ enum OpType {
NotEqual = 6;
};
enum ArithOpType {
Unknown = 0;
Add = 1;
Sub = 2;
Mul = 3;
Div = 4;
Mod = 5;
};
message GenericValue {
oneof val {
bool bool_val = 1;
......@@ -82,6 +91,20 @@ message BinaryExpr {
Expr right = 3;
}
message BinaryArithOp {
ColumnInfo column_info = 1;
ArithOpType arith_op = 2;
GenericValue right_operand = 3;
}
message BinaryArithOpEvalRangeExpr {
ColumnInfo column_info = 1;
ArithOpType arith_op = 2;
GenericValue right_operand = 3;
OpType op = 4;
GenericValue value = 5;
}
message Expr {
oneof expr {
TermExpr term_expr = 1;
......@@ -90,6 +113,7 @@ message Expr {
CompareExpr compare_expr = 4;
UnaryRangeExpr unary_range_expr = 5;
BinaryRangeExpr binary_range_expr = 6;
BinaryArithOpEvalRangeExpr binary_arith_op_eval_range_expr = 7;
};
}
......
此差异已折叠。
......@@ -72,8 +72,17 @@ func (optimizer *optimizer) Exit(node *ant_ast.Node) {
floatNodeRight, rightFloat := node.Right.(*ant_ast.FloatNode)
integerNodeRight, rightInteger := node.Right.(*ant_ast.IntegerNode)
// Check IdentifierNodes
identifierNodeLeft, leftIdentifier := node.Left.(*ant_ast.IdentifierNode)
identifierNodeRight, rightIdentifier := node.Right.(*ant_ast.IdentifierNode)
switch node.Operator {
case "+":
funcName, err := getFuncNameByNodeOp(node.Operator)
if err != nil {
optimizer.err = err
return
}
if leftFloat && rightFloat {
patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value + floatNodeRight.Value})
} else if leftFloat && rightInteger {
......@@ -82,11 +91,24 @@ func (optimizer *optimizer) Exit(node *ant_ast.Node) {
patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) + floatNodeRight.Value})
} else if leftInteger && rightInteger {
patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value + integerNodeRight.Value})
} else if leftIdentifier && rightFloat {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}})
} else if leftIdentifier && rightInteger {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}})
} else if leftFloat && rightIdentifier {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, floatNodeLeft}})
} else if leftInteger && rightIdentifier {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, integerNodeLeft}})
} else {
optimizer.err = fmt.Errorf("invalid data type")
return
}
case "-":
funcName, err := getFuncNameByNodeOp(node.Operator)
if err != nil {
optimizer.err = err
return
}
if leftFloat && rightFloat {
patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value - floatNodeRight.Value})
} else if leftFloat && rightInteger {
......@@ -95,11 +117,26 @@ func (optimizer *optimizer) Exit(node *ant_ast.Node) {
patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) - floatNodeRight.Value})
} else if leftInteger && rightInteger {
patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value - integerNodeRight.Value})
} else if leftIdentifier && rightFloat {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}})
} else if leftIdentifier && rightInteger {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}})
} else if leftFloat && rightIdentifier {
optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator)
return
} else if leftInteger && rightIdentifier {
optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator)
return
} else {
optimizer.err = fmt.Errorf("invalid data type")
return
}
case "*":
funcName, err := getFuncNameByNodeOp(node.Operator)
if err != nil {
optimizer.err = err
return
}
if leftFloat && rightFloat {
patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value * floatNodeRight.Value})
} else if leftFloat && rightInteger {
......@@ -108,11 +145,24 @@ func (optimizer *optimizer) Exit(node *ant_ast.Node) {
patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) * floatNodeRight.Value})
} else if leftInteger && rightInteger {
patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value * integerNodeRight.Value})
} else if leftIdentifier && rightFloat {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}})
} else if leftIdentifier && rightInteger {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}})
} else if leftFloat && rightIdentifier {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, floatNodeLeft}})
} else if leftInteger && rightIdentifier {
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, integerNodeLeft}})
} else {
optimizer.err = fmt.Errorf("invalid data type")
return
}
case "/":
funcName, err := getFuncNameByNodeOp(node.Operator)
if err != nil {
optimizer.err = err
return
}
if leftFloat && rightFloat {
if floatNodeRight.Value == 0 {
optimizer.err = fmt.Errorf("divide by zero")
......@@ -137,17 +187,49 @@ func (optimizer *optimizer) Exit(node *ant_ast.Node) {
return
}
patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value / integerNodeRight.Value})
} else if leftIdentifier && rightFloat {
if floatNodeRight.Value == 0 {
optimizer.err = fmt.Errorf("divide by zero")
return
}
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}})
} else if leftIdentifier && rightInteger {
if integerNodeRight.Value == 0 {
optimizer.err = fmt.Errorf("divide by zero")
return
}
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}})
} else if leftFloat && rightIdentifier {
optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator)
return
} else if leftInteger && rightIdentifier {
optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator)
return
} else {
optimizer.err = fmt.Errorf("invalid data type")
return
}
case "%":
funcName, err := getFuncNameByNodeOp(node.Operator)
if err != nil {
optimizer.err = err
return
}
if leftInteger && rightInteger {
if integerNodeRight.Value == 0 {
optimizer.err = fmt.Errorf("modulo by zero")
return
}
patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value % integerNodeRight.Value})
} else if leftIdentifier && rightInteger {
if integerNodeRight.Value == 0 {
optimizer.err = fmt.Errorf("modulo by zero")
return
}
patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}})
} else if leftInteger && rightIdentifier {
optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator)
return
} else {
optimizer.err = fmt.Errorf("invalid data type")
return
......@@ -254,6 +336,46 @@ func getLogicalOpType(opStr string) planpb.BinaryExpr_BinaryOp {
}
}
func getArithOpType(funcName string) (planpb.ArithOpType, error) {
var op planpb.ArithOpType
switch funcName {
case "add":
op = planpb.ArithOpType_Add
case "sub":
op = planpb.ArithOpType_Sub
case "mul":
op = planpb.ArithOpType_Mul
case "div":
op = planpb.ArithOpType_Div
case "mod":
op = planpb.ArithOpType_Mod
default:
return op, fmt.Errorf("unsupported or invalid arith op type: %s", funcName)
}
return op, nil
}
func getFuncNameByNodeOp(nodeOp string) (string, error) {
var funcName string
switch nodeOp {
case "+":
funcName = "add"
case "-":
funcName = "sub"
case "*":
funcName = "mul"
case "/":
funcName = "div"
case "%":
funcName = "mod"
default:
return funcName, fmt.Errorf("no defined funcName assigned to nodeOp: %s", nodeOp)
}
return funcName, nil
}
func parseBoolNode(nodeRaw *ant_ast.Node) *ant_ast.BoolNode {
switch node := (*nodeRaw).(type) {
case *ant_ast.IdentifierNode:
......@@ -352,10 +474,54 @@ func (pc *parserContext) createCmpExpr(left, right ant_ast.Node, operator string
return expr, nil
}
func (pc *parserContext) createBinaryArithOpEvalExpr(left *ant_ast.FunctionNode, right *ant_ast.Node, operator string) (*planpb.Expr, error) {
switch operator {
case "==", "!=":
binArithOp, err := pc.handleFunction(left)
if err != nil {
return nil, fmt.Errorf("createBinaryArithOpEvalExpr: %v", err)
}
op := getCompareOpType(operator, false)
val, err := pc.handleLeafValue(right, binArithOp.ColumnInfo.DataType)
if err != nil {
return nil, err
}
expr := &planpb.Expr{
Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{
BinaryArithOpEvalRangeExpr: &planpb.BinaryArithOpEvalRangeExpr{
ColumnInfo: binArithOp.ColumnInfo,
ArithOp: binArithOp.ArithOp,
RightOperand: binArithOp.RightOperand,
Op: op,
Value: val,
},
},
}
return expr, nil
}
return nil, fmt.Errorf("operator(%s) not yet supported for function nodes", operator)
}
func (pc *parserContext) handleCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
return pc.createCmpExpr(node.Left, node.Right, node.Operator)
}
func (pc *parserContext) handleBinaryArithCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
leftNode, funcNodeLeft := node.Left.(*ant_ast.FunctionNode)
_, funcNodeRight := node.Right.(*ant_ast.FunctionNode)
if funcNodeRight {
return nil, fmt.Errorf("right node as a function is not supported yet")
} else if !funcNodeLeft {
// Both left and right are not function nodes, pass to createCmpExpr
return pc.createCmpExpr(node.Left, node.Right, node.Operator)
} else {
// Only the left node is a function node
return pc.createBinaryArithOpEvalExpr(leftNode, &node.Right, node.Operator)
}
}
func (pc *parserContext) handleLogicalExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
op := getLogicalOpType(node.Operator)
if op == planpb.BinaryExpr_Invalid {
......@@ -517,6 +683,12 @@ func (pc *parserContext) handleMultiCmpExpr(node *ant_ast.BinaryNode) (*planpb.E
}
func (pc *parserContext) handleBinaryExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
_, arithExpr := node.Left.(*ant_ast.FunctionNode)
if arithExpr {
return pc.handleBinaryArithCmpExpr(node)
}
switch node.Operator {
case "<", "<=", ">", ">=":
return pc.handleMultiCmpExpr(node)
......@@ -597,6 +769,37 @@ func (pc *parserContext) handleLeafValue(nodeRaw *ant_ast.Node, dataType schemap
return gv, nil
}
func (pc *parserContext) handleFunction(node *ant_ast.FunctionNode) (*planpb.BinaryArithOp, error) {
funcArithOp, err := getArithOpType(node.Name)
if err != nil {
return nil, err
}
idNode, ok := node.Arguments[0].(*ant_ast.IdentifierNode)
if !ok {
return nil, fmt.Errorf("left operand of the function must be an identifier")
}
field, err := pc.handleIdentifier(idNode)
if err != nil {
return nil, err
}
valueNode := node.Arguments[1]
val, err := pc.handleLeafValue(&valueNode, field.DataType)
if err != nil {
return nil, err
}
arithOp := &planpb.BinaryArithOp{
ColumnInfo: createColumnInfo(field),
ArithOp: funcArithOp,
RightOperand: val,
}
return arithOp, nil
}
func (pc *parserContext) handleIdentifier(node *ant_ast.IdentifierNode) (*schemapb.FieldSchema, error) {
fieldName := node.Value
field, err := pc.schema.GetFieldFromName(fieldName)
......
......@@ -157,6 +157,66 @@ func TestParseExpr_Naive(t *testing.T) {
assert.Nil(t, exprProto)
}
})
t.Run("test BinaryArithOpNode", func(t *testing.T) {
exprStrs := []string{
// "+"
"FloatField + 1.2 == 3",
"Int64Field + 3 == 5",
"1.2 + FloatField != 3",
"3 + Int64Field != 5",
// "-"
"FloatField - 1.2 == 3",
"Int64Field - 3 != 5",
// "*"
"FloatField * 1.2 == 3",
"Int64Field * 3 == 5",
"1.2 * FloatField != 3",
"3 * Int64Field != 5",
// "/"
"FloatField / 1.2 == 3",
"Int64Field / 3 != 5",
// "%"
"Int64Field % 7 == 5",
}
for _, exprStr := range exprStrs {
exprProto, err := parseExpr(schema, exprStr)
assert.Nil(t, err)
str := proto.MarshalTextString(exprProto)
println(str)
}
})
t.Run("test BinaryArithOpNode invalid", func(t *testing.T) {
exprStrs := []string{
// "+"
"FloatField + FloatField == 20",
"Int64Field + Int64Field != 10",
// "-"
"FloatField - FloatField == 20.0",
"Int64Field - Int64Field != 10",
"10 - FloatField == 20",
"20 - Int64Field != 10",
// "*"
"FloatField * FloatField == 20",
"Int64Field * Int64Field != 10",
// "/"
"FloatField / FloatField == 20",
"Int64Field / Int64Field != 10",
"FloatField / 0 == 20",
"Int64Field / 0 != 10",
// "%"
"Int64Field % Int64Field != 10",
"FloatField % 0 == 20",
"Int64Field % 0 != 10",
"FloatField % 2.3 == 20",
}
for _, exprStr := range exprStrs {
exprProto, err := parseExpr(schema, exprStr)
assert.Error(t, err)
assert.Nil(t, exprProto)
}
})
}
func TestParsePlanNode_Naive(t *testing.T) {
......@@ -330,6 +390,79 @@ func TestExprFieldCompare_Str(t *testing.T) {
}
}
func TestExprBinaryArithOp_Str(t *testing.T) {
exprStrs := []string{
// Basic arithmetic
"(age1 + 5) == 2",
// Float data type
"(FloatN - 5.2) == 0",
// Other operators
"(age1 - 5) == 1",
"(age1 * 5) == 6",
"(age1 / 5) == 1",
"(age1 % 5) == 0",
// Allow for commutative property for + and *
"(6 + age1) != 2",
"(age1 * 4) != 9",
"(5 * FloatN) != 0",
"(9 * FloatN) != 0",
}
unsupportedExprStrs := []string{
// Comparison operators except for "==" and "!=" are unsupported
"(age1 + 2) > 4",
"(age1 + 2) >= 4",
"(age1 + 2) < 4",
"(age1 + 2) <= 4",
// Functional nodes at the right of the comparison are not allowed
"0 == (age1 + 3)",
// Field as the right operand for -, /, and % operators are not supported
"(10 - age1) == 0",
"(20 / age1) == 0",
"(30 % age1) == 0",
// Modulo is not supported in the parser but the engine can handle it since fmod is used
"(FloatN % 2.1) == 0",
// Different data types are not supported
"(age1 + 20.16) == 35.16",
// Left operand of the function must be an identifier
"(10.5 / floatN) == 5.75",
}
fields := []*schemapb.FieldSchema{
{FieldID: 100, Name: "fakevec", DataType: schemapb.DataType_FloatVector},
{FieldID: 101, Name: "age1", DataType: schemapb.DataType_Int64},
{FieldID: 102, Name: "FloatN", DataType: schemapb.DataType_Float},
}
schema := &schemapb.CollectionSchema{
Name: "default-collection",
Description: "",
AutoID: true,
Fields: fields,
}
queryInfo := &planpb.QueryInfo{
Topk: 10,
MetricType: "L2",
SearchParams: "{\"nprobe\": 10}",
}
for offset, exprStr := range exprStrs {
fmt.Printf("case %d: %s\n", offset, exprStr)
planProto, err := createQueryPlan(schema, exprStr, "fakevec", queryInfo)
assert.Nil(t, err)
dbgStr := proto.MarshalTextString(planProto)
println(dbgStr)
}
for offset, exprStr := range unsupportedExprStrs {
fmt.Printf("case %d: %s\n", offset, exprStr)
planProto, err := createQueryPlan(schema, exprStr, "fakevec", queryInfo)
assert.Error(t, err)
dbgStr := proto.MarshalTextString(planProto)
println(dbgStr)
}
}
func TestPlanParseAPIs(t *testing.T) {
t.Run("get compare op type", func(t *testing.T) {
var op planpb.OpType
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册