提交 4a25e618 编写于 作者: J JinHai-CN

Merge remote-tracking branch 'upstream/0.5.0' into 0.5.0


Former-commit-id: c0093fc80e2c855f1965286e0b79af1b2d3b3ebe
...@@ -307,8 +307,8 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s ...@@ -307,8 +307,8 @@ XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const s
} }
} }
//void // void
//XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
// const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
// uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { // uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
// if (src_ids.empty() || src_distance.empty()) { // if (src_ids.empty() || src_distance.empty()) {
......
...@@ -42,10 +42,10 @@ class XSearchTask : public Task { ...@@ -42,10 +42,10 @@ class XSearchTask : public Task {
MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
// static void // static void
// MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
// const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k, // const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t
// uint64_t nq, uint64_t topk, bool ascending); // src_input_k, uint64_t nq, uint64_t topk, bool ascending);
public: public:
TableFileSchemaPtr file_; TableFileSchemaPtr file_;
......
...@@ -46,7 +46,7 @@ BuildResult(std::vector<int64_t>& output_ids, ...@@ -46,7 +46,7 @@ BuildResult(std::vector<int64_t>& output_ids,
output_distance[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48()); output_distance[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48());
} }
//insert invalid items //insert invalid items
for(uint64_t j = input_k; j < topk; j++) { for (uint64_t j = input_k; j < topk; j++) {
output_ids[i * topk + j] = -1; output_ids[i * topk + j] = -1;
output_distance[i * topk + j] = -1.0; output_distance[i * topk + j] = -1.0;
} }
...@@ -113,7 +113,7 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1, ...@@ -113,7 +113,7 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
uint64_t n = std::min(topk, result[i].size()); uint64_t n = std::min(topk, result[i].size());
for (uint64_t j = 0; j < n; j++) { for (uint64_t j = 0; j < n; j++) {
if(result[i][j].first < 0) { if (result[i][j].first < 0) {
continue; continue;
} }
if (src_vec[j] != result[i][j].second) { if (src_vec[j] != result[i][j].second) {
...@@ -126,7 +126,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1, ...@@ -126,7 +126,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
} // namespace } // namespace
void MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { void
MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) {
std::vector<int64_t> ids1, ids2; std::vector<int64_t> ids1, ids2;
std::vector<float> dist1, dist2; std::vector<float> dist1, dist2;
ms::ResultSet result; ms::ResultSet result;
...@@ -150,12 +151,12 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { ...@@ -150,12 +151,12 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) {
MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, false); MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, false);
/* test3, id1/dist1 small topk */ /* test3, id1/dist1 small topk */
MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, true); MergeTopkToResultSetTest(TOP_K / 2, TOP_K, NQ, TOP_K, true);
MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, false); MergeTopkToResultSetTest(TOP_K / 2, TOP_K, NQ, TOP_K, false);
/* test4, id1/dist1 small topk, id2/dist2 small topk */ /* test4, id1/dist1 small topk, id2/dist2 small topk */
MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true); MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, true);
MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false); MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, false);
} }
//void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { //void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) {
...@@ -224,8 +225,8 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -224,8 +225,8 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
std::vector<int32_t> thread_vec = {4, 8}; std::vector<int32_t> thread_vec = {4, 8};
std::vector<int32_t> nq_vec = {1, 10, 100}; std::vector<int32_t> nq_vec = {1, 10, 100};
std::vector<int32_t> topk_vec = {1, 4, 16, 64}; std::vector<int32_t> topk_vec = {1, 4, 16, 64};
int32_t NQ = nq_vec[nq_vec.size()-1]; int32_t NQ = nq_vec[nq_vec.size() - 1];
int32_t TOPK = topk_vec[topk_vec.size()-1]; int32_t TOPK = topk_vec[topk_vec.size() - 1];
std::vector<std::vector<int64_t>> id_vec; std::vector<std::vector<int64_t>> id_vec;
std::vector<std::vector<float>> dist_vec; std::vector<std::vector<float>> dist_vec;
...@@ -255,7 +256,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -255,7 +256,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
} }
std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " + std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " +
std::to_string(nq) + " " + std::to_string(top_k); std::to_string(nq) + " " + std::to_string(top_k);
milvus::TimeRecorder rc1(str1); milvus::TimeRecorder rc1(str1);
/////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册