提交 274d96f9 编写于 作者: Y yudong.cai

MS-606 optimize reduce API, update unittest


Former-commit-id: 1e79b7dba4b7c90e5218fc75e5cc552152ec4bbe
上级 bf1f7ce6
...@@ -155,8 +155,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) { ...@@ -155,8 +155,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
size_t file_size = index_engine_->PhysicalSize(); size_t file_size = index_engine_->PhysicalSize();
std::string info = "Load file id:" + std::to_string(file_->id_) + std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" +
" file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
" bytes from location: " + file_->location_ + " totally cost"; " bytes from location: " + file_->location_ + " totally cost";
double span = rc.ElapseFromBegin(info); double span = rc.ElapseFromBegin(info);
// for (auto &context : search_contexts_) { // for (auto &context : search_contexts_) {
...@@ -209,7 +209,8 @@ XSearchTask::Execute() { ...@@ -209,7 +209,8 @@ XSearchTask::Execute() {
// step 3: pick up topk result // step 3: pick up topk result
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
search_job->GetResult());
span = rc.RecordSection(hdr + ", reduce topk"); span = rc.RecordSection(hdr + ", reduce topk");
// search_job->AccumReduceCost(span); // search_job->AccumReduceCost(span);
...@@ -229,12 +230,8 @@ XSearchTask::Execute() { ...@@ -229,12 +230,8 @@ XSearchTask::Execute() {
} }
void void
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending,
uint64_t input_k,
uint64_t nq,
uint64_t topk,
bool ascending,
scheduler::ResultSet& result) { scheduler::ResultSet& result) {
if (result.empty()) { if (result.empty()) {
result.resize(nq); result.resize(nq);
...@@ -242,14 +239,14 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, ...@@ -242,14 +239,14 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
for (uint64_t i = 0; i < nq; i++) { for (uint64_t i = 0; i < nq; i++) {
scheduler::Id2DistVec result_buf; scheduler::Id2DistVec result_buf;
auto &result_i = result[i]; auto& result_i = result[i];
if (result[i].empty()) { if (result[i].empty()) {
result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0));
uint64_t input_k_multi_i = input_k * i; uint64_t input_k_multi_i = input_k * i;
for (auto k = 0; k < input_k; ++k) { for (auto k = 0; k < input_k; ++k) {
uint64_t idx = input_k_multi_i + k; uint64_t idx = input_k_multi_i + k;
auto &result_buf_item = result_buf[k]; auto& result_buf_item = result_buf[k];
result_buf_item.first = input_ids[idx]; result_buf_item.first = input_ids[idx];
result_buf_item.second = input_distance[idx]; result_buf_item.second = input_distance[idx];
} }
...@@ -262,8 +259,8 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, ...@@ -262,8 +259,8 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
uint64_t input_k_multi_i = input_k * i; uint64_t input_k_multi_i = input_k * i;
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
src_idx = input_k_multi_i + src_k; src_idx = input_k_multi_i + src_k;
auto &result_buf_item = result_buf[buf_k]; auto& result_buf_item = result_buf[buf_k];
auto &result_item = result_i[tar_k]; auto& result_item = result_i[tar_k];
if ((ascending && input_distance[src_idx] < result_item.second) || if ((ascending && input_distance[src_idx] < result_item.second) ||
(!ascending && input_distance[src_idx] > result_item.second)) { (!ascending && input_distance[src_idx] > result_item.second)) {
result_buf_item.first = input_ids[src_idx]; result_buf_item.first = input_ids[src_idx];
...@@ -280,7 +277,7 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, ...@@ -280,7 +277,7 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
if (src_k < input_k) { if (src_k < input_k) {
while (buf_k < output_k && src_k < input_k) { while (buf_k < output_k && src_k < input_k) {
src_idx = input_k_multi_i + src_k; src_idx = input_k_multi_i + src_k;
auto &result_buf_item = result_buf[buf_k]; auto& result_buf_item = result_buf[buf_k];
result_buf_item.first = input_ids[src_idx]; result_buf_item.first = input_ids[src_idx];
result_buf_item.second = input_distance[src_idx]; result_buf_item.second = input_distance[src_idx];
src_k++; src_k++;
...@@ -301,19 +298,15 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, ...@@ -301,19 +298,15 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids,
} }
void void
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
std::vector<float>& tar_distance, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
uint64_t& tar_input_k, uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
const std::vector<int64_t>& src_ids, if (src_ids.empty() || src_distance.empty()) {
const std::vector<float>& src_distance, return;
uint64_t src_input_k, }
uint64_t nq,
uint64_t topk, std::vector<int64_t> id_buf(nq * topk, -1);
bool ascending) { std::vector<float> dist_buf(nq * topk, 0.0);
if (src_ids.empty() || src_distance.empty()) return;
std::vector<int64_t> id_buf(nq*topk, -1);
std::vector<float> dist_buf(nq*topk, 0.0);
uint64_t output_k = std::min(topk, tar_input_k + src_input_k); uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
uint64_t buf_k, src_k, tar_k; uint64_t buf_k, src_k, tar_k;
......
...@@ -39,24 +39,13 @@ class XSearchTask : public Task { ...@@ -39,24 +39,13 @@ class XSearchTask : public Task {
public: public:
static void static void
MergeTopkToResultSet(const std::vector<int64_t>& input_ids, MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
const std::vector<float>& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
uint64_t input_k,
uint64_t nq,
uint64_t topk,
bool ascending,
scheduler::ResultSet& result);
static void static void
MergeTopkArray(std::vector<int64_t>& tar_ids, MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
std::vector<float>& tar_distance, const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k,
uint64_t& tar_input_k, uint64_t nq, uint64_t topk, bool ascending);
const std::vector<int64_t>& src_ids,
const std::vector<float>& src_distance,
uint64_t src_input_k,
uint64_t nq,
uint64_t topk,
bool ascending);
public: public:
TableFileSchemaPtr file_; TableFileSchemaPtr file_;
......
...@@ -28,20 +28,44 @@ namespace { ...@@ -28,20 +28,44 @@ namespace {
namespace ms = milvus::scheduler; namespace ms = milvus::scheduler;
void void
BuildResult(uint64_t nq, BuildResult(std::vector<int64_t>& output_ids,
std::vector<float>& output_distance,
uint64_t topk, uint64_t topk,
bool ascending, uint64_t nq,
std::vector<int64_t>& output_ids, bool ascending) {
std::vector<float>& output_distence) {
output_ids.clear(); output_ids.clear();
output_ids.resize(nq * topk); output_ids.resize(nq * topk);
output_distence.clear(); output_distance.clear();
output_distence.resize(nq * topk); output_distance.resize(nq * topk);
for (uint64_t i = 0; i < nq; i++) { for (uint64_t i = 0; i < nq; i++) {
for (uint64_t j = 0; j < topk; j++) { for (uint64_t j = 0; j < topk; j++) {
output_ids[i * topk + j] = (int64_t)(drand48() * 100000); output_ids[i * topk + j] = (int64_t)(drand48() * 100000);
output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48()); output_distance[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
}
}
}
void
CopyResult(std::vector<int64_t>& output_ids,
std::vector<float>& output_distance,
uint64_t output_topk,
std::vector<int64_t>& input_ids,
std::vector<float>& input_distance,
uint64_t input_topk,
uint64_t nq) {
ASSERT_TRUE(input_ids.size() >= nq * input_topk);
ASSERT_TRUE(input_distance.size() >= nq * input_topk);
ASSERT_TRUE(output_topk <= input_topk);
output_ids.clear();
output_ids.resize(nq * output_topk);
output_distance.clear();
output_distance.resize(nq * output_topk);
for (uint64_t i = 0; i < nq; i++) {
for (uint64_t j = 0; j < output_topk; j++) {
output_ids[i * output_topk + j] = input_ids[i * input_topk + j];
output_distance[i * output_topk + j] = input_distance[i * input_topk + j];
} }
} }
} }
...@@ -51,8 +75,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1, ...@@ -51,8 +75,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
const std::vector<float>& input_distance_1, const std::vector<float>& input_distance_1,
const std::vector<int64_t>& input_ids_2, const std::vector<int64_t>& input_ids_2,
const std::vector<float>& input_distance_2, const std::vector<float>& input_distance_2,
uint64_t nq,
uint64_t topk, uint64_t topk,
uint64_t nq,
bool ascending, bool ascending,
const milvus::scheduler::ResultSet& result) { const milvus::scheduler::ResultSet& result) {
ASSERT_EQ(result.size(), nq); ASSERT_EQ(result.size(), nq);
...@@ -96,32 +120,32 @@ TEST(DBSearchTest, TOPK_TEST) { ...@@ -96,32 +120,32 @@ TEST(DBSearchTest, TOPK_TEST) {
/* test1, id1/dist1 valid, id2/dist2 empty */ /* test1, id1/dist1 valid, id2/dist2 empty */
ascending = true; ascending = true;
BuildResult(NQ, TOP_K, ascending, ids1, dist1); BuildResult(ids1, dist1, TOP_K, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
/* test2, id1/dist1 valid, id2/dist2 valid */ /* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult(NQ, TOP_K, ascending, ids2, dist2); BuildResult(ids2, dist2, TOP_K, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
/* test3, id1/dist1 small topk */ /* test3, id1/dist1 small topk */
ids1.clear(); ids1.clear();
dist1.clear(); dist1.clear();
result.clear(); result.clear();
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
/* test4, id1/dist1 small topk, id2/dist2 small topk */ /* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2.clear(); ids2.clear();
dist2.clear(); dist2.clear();
result.clear(); result.clear();
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
///////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////
ascending = false; ascending = false;
...@@ -132,145 +156,199 @@ TEST(DBSearchTest, TOPK_TEST) { ...@@ -132,145 +156,199 @@ TEST(DBSearchTest, TOPK_TEST) {
result.clear(); result.clear();
/* test1, id1/dist1 valid, id2/dist2 empty */ /* test1, id1/dist1 valid, id2/dist2 empty */
BuildResult(NQ, TOP_K, ascending, ids1, dist1); BuildResult(ids1, dist1, TOP_K, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
/* test2, id1/dist1 valid, id2/dist2 valid */ /* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult(NQ, TOP_K, ascending, ids2, dist2); BuildResult(ids2, dist2, TOP_K, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
/* test3, id1/dist1 small topk */ /* test3, id1/dist1 small topk */
ids1.clear(); ids1.clear();
dist1.clear(); dist1.clear();
result.clear(); result.clear();
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
/* test4, id1/dist1 small topk, id2/dist2 small topk */ /* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2.clear(); ids2.clear();
dist2.clear(); dist2.clear();
result.clear(); result.clear();
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
} }
TEST(DBSearchTest, REDUCE_PERF_TEST) { TEST(DBSearchTest, REDUCE_PERF_TEST) {
int32_t nq = 100;
int32_t top_k = 1000;
int32_t index_file_num = 478; /* sift1B dataset, index files num */ int32_t index_file_num = 478; /* sift1B dataset, index files num */
bool ascending = true; bool ascending = true;
std::vector<int32_t> thread_vec = {4, 8, 11};
std::vector<int32_t> nq_vec = {1, 10, 100, 1000};
std::vector<int32_t> topk_vec = {1, 4, 16, 64, 256, 1024};
int32_t NQ = nq_vec[nq_vec.size()-1];
int32_t TOPK = topk_vec[topk_vec.size()-1];
std::vector<std::vector<int64_t>> id_vec; std::vector<std::vector<int64_t>> id_vec;
std::vector<std::vector<float>> dist_vec; std::vector<std::vector<float>> dist_vec;
std::vector<uint64_t> k_vec;
std::vector<int64_t> input_ids; std::vector<int64_t> input_ids;
std::vector<float> input_distance; std::vector<float> input_distance;
ms::ResultSet final_result, final_result_2, final_result_3;
int32_t i, k, step; int32_t i, k, step;
double reduce_cost = 0.0;
milvus::TimeRecorder rc("");
/* generate testing data */
for (i = 0; i < index_file_num; i++) { for (i = 0; i < index_file_num; i++) {
BuildResult(nq, top_k, ascending, input_ids, input_distance); BuildResult(input_ids, input_distance, TOPK, NQ, ascending);
id_vec.push_back(input_ids); id_vec.push_back(input_ids);
dist_vec.push_back(input_distance); dist_vec.push_back(input_distance);
k_vec.push_back(top_k);
} }
rc.RecordSection("Method-1 result reduce start"); for (int32_t max_thread_num : thread_vec) {
milvus::ThreadPool threadPool(max_thread_num);
/* method-1 */ std::list<std::future<void>> threads_list;
for (i = 0; i < index_file_num; i++) {
ms::XSearchTask::MergeTopkToResultSet(id_vec[i], dist_vec[i], k_vec[i], nq, top_k, ascending, final_result);
ASSERT_EQ(final_result.size(), nq);
}
reduce_cost = rc.RecordSection("Method-1 result reduce done");
std::cout << "Method-1: total reduce time " << reduce_cost/1000 << " ms" << std::endl;
/* method-2 */
std::vector<std::vector<int64_t>> id_vec_2(id_vec);
std::vector<std::vector<float>> dist_vec_2(dist_vec);
std::vector<uint64_t> k_vec_2(k_vec);
rc.RecordSection("Method-2 result reduce start"); for (int32_t nq : nq_vec) {
for (int32_t top_k : topk_vec) {
ms::ResultSet final_result, final_result_2, final_result_3;
for (step = 1; step < index_file_num; step *= 2) { std::vector<std::vector<int64_t>> id_vec_1(index_file_num);
for (i = 0; i+step < index_file_num; i += step*2) { std::vector<std::vector<float>> dist_vec_1(index_file_num);
ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], for (i = 0; i < index_file_num; i++) {
id_vec_2[i+step], dist_vec_2[i+step], k_vec_2[i+step], CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
nq, top_k, ascending); }
}
}
ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], dist_vec_2[0], k_vec_2[0], nq, top_k, ascending, final_result_2);
ASSERT_EQ(final_result_2.size(), nq);
reduce_cost = rc.RecordSection("Method-2 result reduce done");
std::cout << "Method-2: total reduce time " << reduce_cost/1000 << " ms" << std::endl;
for (i = 0; i < nq; i++) { std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " +
ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); std::to_string(nq) + " " + std::to_string(top_k);
for (k = 0; k < final_result.size(); k++) { milvus::TimeRecorder rc1(str1);
ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first);
ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); ///////////////////////////////////////////////////////////////////////////////////////
} /* method-1 */
} for (i = 0; i < index_file_num; i++) {
ms::XSearchTask::MergeTopkToResultSet(id_vec_1[i],
dist_vec_1[i],
top_k,
nq,
top_k,
ascending,
final_result);
ASSERT_EQ(final_result.size(), nq);
}
/* method-3 parallel */ rc1.RecordSection("reduce done");
std::vector<std::vector<int64_t>> id_vec_3(id_vec);
std::vector<std::vector<float>> dist_vec_3(dist_vec);
std::vector<uint64_t> k_vec_3(k_vec);
uint32_t max_thread_count = std::min(std::thread::hardware_concurrency() - 1, (uint32_t)MAX_THREADS_NUM); ///////////////////////////////////////////////////////////////////////////////////////
milvus::ThreadPool threadPool(max_thread_count); /* method-2 */
std::list<std::future<void>> threads_list; std::vector<std::vector<int64_t>> id_vec_2(index_file_num);
std::vector<std::vector<float>> dist_vec_2(index_file_num);
std::vector<uint64_t> k_vec_2(index_file_num);
for (i = 0; i < index_file_num; i++) {
CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
k_vec_2[i] = top_k;
}
rc.RecordSection("Method-3 parallel result reduce start"); std::string str2 = "Method-2 " + std::to_string(max_thread_num) + " " +
std::to_string(nq) + " " + std::to_string(top_k);
milvus::TimeRecorder rc2(str2);
for (step = 1; step < index_file_num; step *= 2) { for (step = 1; step < index_file_num; step *= 2) {
for (i = 0; i+step < index_file_num; i += step*2) { for (i = 0; i + step < index_file_num; i += step * 2) {
threads_list.push_back( ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i],
threadPool.enqueue(ms::XSearchTask::MergeTopkArray, id_vec_2[i + step], dist_vec_2[i + step], k_vec_2[i + step],
std::ref(id_vec_3[i]), std::ref(dist_vec_3[i]), std::ref(k_vec_3[i]), nq, top_k, ascending);
std::ref(id_vec_3[i+step]), std::ref(dist_vec_3[i+step]), std::ref(k_vec_3[i+step]), }
nq, top_k, ascending)); }
} ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0],
dist_vec_2[0],
k_vec_2[0],
nq,
top_k,
ascending,
final_result_2);
ASSERT_EQ(final_result_2.size(), nq);
rc2.RecordSection("reduce done");
for (i = 0; i < nq; i++) {
ASSERT_EQ(final_result[i].size(), final_result_2[i].size());
for (k = 0; k < final_result[i].size(); k++) {
if (final_result[i][k].first != final_result_2[i][k].first) {
std::cout << i << " " << k << std::endl;
}
ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first);
ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second);
}
}
while (threads_list.size() > 0) { ///////////////////////////////////////////////////////////////////////////////////////
int nready = 0; /* method-3 parallel */
for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { std::vector<std::vector<int64_t>> id_vec_3(index_file_num);
auto &p = *it; std::vector<std::vector<float>> dist_vec_3(index_file_num);
std::chrono::milliseconds span(0); std::vector<uint64_t> k_vec_3(index_file_num);
if (p.wait_for(span) == std::future_status::ready) { for (i = 0; i < index_file_num; i++) {
threads_list.erase(it++); CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
++nready; k_vec_3[i] = top_k;
} else {
++it;
} }
}
if (nready == 0) { std::string str3 = "Method-3 " + std::to_string(max_thread_num) + " " +
std::this_thread::yield(); std::to_string(nq) + " " + std::to_string(top_k);
milvus::TimeRecorder rc3(str3);
for (step = 1; step < index_file_num; step *= 2) {
for (i = 0; i + step < index_file_num; i += step * 2) {
threads_list.push_back(
threadPool.enqueue(ms::XSearchTask::MergeTopkArray,
std::ref(id_vec_3[i]),
std::ref(dist_vec_3[i]),
std::ref(k_vec_3[i]),
std::ref(id_vec_3[i + step]),
std::ref(dist_vec_3[i + step]),
std::ref(k_vec_3[i + step]),
nq,
top_k,
ascending));
}
while (threads_list.size() > 0) {
int nready = 0;
for (auto it = threads_list.begin(); it != threads_list.end(); it = it) {
auto &p = *it;
std::chrono::milliseconds span(0);
if (p.wait_for(span) == std::future_status::ready) {
threads_list.erase(it++);
++nready;
} else {
++it;
}
}
if (nready == 0) {
std::this_thread::yield();
}
}
}
ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0],
dist_vec_3[0],
k_vec_3[0],
nq,
top_k,
ascending,
final_result_3);
ASSERT_EQ(final_result_3.size(), nq);
rc3.RecordSection("reduce done");
for (i = 0; i < nq; i++) {
ASSERT_EQ(final_result[i].size(), final_result_3[i].size());
for (k = 0; k < final_result[i].size(); k++) {
ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first);
ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second);
}
}
} }
} }
} }
ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], dist_vec_3[0], k_vec_3[0], nq, top_k, ascending, final_result_3);
ASSERT_EQ(final_result_3.size(), nq);
reduce_cost = rc.RecordSection("Method-3 parallel result reduce done");
std::cout << "Method-3 parallel: total reduce time " << reduce_cost/1000 << " ms" << std::endl;
for (i = 0; i < nq; i++) {
ASSERT_EQ(final_result[i].size(), final_result_3[i].size());
for (k = 0; k < final_result.size(); k++) {
ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first);
ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second);
}
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册