未验证 提交 c073aa0d 编写于 作者: C cai.zhang 提交者: GitHub

Fix bug for json_contains_all has multiple array elements (#26446)

Signed-off-by: Ncai.zhang <cai.zhang@zilliz.com>
上级 533f0ddf
......@@ -2067,10 +2067,21 @@ ExecExprVisitor::visit(AlwaysTrueExpr& expr) {
bitset_opt_ = std::move(res);
}
template <typename T>
bool
compareTwoJsonArray(simdjson::simdjson_result<simdjson::ondemand::array> arr1,
const proto::plan::Array& arr2) {
if (arr2.array_size() != arr1.count_elements()) {
compareTwoJsonArray(T arr1, const proto::plan::Array& arr2) {
int json_array_length = 0;
if constexpr (std::is_same_v<
T,
simdjson::simdjson_result<simdjson::ondemand::array>>) {
json_array_length = arr1.count_elements();
}
if constexpr (std::is_same_v<T,
std::vector<simdjson::simdjson_result<
simdjson::ondemand::value>>>) {
json_array_length = arr1.size();
}
if (arr2.array_size() != json_array_length) {
return false;
}
int i = 0;
......@@ -2165,13 +2176,19 @@ ExecExprVisitor::ExecJsonContainsArray(JsonContainsExpr& expr_raw)
if (array.error()) {
return false;
}
for (auto const& element : elements) {
for (auto&& it : array) {
auto val = it.get_array();
if (val.error()) {
continue;
}
if (compareTwoJsonArray(val, element)) {
for (auto&& it : array) {
auto val = it.get_array();
if (val.error()) {
continue;
}
std::vector<simdjson::simdjson_result<simdjson::ondemand::value>>
json_array;
json_array.reserve(val.count_elements());
for (auto&& e : val) {
json_array.emplace_back(e);
}
for (auto const& element : elements) {
if (compareTwoJsonArray(json_array, element)) {
return true;
}
}
......@@ -2322,33 +2339,39 @@ ExecExprVisitor::ExecJsonContainsAllArray(JsonContainsExpr& expr_raw)
elements_index.insert(i);
i++;
}
auto elem_func =
[&elements, &elements_index, &pointer](const milvus::Json& json) {
auto doc = json.doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
auto elem_func = [&elements, &elements_index, &pointer](
const milvus::Json& json) {
auto doc = json.doc();
auto array = doc.at_pointer(pointer).get_array();
if (array.error()) {
return false;
}
std::unordered_set<int> tmp_elements_index(elements_index);
for (auto&& it : array) {
auto val = it.get_array();
if (val.error()) {
continue;
}
std::unordered_set<int> tmp_elements_index(elements_index);
for (auto&& it : array) {
auto val = it.get_array();
if (val.error()) {
continue;
}
int i = -1;
for (auto const& element : elements) {
i++;
if (compareTwoJsonArray(val, element)) {
tmp_elements_index.erase(i);
break;
}
}
if (tmp_elements_index.size() == 0) {
return true;
std::vector<simdjson::simdjson_result<simdjson::ondemand::value>>
json_array;
json_array.reserve(val.count_elements());
for (auto&& e : val) {
json_array.emplace_back(e);
}
for (auto index : tmp_elements_index) {
if (compareTwoJsonArray(json_array, elements[index])) {
tmp_elements_index.erase(index);
// TODO: construct array set.
// prevent expression json_contains_all(json_array, [[1,2], [3,4], [1,2]]) being unsuccessful
// break;
}
}
return tmp_elements_index.size() == 0;
};
if (tmp_elements_index.size() == 0) {
return true;
}
}
return tmp_elements_index.size() == 0;
};
return ExecRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
......
......@@ -3997,6 +3997,164 @@ TEST(Expr, TestJsonContainsArray) {
ASSERT_EQ(ans, check(res, i));
}
}
proto::plan::Array sub_arr1;
sub_arr1.set_same_type(true);
proto::plan::GenericValue int_val11;
int_val11.set_int64_val(int64_t(1));
sub_arr1.add_array()->CopyFrom(int_val11);
proto::plan::GenericValue int_val12;
int_val12.set_int64_val(int64_t(2));
sub_arr1.add_array()->CopyFrom(int_val12);
proto::plan::Array sub_arr2;
sub_arr2.set_same_type(true);
proto::plan::GenericValue int_val21;
int_val21.set_int64_val(int64_t(3));
sub_arr2.add_array()->CopyFrom(int_val21);
proto::plan::GenericValue int_val22;
int_val22.set_int64_val(int64_t(4));
sub_arr2.add_array()->CopyFrom(int_val22);
std::vector<Testcase<proto::plan::Array>> diff_testcases2{{{sub_arr1, sub_arr2}, {"array2"}}};
for (auto& testcase : diff_testcases2) {
auto check = [&](const std::vector<bool>& values, int i) {
return true;
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ =
std::make_unique<JsonContainsExprImpl<proto::plan::Array>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
true,
proto::plan::JSONContainsExpr_JSONOp_ContainsAny,
proto::plan::GenericValue::ValCase::kArrayVal);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
std::vector<bool> res;
ASSERT_EQ(ans, check(res, i));
}
}
for (auto& testcase : diff_testcases2) {
auto check = [&](const std::vector<bool>& values, int i) {
return true;
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ =
std::make_unique<JsonContainsExprImpl<proto::plan::Array>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
true,
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
proto::plan::GenericValue::ValCase::kArrayVal);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
std::vector<bool> res;
ASSERT_EQ(ans, check(res, i));
}
}
proto::plan::Array sub_arr3;
sub_arr3.set_same_type(true);
proto::plan::GenericValue int_val31;
int_val31.set_int64_val(int64_t(5));
sub_arr3.add_array()->CopyFrom(int_val31);
proto::plan::GenericValue int_val32;
int_val32.set_int64_val(int64_t(6));
sub_arr3.add_array()->CopyFrom(int_val32);
proto::plan::Array sub_arr4;
sub_arr4.set_same_type(true);
proto::plan::GenericValue int_val41;
int_val41.set_int64_val(int64_t(7));
sub_arr4.add_array()->CopyFrom(int_val41);
proto::plan::GenericValue int_val42;
int_val42.set_int64_val(int64_t(8));
sub_arr4.add_array()->CopyFrom(int_val42);
std::vector<Testcase<proto::plan::Array>> diff_testcases3{{{sub_arr3, sub_arr4}, {"array2"}}};
for (auto& testcase : diff_testcases2) {
auto check = [&](const std::vector<bool>& values, int i) {
return true;
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ =
std::make_unique<JsonContainsExprImpl<proto::plan::Array>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
true,
proto::plan::JSONContainsExpr_JSONOp_ContainsAny,
proto::plan::GenericValue::ValCase::kArrayVal);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
std::vector<bool> res;
ASSERT_EQ(ans, check(res, i));
}
}
for (auto& testcase : diff_testcases2) {
auto check = [&](const std::vector<bool>& values, int i) {
return true;
};
RetrievePlanNode plan;
auto pointer = milvus::Json::pointer(testcase.nested_path);
plan.predicate_ =
std::make_unique<JsonContainsExprImpl<proto::plan::Array>>(
ColumnInfo(json_fid, DataType::JSON, testcase.nested_path),
testcase.term,
true,
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
proto::plan::GenericValue::ValCase::kArrayVal);
auto start = std::chrono::steady_clock::now();
auto final = visitor.call_child(*plan.predicate_.value());
std::cout << "cost"
<< std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count()
<< std::endl;
EXPECT_EQ(final.size(), N * num_iters);
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
std::vector<bool> res;
ASSERT_EQ(ans, check(res, i));
}
}
}
TEST(Expr, TestJsonContainsDiffType) {
......
......@@ -420,7 +420,8 @@ DataGenForJsonArray(SchemaPtr schema,
R"(],"double":[)" + join(doubleVec, ",") +
R"(],"string":[)" + join(stringVec, ",") +
R"(],"bool": [)" + join(boolVec, ",") +
R"(],"array": [)" + join(arrayVec, ",") + "]}";
R"(],"array": [)" + join(arrayVec, ",") +
R"(],"array2": [[1,2], [3,4]])" + "}";
//std::cout << str << std::endl;
data[i] = str;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册