提交 0d33778a 编写于 作者: J jinhai

Merge branch 'branch-0.5.0' into '0.5.0'

#59 Topk result is incorrect for small dataset

See merge request megasearch/milvus!783

Former-commit-id: 1fe9c333a48b8436c3495efaca2bda06d7178f73
...@@ -307,71 +307,71 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s ...@@ -307,71 +307,71 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s
} }
} }
void //void
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, //XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { // uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
if (src_ids.empty() || src_distance.empty()) { // if (src_ids.empty() || src_distance.empty()) {
return; // return;
} // }
//
uint64_t output_k = std::min(topk, tar_input_k + src_input_k); // uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
std::vector<int64_t> id_buf(nq * output_k, -1); // std::vector<int64_t> id_buf(nq * output_k, -1);
std::vector<float> dist_buf(nq * output_k, 0.0); // std::vector<float> dist_buf(nq * output_k, 0.0);
//
uint64_t buf_k, src_k, tar_k; // uint64_t buf_k, src_k, tar_k;
uint64_t src_idx, tar_idx, buf_idx; // uint64_t src_idx, tar_idx, buf_idx;
uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; // uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i;
//
for (uint64_t i = 0; i < nq; i++) { // for (uint64_t i = 0; i < nq; i++) {
src_input_k_multi_i = src_input_k * i; // src_input_k_multi_i = src_input_k * i;
tar_input_k_multi_i = tar_input_k * i; // tar_input_k_multi_i = tar_input_k * i;
buf_k_multi_i = output_k * i; // buf_k_multi_i = output_k * i;
buf_k = src_k = tar_k = 0; // buf_k = src_k = tar_k = 0;
while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { // while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) {
src_idx = src_input_k_multi_i + src_k; // src_idx = src_input_k_multi_i + src_k;
tar_idx = tar_input_k_multi_i + tar_k; // tar_idx = tar_input_k_multi_i + tar_k;
buf_idx = buf_k_multi_i + buf_k; // buf_idx = buf_k_multi_i + buf_k;
if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || // if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) ||
(!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { // (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) {
id_buf[buf_idx] = src_ids[src_idx]; // id_buf[buf_idx] = src_ids[src_idx];
dist_buf[buf_idx] = src_distance[src_idx]; // dist_buf[buf_idx] = src_distance[src_idx];
src_k++; // src_k++;
} else { // } else {
id_buf[buf_idx] = tar_ids[tar_idx]; // id_buf[buf_idx] = tar_ids[tar_idx];
dist_buf[buf_idx] = tar_distance[tar_idx]; // dist_buf[buf_idx] = tar_distance[tar_idx];
tar_k++; // tar_k++;
} // }
buf_k++; // buf_k++;
} // }
//
if (buf_k < output_k) { // if (buf_k < output_k) {
if (src_k < src_input_k) { // if (src_k < src_input_k) {
while (buf_k < output_k && src_k < src_input_k) { // while (buf_k < output_k && src_k < src_input_k) {
src_idx = src_input_k_multi_i + src_k; // src_idx = src_input_k_multi_i + src_k;
buf_idx = buf_k_multi_i + buf_k; // buf_idx = buf_k_multi_i + buf_k;
id_buf[buf_idx] = src_ids[src_idx]; // id_buf[buf_idx] = src_ids[src_idx];
dist_buf[buf_idx] = src_distance[src_idx]; // dist_buf[buf_idx] = src_distance[src_idx];
src_k++; // src_k++;
buf_k++; // buf_k++;
} // }
} else { // } else {
while (buf_k < output_k && tar_k < tar_input_k) { // while (buf_k < output_k && tar_k < tar_input_k) {
tar_idx = tar_input_k_multi_i + tar_k; // tar_idx = tar_input_k_multi_i + tar_k;
buf_idx = buf_k_multi_i + buf_k; // buf_idx = buf_k_multi_i + buf_k;
id_buf[buf_idx] = tar_ids[tar_idx]; // id_buf[buf_idx] = tar_ids[tar_idx];
dist_buf[buf_idx] = tar_distance[tar_idx]; // dist_buf[buf_idx] = tar_distance[tar_idx];
tar_k++; // tar_k++;
buf_k++; // buf_k++;
} // }
} // }
} // }
} // }
//
tar_ids.swap(id_buf); // tar_ids.swap(id_buf);
tar_distance.swap(dist_buf); // tar_distance.swap(dist_buf);
tar_input_k = output_k; // tar_input_k = output_k;
} //}
} // namespace scheduler } // namespace scheduler
} // namespace milvus } // namespace milvus
...@@ -42,10 +42,10 @@ class XSearchTask : public Task { ...@@ -42,10 +42,10 @@ class XSearchTask : public Task {
MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
static void // static void
MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k,
uint64_t nq, uint64_t topk, bool ascending); // uint64_t nq, uint64_t topk, bool ascending);
public: public:
TableFileSchemaPtr file_; TableFileSchemaPtr file_;
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册