From 257ea7782f24b463004be63a2450b0102d097267 Mon Sep 17 00:00:00 2001 From: "shengjun.li" <49774184+shengjun1985@users.noreply.github.com> Date: Mon, 1 Jun 2020 15:31:58 +0800 Subject: [PATCH] fix merge result (#2463) * fix merge result Signed-off-by: shengjun.li * fix tests Signed-off-by: shengjun.li --- core/src/scheduler/JobMgr.cpp | 7 ++----- core/src/scheduler/task/SearchTask.cpp | 17 ++--------------- core/unittest/db/test_delete.cpp | 6 +----- tests/milvus_python_test/entity/test_delete.py | 10 ++++------ 4 files changed, 9 insertions(+), 31 deletions(-) diff --git a/core/src/scheduler/JobMgr.cpp b/core/src/scheduler/JobMgr.cpp index 9534f1a4..0856c749 100644 --- a/core/src/scheduler/JobMgr.cpp +++ b/core/src/scheduler/JobMgr.cpp @@ -84,11 +84,8 @@ JobMgr::worker_function() { // TODO(zhiru): if the job is search by ids, pass any task where the ids don't exist auto search_job = std::dynamic_pointer_cast(job); if (search_job != nullptr) { - scheduler::ResultIds ids(search_job->nq() * search_job->topk(), -1); - scheduler::ResultDistances distances(search_job->nq() * search_job->topk(), - std::numeric_limits::max()); - search_job->GetResultIds() = ids; - search_job->GetResultDistances() = distances; + search_job->GetResultIds().resize(search_job->nq(), -1); + search_job->GetResultDistances().resize(search_job->nq(), std::numeric_limits::max()); if (search_job->vectors().float_data_.empty() && search_job->vectors().binary_data_.empty() && !search_job->vectors().id_array_.empty()) { diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index 154df0a9..595dca40 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -279,9 +279,7 @@ XSearchTask::Execute() { auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk; if (spec_k == 0) { LOG_ENGINE_WARNING_ << "Searching in an empty file. file location = " << file_->location_; - } - - { + } else { std::unique_lock lock(search_job->mutex()); search_job->vector_count() = nq; XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce, @@ -315,19 +313,8 @@ XSearchTask::Execute() { if (spec_k == 0) { LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", "search", 0, file_->location_.c_str()); - } - - { + } else { std::unique_lock lock(search_job->mutex()); - - if (search_job->GetResultIds().size() > spec_k) { - if (search_job->GetResultIds().front() == -1) { - // initialized results set - search_job->GetResultIds().resize(spec_k * nq); - search_job->GetResultDistances().resize(spec_k * nq); - } - } - XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce, search_job->GetResultIds(), search_job->GetResultDistances()); } diff --git a/core/unittest/db/test_delete.cpp b/core/unittest/db/test_delete.cpp index 1feb5b82..4c06721f 100644 --- a/core/unittest/db/test_delete.cpp +++ b/core/unittest/db/test_delete.cpp @@ -565,11 +565,7 @@ TEST_F(DeleteTest, delete_single_vector) { milvus::engine::ResultDistances result_distances; stat = db_->Query(dummy_context_, collection_info.collection_id_, tags, topk, json_params, xb, result_ids, result_distances); - ASSERT_TRUE(result_ids.empty()); - ASSERT_TRUE(result_distances.empty()); - // ASSERT_EQ(result_ids[0], -1); - // ASSERT_LT(result_distances[0], 1e-4); - // ASSERT_EQ(result_distances[0], std::numeric_limits::max()); + ASSERT_TRUE(result_ids.empty() || (result_ids[0] == -1)); } TEST_F(DeleteTest, delete_add_create_index) { diff --git a/tests/milvus_python_test/entity/test_delete.py b/tests/milvus_python_test/entity/test_delete.py index 5bf64e92..148291c3 100644 --- a/tests/milvus_python_test/entity/test_delete.py +++ b/tests/milvus_python_test/entity/test_delete.py @@ -61,7 +61,7 @@ class TestDeleteBase: status, res = connect.search(collection, top_k, vector, params=search_param) logging.getLogger().info(res) assert status.OK() - assert len(res) == 0 + assert len(res[0]) == 0 def test_delete_vector_multi_same_ids(self, connect, collection, get_simple_index): ''' @@ -83,7 +83,7 @@ class TestDeleteBase: status, res = connect.search(collection, top_k, [vectors[0]], params=search_param) logging.getLogger().info(res) assert status.OK() - assert len(res) == 0 + assert len(res[0]) == 0 def test_delete_vector_collection_count(self, connect, collection): ''' @@ -327,7 +327,7 @@ class TestDeleteIndexedVectors: status, res = connect.search(collection, top_k, vector, params=search_param) logging.getLogger().info(res) assert status.OK() - assert len(res) == 0 + assert len(res[0]) == 0 def test_insert_delete_vector(self, connect, collection, get_simple_index): ''' @@ -399,9 +399,7 @@ class TestDeleteBinary: status, res = connect.search(jac_collection, top_k, vector, params=search_param) logging.getLogger().info(res) assert status.OK() - assert len(res) == 0 - assert status.OK() - assert len(res) == 0 + assert len(res[0]) == 0 # TODO: soft delete def test_delete_vector_collection_count(self, connect, jac_collection): -- GitLab