From 508da260ca76778be11c3a35dfd19400ebbb7091 Mon Sep 17 00:00:00 2001 From: groot Date: Thu, 18 Jun 2020 09:36:04 +0800 Subject: [PATCH] fix #2578 (#2591) Signed-off-by: yhmo --- .../delivery/request/SearchCombineRequest.cpp | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/core/src/server/delivery/request/SearchCombineRequest.cpp b/core/src/server/delivery/request/SearchCombineRequest.cpp index fa0c6666..818695bb 100644 --- a/core/src/server/delivery/request/SearchCombineRequest.cpp +++ b/core/src/server/delivery/request/SearchCombineRequest.cpp @@ -384,14 +384,6 @@ SearchCombineRequest::OnExecute() { return status; } - // avoid memcpy crash, check id count = target vector count * topk - if (result_ids.size() != total_count * search_topk_) { - status = Status(DB_ERROR, "Result count doesn't match target vectors count"); - // let all request return - FreeRequests(status); - return status; - } - // avoid memcpy crash, check distance count = id count if (result_distances.size() != result_ids.size()) { status = Status(DB_ERROR, "Result distance and id count doesn't match"); @@ -401,18 +393,26 @@ SearchCombineRequest::OnExecute() { } // step 5: construct result array + // engine ensure each target vector has same count of id/distance pairs + size_t pair_each_vector = result_ids.size() / vectors_data_.vector_count_; offset = 0; for (auto& request : request_list_) { uint64_t count = request->VectorsData().vector_count_; int64_t topk = request->TopK(); - uint64_t element_cnt = count * topk; + uint64_t pair_cnt = (pair_each_vector > topk) ? topk : pair_each_vector; TopKQueryResult& result = request->QueryResult(); result.row_num_ = count; - result.id_list_.resize(element_cnt); - result.distance_list_.resize(element_cnt); - memcpy(result.id_list_.data(), result_ids.data() + offset, element_cnt * sizeof(int64_t)); - memcpy(result.distance_list_.data(), result_distances.data() + offset, element_cnt * sizeof(float)); - offset += (count * search_topk_); + result.id_list_.resize(count * pair_cnt); + result.distance_list_.resize(count * pair_cnt); + + for (uint64_t i = 0; i < count; i++) { + uint64_t poz = i * pair_cnt; + memcpy(result.id_list_.data() + poz, result_ids.data() + offset + poz, pair_cnt * sizeof(int64_t)); + memcpy(result.distance_list_.data() + poz, result_distances.data() + offset + poz, + pair_cnt * sizeof(float)); + } + + offset += count * pair_cnt; // let request return FreeRequest(request, Status::OK()); -- GitLab