未验证 提交 a287a2b3 编写于 作者: Z zhenshan.cao 提交者: GitHub

Return empty result in advance if all data filtered out (#18329) (#18438)

Signed-off-by: Nzhenshan.cao <zhenshan.cao@zilliz.com>
上级 a3ec44cd
......@@ -98,6 +98,11 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
segment->mask_with_timestamps(bitset_holder, timestamp_);
segment->mask_with_delete(bitset_holder, active_count, timestamp_);
// if bitset_holder is all 1's, we got empty result
if (bitset_holder.count() == bitset_holder.size()) {
search_result_opt_ = empty_search_result(num_queries, node.search_info_);
return;
}
BitsetView final_view = bitset_holder;
segment->vector_search(active_count, node.search_info_, src_data, num_queries, timestamp_, final_view,
search_result);
......@@ -128,6 +133,12 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) {
segment->mask_with_timestamps(bitset_holder, timestamp_);
segment->mask_with_delete(bitset_holder, active_count, timestamp_);
// if bitset_holder is all 1's, we got empty result
if (bitset_holder.count() == bitset_holder.size()) {
retrieve_result_opt_ = std::move(retrieve_result);
return;
}
BitsetView final_view = bitset_holder;
auto seg_offsets = segment->search_ids(final_view, timestamp_);
retrieve_result.result_offsets_.assign((int64_t*)seg_offsets.data(),
......
......@@ -110,7 +110,6 @@ generate_max_float_query_data(int all_nq, int max_float_nq) {
}
auto blob = raw_group.SerializeAsString();
return blob;
}
std::string
......@@ -1155,14 +1154,12 @@ TEST(CApiTest, ReudceNullResult) {
EXPECT_EQ(size, num_queries / 2);
DeleteSearchResult(res);
}
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, ReduceRemoveDuplicates) {
......
......@@ -312,9 +312,9 @@ TEST(Indexing, BinaryBruteForce) {
auto bin_vec = dataset.get_col<uint8_t>(vec_fid);
auto query_data = 1024 * dim / 8 + bin_vec.data();
query::dataset::SearchDataset search_dataset{
metric_type, //
num_queries, //
topk, //
metric_type, //
num_queries, //
topk, //
round_decimal,
dim, //
query_data //
......
......@@ -14,6 +14,7 @@
#include "knowhere/index/VecIndex.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexHNSW.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "segcore/SegmentSealedImpl.h"
#include "test_utils/DataGen.h"
......@@ -79,14 +80,12 @@ TEST(Sealed, without_predicate) {
auto pre_result = SearchResultToJson(*sr);
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::indexparam::NLIST, 100},
{knowhere::indexparam::NPROBE, 10},
{knowhere::meta::DEVICE_ID, 0}
};
auto conf = knowhere::Config{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::indexparam::NLIST, 100},
{knowhere::indexparam::NPROBE, 10},
{knowhere::meta::DEVICE_ID, 0}};
auto database = knowhere::GenDataset(N, dim, vec_col.data() + 1000 * dim);
indexing->Train(database, conf);
......@@ -99,7 +98,7 @@ TEST(Sealed, without_predicate) {
auto result = indexing->Query(query_dataset, conf, nullptr);
auto ids = knowhere::GetDatasetIDs(result); // for comparison
auto ids = knowhere::GetDatasetIDs(result); // for comparison
auto dis = knowhere::GetDatasetDistance(result); // for comparison
std::vector<int64_t> vec_ids(ids, ids + topK * num_queries);
std::vector<float> vec_dis(dis, dis + topK * num_queries);
......@@ -126,6 +125,9 @@ TEST(Sealed, without_predicate) {
std::cout << "post_result" << std::endl;
std::cout << post_result.dump(1);
// ASSERT_EQ(ref_result.dump(1), post_result.dump(1));
sr = sealed_segment->Search(plan.get(), ph_group.get(), 0);
EXPECT_EQ(sr->get_total_result_count(), 0);
}
TEST(Sealed, with_predicate) {
......@@ -186,14 +188,12 @@ TEST(Sealed, with_predicate) {
auto sr = segment->Search(plan.get(), ph_group.get(), time);
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::indexparam::NLIST, 100},
{knowhere::indexparam::NPROBE, 10},
{knowhere::meta::DEVICE_ID, 0}
};
auto conf = knowhere::Config{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::indexparam::NLIST, 100},
{knowhere::indexparam::NPROBE, 10},
{knowhere::meta::DEVICE_ID, 0}};
auto database = knowhere::GenDataset(N, dim, vec_col.data());
indexing->Train(database, conf);
......@@ -225,6 +225,115 @@ TEST(Sealed, with_predicate) {
}
}
TEST(Sealed, with_predicate_filter_all) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto dim = 16;
auto topK = 5;
// auto metric_type = MetricType::METRIC_L2;
auto metric_type = knowhere::metric::L2;
auto fake_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto i64_fid = schema->AddDebugField("counter", DataType::INT64);
schema->set_primary_field_id(i64_fid);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"counter": {
"GE": 42000,
"LT": 41999
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 6
}
}
}
]
}
})";
auto N = ROW_COUNT;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<float>(fake_id);
auto query_ptr = vec_col.data() + 42000 * dim;
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp time = 10000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
auto ivf_indexing = std::make_shared<knowhere::IVF>();
auto ivf_conf = knowhere::Config{{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::indexparam::NLIST, 100},
{knowhere::indexparam::NPROBE, 10},
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DEVICE_ID, 0}};
auto database = knowhere::GenDataset(N, dim, vec_col.data());
ivf_indexing->Train(database, ivf_conf);
ivf_indexing->AddWithoutIds(database, ivf_conf);
EXPECT_EQ(ivf_indexing->Count(), N);
EXPECT_EQ(ivf_indexing->Dim(), dim);
LoadIndexInfo load_info;
load_info.field_id = fake_id.get();
load_info.index = ivf_indexing;
load_info.index_params["metric_type"] = "L2";
// load index for vec field, load raw data for scalar filed
auto ivf_sealed_segment = SealedCreator(schema, dataset);
ivf_sealed_segment->DropFieldData(fake_id);
ivf_sealed_segment->LoadIndex(load_info);
auto sr = ivf_sealed_segment->Search(plan.get(), ph_group.get(), time);
EXPECT_EQ(sr->get_total_result_count(), 0);
auto hnsw_conf =
knowhere::Config{{knowhere::meta::DIM, dim}, {knowhere::meta::TOPK, topK},
{knowhere::indexparam::HNSW_M, 16}, {knowhere::indexparam::EFCONSTRUCTION, 200},
{knowhere::indexparam::EF, 200}, {knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DEVICE_ID, 0}};
auto hnsw_indexing = std::make_shared<knowhere::IndexHNSW>();
hnsw_indexing->Train(database, hnsw_conf);
hnsw_indexing->AddWithoutIds(database, hnsw_conf);
EXPECT_EQ(hnsw_indexing->Count(), N);
EXPECT_EQ(hnsw_indexing->Dim(), dim);
LoadIndexInfo hnsw_load_info;
hnsw_load_info.field_id = fake_id.get();
hnsw_load_info.index = hnsw_indexing;
hnsw_load_info.index_params["metric_type"] = "L2";
// load index for vec field, load raw data for scalar filed
auto hnsw_sealed_segment = SealedCreator(schema, dataset);
hnsw_sealed_segment->DropFieldData(fake_id);
hnsw_sealed_segment->LoadIndex(hnsw_load_info);
auto sr2 = hnsw_sealed_segment->Search(plan.get(), ph_group.get(), time);
EXPECT_EQ(sr2->get_total_result_count(), 0);
}
TEST(Sealed, LoadFieldData) {
auto dim = 16;
auto topK = 5;
......@@ -620,6 +729,9 @@ TEST(Sealed, BF) {
EXPECT_GT(ves[0].first, 0);
EXPECT_LE(ves[0].first, N);
EXPECT_LE(ves[0].second, dim);
auto result2 = segment->Search(plan.get(), ph_group.get(), 0);
EXPECT_EQ(result2->get_total_result_count(), 0);
}
TEST(Sealed, BF_Overflow) {
......@@ -637,7 +749,7 @@ TEST(Sealed, BF_Overflow) {
LoadFieldDataInfo load_info{100, base_arr.get(), N};
auto dataset = DataGen(schema, N);
auto segment = CreateSealedSegment(schema);
std::cout<< fake_id.get() <<std::endl;
std::cout << fake_id.get() << std::endl;
SealedLoadFieldData(dataset, *segment, {fake_id.get()});
segment->LoadFieldData(load_info);
......
......@@ -531,9 +531,9 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
query::dataset::SearchDataset search_dataset{
knowhere::metric::L2, //
num_queries, //
topk, //
knowhere::metric::L2, //
num_queries, //
topk, //
round_decimal,
dim, //
query_ptr //
......
......@@ -290,7 +290,7 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, const std::vector<float>& v
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(vecs[i*dim+d]);
vec.push_back(vecs[i * dim + d]);
}
value->add_values(vec.data(), vec.size() * sizeof(float));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册