提交 43d2609c 编写于 作者: J Jin Hai 提交者: GitHub

Merge pull request #172 from cydrain/caiyd_reduce_opt

#168 improve result reduce

Former-commit-id: 80644a5c84c295b90b5c20b921bd6eada6ea6e3c
...@@ -29,6 +29,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -29,6 +29,7 @@ Please mark all change in change log and use the ticket from JIRA.
- \#149 - Improve large query optimizer pass - \#149 - Improve large query optimizer pass
- \#156 - Not return error when search_resources and index_build_device set cpu - \#156 - Not return error when search_resources and index_build_device set cpu
- \#159 - Change the configuration name from 'use_gpu_threshold' to 'gpu_search_threshold' - \#159 - Change the configuration name from 'use_gpu_threshold' to 'gpu_search_threshold'
- \#168 - Improve result reduce
## Task ## Task
......
...@@ -67,15 +67,16 @@ class DB { ...@@ -67,15 +67,16 @@ class DB {
virtual Status virtual Status
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
QueryResults& results) = 0; ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status virtual Status
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) = 0; const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) = 0;
virtual Status virtual Status
Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) = 0; uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) = 0;
virtual Status virtual Status
Size(uint64_t& result) = 0; Size(uint64_t& result) = 0;
......
...@@ -336,20 +336,20 @@ DBImpl::DropIndex(const std::string& table_id) { ...@@ -336,20 +336,20 @@ DBImpl::DropIndex(const std::string& table_id) {
Status Status
DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
QueryResults& results) { ResultIds& result_ids, ResultDistances& result_distances) {
if (shutting_down_.load(std::memory_order_acquire)) { if (shutting_down_.load(std::memory_order_acquire)) {
return Status(DB_ERROR, "Milsvus server is shutdown!"); return Status(DB_ERROR, "Milsvus server is shutdown!");
} }
meta::DatesT dates = {utils::GetDate()}; meta::DatesT dates = {utils::GetDate()};
Status result = Query(table_id, k, nq, nprobe, vectors, dates, results); Status result = Query(table_id, k, nq, nprobe, vectors, dates, result_ids, result_distances);
return result; return result;
} }
Status Status
DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) { const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) {
if (shutting_down_.load(std::memory_order_acquire)) { if (shutting_down_.load(std::memory_order_acquire)) {
return Status(DB_ERROR, "Milsvus server is shutdown!"); return Status(DB_ERROR, "Milsvus server is shutdown!");
} }
...@@ -372,14 +372,15 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr ...@@ -372,14 +372,15 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr
} }
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results); status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
return status; return status;
} }
Status Status
DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) {
if (shutting_down_.load(std::memory_order_acquire)) { if (shutting_down_.load(std::memory_order_acquire)) {
return Status(DB_ERROR, "Milsvus server is shutdown!"); return Status(DB_ERROR, "Milsvus server is shutdown!");
} }
...@@ -413,7 +414,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ ...@@ -413,7 +414,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_
} }
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results); status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, result_ids, result_distances);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
return status; return status;
} }
...@@ -432,7 +433,7 @@ DBImpl::Size(uint64_t& result) { ...@@ -432,7 +433,7 @@ DBImpl::Size(uint64_t& result) {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status Status
DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, QueryResults& results) { uint64_t nprobe, const float* vectors, ResultIds& result_ids, ResultDistances& result_distances) {
server::CollectQueryMetrics metrics(nq); server::CollectQueryMetrics metrics(nq);
TimeRecorder rc(""); TimeRecorder rc("");
...@@ -453,7 +454,8 @@ DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& fi ...@@ -453,7 +454,8 @@ DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& fi
} }
// step 3: construct results // step 3: construct results
results = job->GetResult(); result_ids = job->GetResultIds();
result_distances = job->GetResultDistances();
rc.ElapseFromBegin("Engine query totally cost"); rc.ElapseFromBegin("Engine query totally cost");
return Status::OK(); return Status::OK();
......
...@@ -91,15 +91,16 @@ class DBImpl : public DB { ...@@ -91,15 +91,16 @@ class DBImpl : public DB {
Status Status
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
QueryResults& results) override; ResultIds& result_ids, ResultDistances& result_distances) override;
Status Status
Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) override; const meta::DatesT& dates, ResultIds& result_ids, ResultDistances& result_distances) override;
Status Status
Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq, Query(const std::string& table_id, const std::vector<std::string>& file_ids, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) override; uint64_t nprobe, const float* vectors, const meta::DatesT& dates, ResultIds& result_ids,
ResultDistances& result_distances) override;
Status Status
Size(uint64_t& result) override; Size(uint64_t& result) override;
...@@ -107,7 +108,7 @@ class DBImpl : public DB { ...@@ -107,7 +108,7 @@ class DBImpl : public DB {
private: private:
Status Status
QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, QueryResults& results); uint64_t nprobe, const float* vectors, ResultIds& result_ids, ResultDistances& result_distances);
void void
BackgroundTimerTask(); BackgroundTimerTask();
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "db/engine/ExecutionEngine.h" #include "db/engine/ExecutionEngine.h"
#include <faiss/Index.h>
#include <stdint.h> #include <stdint.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -26,12 +27,13 @@ ...@@ -26,12 +27,13 @@
namespace milvus { namespace milvus {
namespace engine { namespace engine {
typedef int64_t IDNumber; using IDNumber = faiss::Index::idx_t;
typedef IDNumber* IDNumberPtr; typedef IDNumber* IDNumberPtr;
typedef std::vector<IDNumber> IDNumbers; typedef std::vector<IDNumber> IDNumbers;
typedef std::vector<std::pair<IDNumber, double>> QueryResult; typedef std::vector<faiss::Index::idx_t> ResultIds;
typedef std::vector<QueryResult> QueryResults; typedef std::vector<faiss::Index::distance_t> ResultDistances;
struct TableIndex { struct TableIndex {
int32_t engine_type_ = (int)EngineType::FAISS_IDMAP; int32_t engine_type_ = (int)EngineType::FAISS_IDMAP;
......
...@@ -53,9 +53,14 @@ SearchJob::SearchDone(size_t index_id) { ...@@ -53,9 +53,14 @@ SearchJob::SearchDone(size_t index_id) {
SERVER_LOG_DEBUG << "SearchJob " << id() << " finish index file: " << index_id; SERVER_LOG_DEBUG << "SearchJob " << id() << " finish index file: " << index_id;
} }
ResultSet& ResultIds&
SearchJob::GetResult() { SearchJob::GetResultIds() {
return result_; return result_ids_;
}
ResultDistances&
SearchJob::GetResultDistances() {
return result_distances_;
} }
Status& Status&
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <vector> #include <vector>
#include "Job.h" #include "Job.h"
#include "db/Types.h"
#include "db/meta/MetaTypes.h" #include "db/meta/MetaTypes.h"
namespace milvus { namespace milvus {
...@@ -37,9 +38,9 @@ namespace scheduler { ...@@ -37,9 +38,9 @@ namespace scheduler {
using engine::meta::TableFileSchemaPtr; using engine::meta::TableFileSchemaPtr;
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>; using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
using IdDistPair = std::pair<int64_t, double>;
using Id2DistVec = std::vector<IdDistPair>; using ResultIds = engine::ResultIds;
using ResultSet = std::vector<Id2DistVec>; using ResultDistances = engine::ResultDistances;
class SearchJob : public Job { class SearchJob : public Job {
public: public:
...@@ -55,8 +56,11 @@ class SearchJob : public Job { ...@@ -55,8 +56,11 @@ class SearchJob : public Job {
void void
SearchDone(size_t index_id); SearchDone(size_t index_id);
ResultSet& ResultIds&
GetResult(); GetResultIds();
ResultDistances&
GetResultDistances();
Status& Status&
GetStatus(); GetStatus();
...@@ -104,7 +108,8 @@ class SearchJob : public Job { ...@@ -104,7 +108,8 @@ class SearchJob : public Job {
Id2IndexMap index_files_; Id2IndexMap index_files_;
// TODO: column-base better ? // TODO: column-base better ?
ResultSet result_; ResultIds result_ids_;
ResultDistances result_distances_;
Status status_; Status status_;
std::mutex mutex_; std::mutex mutex_;
......
...@@ -222,7 +222,7 @@ XSearchTask::Execute() { ...@@ -222,7 +222,7 @@ XSearchTask::Execute() {
{ {
std::unique_lock<std::mutex> lock(search_job->mutex()); std::unique_lock<std::mutex> lock(search_job->mutex());
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
search_job->GetResult()); search_job->GetResultIds(), search_job->GetResultDistances());
} }
span = rc.RecordSection(hdr + ", reduce topk"); span = rc.RecordSection(hdr + ", reduce topk");
...@@ -243,71 +243,69 @@ XSearchTask::Execute() { ...@@ -243,71 +243,69 @@ XSearchTask::Execute() {
} }
void void
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, XSearchTask::MergeTopkToResultSet(const scheduler::ResultIds& src_ids, const scheduler::ResultDistances& src_distances,
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, size_t src_k, size_t nq, size_t topk, bool ascending, scheduler::ResultIds& tar_ids,
scheduler::ResultSet& result) { scheduler::ResultDistances& tar_distances) {
if (result.empty()) { if (src_ids.empty()) {
result.resize(nq); return;
} }
size_t tar_k = tar_ids.size() / nq;
size_t buf_k = std::min(topk, src_k + tar_k);
scheduler::ResultIds buf_ids(nq * buf_k, -1);
scheduler::ResultDistances buf_distances(nq * buf_k, 0.0);
for (uint64_t i = 0; i < nq; i++) { for (uint64_t i = 0; i < nq; i++) {
scheduler::Id2DistVec result_buf; size_t buf_k_j = 0, src_k_j = 0, tar_k_j = 0;
auto& result_i = result[i]; size_t buf_idx, src_idx, tar_idx;
if (result[i].empty()) { size_t buf_k_multi_i = buf_k * i;
result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); size_t src_k_multi_i = topk * i;
uint64_t input_k_multi_i = topk * i; size_t tar_k_multi_i = tar_k * i;
for (auto k = 0; k < input_k; ++k) {
uint64_t idx = input_k_multi_i + k; while (buf_k_j < buf_k && src_k_j < src_k && tar_k_j < tar_k) {
auto& result_buf_item = result_buf[k]; src_idx = src_k_multi_i + src_k_j;
result_buf_item.first = input_ids[idx]; tar_idx = tar_k_multi_i + tar_k_j;
result_buf_item.second = input_distance[idx]; buf_idx = buf_k_multi_i + buf_k_j;
}
} else { if ((ascending && src_distances[src_idx] < tar_distances[tar_idx]) ||
size_t tar_size = result_i.size(); (!ascending && src_distances[src_idx] > tar_distances[tar_idx])) {
uint64_t output_k = std::min(topk, input_k + tar_size); buf_ids[buf_idx] = src_ids[src_idx];
result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); buf_distances[buf_idx] = src_distances[src_idx];
size_t buf_k = 0, src_k = 0, tar_k = 0; src_k_j++;
uint64_t src_idx; } else {
uint64_t input_k_multi_i = topk * i; buf_ids[buf_idx] = tar_ids[tar_idx];
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { buf_distances[buf_idx] = tar_distances[tar_idx];
src_idx = input_k_multi_i + src_k; tar_k_j++;
auto& result_buf_item = result_buf[buf_k];
auto& result_item = result_i[tar_k];
if ((ascending && input_distance[src_idx] < result_item.second) ||
(!ascending && input_distance[src_idx] > result_item.second)) {
result_buf_item.first = input_ids[src_idx];
result_buf_item.second = input_distance[src_idx];
src_k++;
} else {
result_buf_item = result_item;
tar_k++;
}
buf_k++;
} }
buf_k_j++;
}
if (buf_k < output_k) { if (buf_k_j < buf_k) {
if (src_k < input_k) { if (src_k_j < src_k) {
while (buf_k < output_k && src_k < input_k) { while (buf_k_j < buf_k && src_k_j < src_k) {
src_idx = input_k_multi_i + src_k; buf_idx = buf_k_multi_i + buf_k_j;
auto& result_buf_item = result_buf[buf_k]; src_idx = src_k_multi_i + src_k_j;
result_buf_item.first = input_ids[src_idx]; buf_ids[buf_idx] = src_ids[src_idx];
result_buf_item.second = input_distance[src_idx]; buf_distances[buf_idx] = src_distances[src_idx];
src_k++; src_k_j++;
buf_k++; buf_k_j++;
} }
} else { } else {
while (buf_k < output_k && tar_k < tar_size) { while (buf_k_j < buf_k && tar_k_j < tar_k) {
result_buf[buf_k] = result_i[tar_k]; buf_idx = buf_k_multi_i + buf_k_j;
tar_k++; tar_idx = tar_k_multi_i + tar_k_j;
buf_k++; buf_ids[buf_idx] = tar_ids[tar_idx];
} buf_distances[buf_idx] = tar_distances[tar_idx];
tar_k_j++;
buf_k_j++;
} }
} }
} }
result_i.swap(result_buf);
} }
tar_ids.swap(buf_ids);
tar_distances.swap(buf_distances);
} }
// void // void
......
...@@ -39,8 +39,9 @@ class XSearchTask : public Task { ...@@ -39,8 +39,9 @@ class XSearchTask : public Task {
public: public:
static void static void
MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, MergeTopkToResultSet(const scheduler::ResultIds& src_ids, const scheduler::ResultDistances& src_distances,
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); size_t src_k, size_t nq, size_t topk, bool ascending, scheduler::ResultIds& tar_ids,
scheduler::ResultDistances& tar_distances);
// static void // static void
// MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k, // MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
......
...@@ -637,7 +637,8 @@ SearchTask::OnExecute() { ...@@ -637,7 +637,8 @@ SearchTask::OnExecute() {
rc.RecordSection("prepare vector data"); rc.RecordSection("prepare vector data");
// step 6: search vectors // step 6: search vectors
engine::QueryResults results; engine::ResultIds result_ids;
engine::ResultDistances result_distances;
auto record_count = (uint64_t)search_param_->query_record_array().size(); auto record_count = (uint64_t)search_param_->query_record_array().size();
#ifdef MILVUS_ENABLE_PROFILING #ifdef MILVUS_ENABLE_PROFILING
...@@ -647,11 +648,11 @@ SearchTask::OnExecute() { ...@@ -647,11 +648,11 @@ SearchTask::OnExecute() {
#endif #endif
if (file_id_array_.empty()) { if (file_id_array_.empty()) {
status = status = DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates,
DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates, results); result_ids, result_distances);
} else { } else {
status = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t)top_k, record_count, nprobe, status = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t)top_k, record_count, nprobe,
vec_f.data(), dates, results); vec_f.data(), dates, result_ids, result_distances);
} }
#ifdef MILVUS_ENABLE_PROFILING #ifdef MILVUS_ENABLE_PROFILING
...@@ -663,23 +664,20 @@ SearchTask::OnExecute() { ...@@ -663,23 +664,20 @@ SearchTask::OnExecute() {
return status; return status;
} }
if (results.empty()) { if (result_ids.empty()) {
return Status::OK(); // empty table return Status::OK(); // empty table
} }
if (results.size() != record_count) { size_t result_k = result_ids.size() / record_count;
std::string msg = "Search " + std::to_string(record_count) + " vectors but only return " +
std::to_string(results.size()) + " results";
return Status(SERVER_ILLEGAL_SEARCH_RESULT, msg);
}
// step 7: construct result array // step 7: construct result array
for (auto& result : results) { for (size_t i = 0; i < record_count; i++) {
::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result(); ::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result();
for (auto& pair : result) { for (size_t j = 0; j < result_k; j++) {
::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays(); ::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays();
grpc_result->set_id(pair.first); size_t idx = i * result_k + j;
grpc_result->set_distance(pair.second); grpc_result->set_id(result_ids[idx]);
grpc_result->set_distance(result_distances[idx]);
} }
} }
......
...@@ -175,7 +175,8 @@ TEST_F(DBTest, DB_TEST) { ...@@ -175,7 +175,8 @@ TEST_F(DBTest, DB_TEST) {
BuildVectors(qb, qxb); BuildVectors(qb, qxb);
std::thread search([&]() { std::thread search([&]() {
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
int k = 10; int k = 10;
std::this_thread::sleep_for(std::chrono::seconds(2)); std::this_thread::sleep_for(std::chrono::seconds(2));
...@@ -190,17 +191,17 @@ TEST_F(DBTest, DB_TEST) { ...@@ -190,17 +191,17 @@ TEST_F(DBTest, DB_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M"; ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
for (auto k = 0; k < qb; ++k) { for (auto i = 0; i < qb; ++i) {
ASSERT_EQ(results[k][0].first, target_ids[k]); ASSERT_EQ(result_ids[i*k], target_ids[i]);
ss.str(""); ss.str("");
ss << "Result [" << k << "]:"; ss << "Result [" << i << "]:";
for (auto result : results[k]) { for (auto t = 0; t < k; t++) {
ss << result.first << " "; ss << result_ids[i * k + t] << " ";
} }
/* LOG(DEBUG) << ss.str(); */ /* LOG(DEBUG) << ss.str(); */
} }
...@@ -284,16 +285,18 @@ TEST_F(DBTest, SEARCH_TEST) { ...@@ -284,16 +285,18 @@ TEST_F(DBTest, SEARCH_TEST) {
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
{ {
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
} }
{//search by specify index file {//search by specify index file
milvus::engine::meta::DatesT dates; milvus::engine::meta::DatesT dates;
std::vector<std::string> file_ids = {"1", "2", "3", "4", "5", "6"}; std::vector<std::string> file_ids = {"1", "2", "3", "4", "5", "6"};
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids, result_distances);
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
} }
...@@ -303,22 +306,25 @@ TEST_F(DBTest, SEARCH_TEST) { ...@@ -303,22 +306,25 @@ TEST_F(DBTest, SEARCH_TEST) {
db_->CreateIndex(TABLE_NAME, index); // wait until build index finish db_->CreateIndex(TABLE_NAME, index); // wait until build index finish
{ {
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
} }
{ {
milvus::engine::QueryResults large_nq_results; milvus::engine::ResultIds result_ids;
stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), large_nq_results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
} }
{//search by specify index file {//search by specify index file
milvus::engine::meta::DatesT dates; milvus::engine::meta::DatesT dates;
std::vector<std::string> file_ids = {"1", "2", "3", "4", "5", "6"}; std::vector<std::string> file_ids = {"1", "2", "3", "4", "5", "6"};
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, result_ids, result_distances);
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
} }
...@@ -391,11 +397,12 @@ TEST_F(DBTest, SHUTDOWN_TEST) { ...@@ -391,11 +397,12 @@ TEST_F(DBTest, SHUTDOWN_TEST) {
ASSERT_FALSE(stat.ok()); ASSERT_FALSE(stat.ok());
milvus::engine::meta::DatesT dates; milvus::engine::meta::DatesT dates;
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, result_ids, result_distances);
ASSERT_FALSE(stat.ok()); ASSERT_FALSE(stat.ok());
std::vector<std::string> file_ids; std::vector<std::string> file_ids;
stat = db_->Query(table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, results); stat = db_->Query(table_info.table_id_, file_ids, 1, 1, 1, nullptr, dates, result_ids, result_distances);
ASSERT_FALSE(stat.ok()); ASSERT_FALSE(stat.ok());
stat = db_->DeleteTable(table_info.table_id_, dates); stat = db_->DeleteTable(table_info.table_id_, dates);
......
...@@ -81,7 +81,8 @@ TEST_F(MySqlDBTest, DB_TEST) { ...@@ -81,7 +81,8 @@ TEST_F(MySqlDBTest, DB_TEST) {
ASSERT_EQ(target_ids.size(), qb); ASSERT_EQ(target_ids.size(), qb);
std::thread search([&]() { std::thread search([&]() {
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
int k = 10; int k = 10;
std::this_thread::sleep_for(std::chrono::seconds(5)); std::this_thread::sleep_for(std::chrono::seconds(5));
...@@ -96,25 +97,25 @@ TEST_F(MySqlDBTest, DB_TEST) { ...@@ -96,25 +97,25 @@ TEST_F(MySqlDBTest, DB_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M"; ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
for (auto k = 0; k < qb; ++k) { for (auto i = 0; i < qb; ++i) {
// std::cout << results[k][0].first << " " << target_ids[k] << std::endl; // std::cout << results[k][0].first << " " << target_ids[k] << std::endl;
// ASSERT_EQ(results[k][0].first, target_ids[k]); // ASSERT_EQ(results[k][0].first, target_ids[k]);
bool exists = false; bool exists = false;
for (auto &result : results[k]) { for (auto t = 0; t < k; t++) {
if (result.first == target_ids[k]) { if (result_ids[i * k + t] == target_ids[i]) {
exists = true; exists = true;
} }
} }
ASSERT_TRUE(exists); ASSERT_TRUE(exists);
ss.str(""); ss.str("");
ss << "Result [" << k << "]:"; ss << "Result [" << i << "]:";
for (auto result : results[k]) { for (auto t = 0; t < k; t++) {
ss << result.first << " "; ss << result_ids[i * k + t] << " ";
} }
/* LOG(DEBUG) << ss.str(); */ /* LOG(DEBUG) << ss.str(); */
} }
...@@ -188,8 +189,9 @@ TEST_F(MySqlDBTest, SEARCH_TEST) { ...@@ -188,8 +189,9 @@ TEST_F(MySqlDBTest, SEARCH_TEST) {
sleep(2); // wait until build index finish sleep(2); // wait until build index finish
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); milvus::engine::ResultDistances result_distances;
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), result_ids, result_distances);
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
} }
......
...@@ -259,10 +259,11 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) { ...@@ -259,10 +259,11 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) {
int topk = 10, nprobe = 10; int topk = 10, nprobe = 10;
for (auto& pair : search_vectors) { for (auto& pair : search_vectors) {
auto& search = pair.second; auto& search = pair.second;
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), results); milvus::engine::ResultDistances result_distances;
ASSERT_EQ(results[0][0].first, pair.first); stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), result_ids, result_distances);
ASSERT_LT(results[0][0].second, 1e-4); ASSERT_EQ(result_ids[0], pair.first);
ASSERT_LT(result_distances[0], 1e-4);
} }
} }
...@@ -314,7 +315,8 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { ...@@ -314,7 +315,8 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
BuildVectors(qb, qxb); BuildVectors(qb, qxb);
std::thread search([&]() { std::thread search([&]() {
milvus::engine::QueryResults results; milvus::engine::ResultIds result_ids;
milvus::engine::ResultDistances result_distances;
int k = 10; int k = 10;
std::this_thread::sleep_for(std::chrono::seconds(2)); std::this_thread::sleep_for(std::chrono::seconds(2));
...@@ -329,17 +331,17 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { ...@@ -329,17 +331,17 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), results); stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), result_ids, result_distances);
ss << "Search " << j << " With Size " << count / milvus::engine::M << " M"; ss << "Search " << j << " With Size " << count / milvus::engine::M << " M";
STOP_TIMER(ss.str()); STOP_TIMER(ss.str());
ASSERT_TRUE(stat.ok()); ASSERT_TRUE(stat.ok());
for (auto k = 0; k < qb; ++k) { for (auto i = 0; i < qb; ++i) {
ASSERT_EQ(results[k][0].first, target_ids[k]); ASSERT_EQ(result_ids[i * k], target_ids[i]);
ss.str(""); ss.str("");
ss << "Result [" << k << "]:"; ss << "Result [" << i << "]:";
for (auto result : results[k]) { for (auto t = 0; t < k; t++) {
ss << result.first << " "; ss << result_ids[i * k + t] << " ";
} }
/* LOG(DEBUG) << ss.str(); */ /* LOG(DEBUG) << ss.str(); */
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "scheduler/job/SearchJob.h"
#include "scheduler/task/SearchTask.h" #include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h" #include "utils/ThreadPool.h"
...@@ -28,74 +29,80 @@ namespace { ...@@ -28,74 +29,80 @@ namespace {
namespace ms = milvus::scheduler; namespace ms = milvus::scheduler;
void void
BuildResult(std::vector<int64_t>& output_ids, BuildResult(ms::ResultIds& output_ids,
std::vector<float>& output_distance, ms::ResultDistances & output_distances,
uint64_t input_k, size_t input_k,
uint64_t topk, size_t topk,
uint64_t nq, size_t nq,
bool ascending) { bool ascending) {
output_ids.clear(); output_ids.clear();
output_ids.resize(nq * topk); output_ids.resize(nq * topk);
output_distance.clear(); output_distances.clear();
output_distance.resize(nq * topk); output_distances.resize(nq * topk);
for (uint64_t i = 0; i < nq; i++) { for (size_t i = 0; i < nq; i++) {
//insert valid items //insert valid items
for (uint64_t j = 0; j < input_k; j++) { for (size_t j = 0; j < input_k; j++) {
output_ids[i * topk + j] = (int64_t)(drand48() * 100000); output_ids[i * topk + j] = (int64_t)(drand48() * 100000);
output_distance[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48()); output_distances[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48());
} }
//insert invalid items //insert invalid items
for (uint64_t j = input_k; j < topk; j++) { for (size_t j = input_k; j < topk; j++) {
output_ids[i * topk + j] = -1; output_ids[i * topk + j] = -1;
output_distance[i * topk + j] = -1.0; output_distances[i * topk + j] = -1.0;
} }
} }
} }
void void
CopyResult(std::vector<int64_t>& output_ids, CopyResult(ms::ResultIds& output_ids,
std::vector<float>& output_distance, ms::ResultDistances& output_distances,
uint64_t output_topk, size_t output_topk,
std::vector<int64_t>& input_ids, ms::ResultIds& input_ids,
std::vector<float>& input_distance, ms::ResultDistances& input_distances,
uint64_t input_topk, size_t input_topk,
uint64_t nq) { size_t nq) {
ASSERT_TRUE(input_ids.size() >= nq * input_topk); ASSERT_TRUE(input_ids.size() >= nq * input_topk);
ASSERT_TRUE(input_distance.size() >= nq * input_topk); ASSERT_TRUE(input_distances.size() >= nq * input_topk);
ASSERT_TRUE(output_topk <= input_topk); ASSERT_TRUE(output_topk <= input_topk);
output_ids.clear(); output_ids.clear();
output_ids.resize(nq * output_topk); output_ids.resize(nq * output_topk);
output_distance.clear(); output_distances.clear();
output_distance.resize(nq * output_topk); output_distances.resize(nq * output_topk);
for (uint64_t i = 0; i < nq; i++) { for (size_t i = 0; i < nq; i++) {
for (uint64_t j = 0; j < output_topk; j++) { for (size_t j = 0; j < output_topk; j++) {
output_ids[i * output_topk + j] = input_ids[i * input_topk + j]; output_ids[i * output_topk + j] = input_ids[i * input_topk + j];
output_distance[i * output_topk + j] = input_distance[i * input_topk + j]; output_distances[i * output_topk + j] = input_distances[i * input_topk + j];
} }
} }
} }
void void
CheckTopkResult(const std::vector<int64_t>& input_ids_1, CheckTopkResult(const ms::ResultIds& input_ids_1,
const std::vector<float>& input_distance_1, const ms::ResultDistances& input_distances_1,
const std::vector<int64_t>& input_ids_2, size_t input_k_1,
const std::vector<float>& input_distance_2, const ms::ResultIds& input_ids_2,
uint64_t topk, const ms::ResultDistances& input_distances_2,
uint64_t nq, size_t input_k_2,
size_t topk,
size_t nq,
bool ascending, bool ascending,
const milvus::scheduler::ResultSet& result) { const ms::ResultIds& result_ids,
ASSERT_EQ(result.size(), nq); const ms::ResultDistances& result_distances) {
ASSERT_EQ(input_ids_1.size(), input_distance_1.size()); ASSERT_EQ(result_ids.size(), result_distances.size());
ASSERT_EQ(input_ids_2.size(), input_distance_2.size()); ASSERT_EQ(input_ids_1.size(), input_distances_1.size());
ASSERT_EQ(input_ids_2.size(), input_distances_2.size());
for (int64_t i = 0; i < nq; i++) { size_t result_k = result_distances.size() / nq;
ASSERT_EQ(result_k, std::min(topk, input_k_1 + input_k_2));
for (size_t i = 0; i < nq; i++) {
std::vector<float> std::vector<float>
src_vec(input_distance_1.begin() + i * topk, input_distance_1.begin() + (i + 1) * topk); src_vec(input_distances_1.begin() + i * topk, input_distances_1.begin() + (i + 1) * topk);
src_vec.insert(src_vec.end(), src_vec.insert(src_vec.end(),
input_distance_2.begin() + i * topk, input_distances_2.begin() + i * topk,
input_distance_2.begin() + (i + 1) * topk); input_distances_2.begin() + (i + 1) * topk);
if (ascending) { if (ascending) {
std::sort(src_vec.begin(), src_vec.end()); std::sort(src_vec.begin(), src_vec.end());
} else { } else {
...@@ -111,15 +118,16 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1, ...@@ -111,15 +118,16 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
++iter; ++iter;
} }
uint64_t n = std::min(topk, result[i].size()); size_t n = std::min(topk, result_ids.size() / nq);
for (uint64_t j = 0; j < n; j++) { for (size_t j = 0; j < n; j++) {
if (result[i][j].first < 0) { size_t idx = i * n + j;
if (result_ids[idx] < 0) {
continue; continue;
} }
if (src_vec[j] != result[i][j].second) { if (src_vec[j] != result_distances[idx]) {
std::cout << src_vec[j] << " " << result[i][j].second << std::endl; std::cout << src_vec[j] << " " << result_distances[idx] << std::endl;
} }
ASSERT_TRUE(src_vec[j] == result[i][j].second); ASSERT_TRUE(src_vec[j] == result_distances[idx]);
} }
} }
} }
...@@ -127,20 +135,21 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1, ...@@ -127,20 +135,21 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
} // namespace } // namespace
void void
MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { MergeTopkToResultSetTest(size_t topk_1, size_t topk_2, size_t nq, size_t topk, bool ascending) {
std::vector<int64_t> ids1, ids2; ms::ResultIds ids1, ids2;
std::vector<float> dist1, dist2; ms::ResultDistances dist1, dist2;
ms::ResultSet result; ms::ResultIds result_ids;
ms::ResultDistances result_distances;
BuildResult(ids1, dist1, topk_1, topk, nq, ascending); BuildResult(ids1, dist1, topk_1, topk, nq, ascending);
BuildResult(ids2, dist2, topk_2, topk, nq, ascending); BuildResult(ids2, dist2, topk_2, topk, nq, ascending);
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result_ids, result_distances);
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result_ids, result_distances);
CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result); CheckTopkResult(ids1, dist1, topk_1, ids2, dist2, topk_2, topk, nq, ascending, result_ids, result_distances);
} }
TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { TEST(DBSearchTest, MERGE_RESULT_SET_TEST) {
uint64_t NQ = 15; size_t NQ = 15;
uint64_t TOP_K = 64; size_t TOP_K = 64;
/* test1, id1/dist1 valid, id2/dist2 empty */ /* test1, id1/dist1 valid, id2/dist2 empty */
MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, true); MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, true);
...@@ -159,21 +168,21 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { ...@@ -159,21 +168,21 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) {
MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, false); MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, false);
} }
//void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { //void MergeTopkArrayTest(size_t topk_1, size_t topk_2, size_t nq, size_t topk, bool ascending) {
// std::vector<int64_t> ids1, ids2; // std::vector<int64_t> ids1, ids2;
// std::vector<float> dist1, dist2; // std::vector<float> dist1, dist2;
// ms::ResultSet result; // ms::ResultSet result;
// BuildResult(ids1, dist1, topk_1, topk, nq, ascending); // BuildResult(ids1, dist1, topk_1, topk, nq, ascending);
// BuildResult(ids2, dist2, topk_2, topk, nq, ascending); // BuildResult(ids2, dist2, topk_2, topk, nq, ascending);
// uint64_t result_topk = std::min(topk, topk_1 + topk_2); // size_t result_topk = std::min(topk, topk_1 + topk_2);
// ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending); // ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending);
// if (ids1.size() != result_topk * nq) { // if (ids1.size() != result_topk * nq) {
// std::cout << ids1.size() << " " << result_topk * nq << std::endl; // std::cout << ids1.size() << " " << result_topk * nq << std::endl;
// } // }
// ASSERT_TRUE(ids1.size() == result_topk * nq); // ASSERT_TRUE(ids1.size() == result_topk * nq);
// ASSERT_TRUE(dist1.size() == result_topk * nq); // ASSERT_TRUE(dist1.size() == result_topk * nq);
// for (uint64_t i = 0; i < nq; i++) { // for (size_t i = 0; i < nq; i++) {
// for (uint64_t k = 1; k < result_topk; k++) { // for (size_t k = 1; k < result_topk; k++) {
// float f0 = dist1[i * topk + k - 1]; // float f0 = dist1[i * topk + k - 1];
// float f1 = dist1[i * topk + k]; // float f1 = dist1[i * topk + k];
// if (ascending) { // if (ascending) {
...@@ -192,8 +201,8 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { ...@@ -192,8 +201,8 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) {
//} //}
//TEST(DBSearchTest, MERGE_ARRAY_TEST) { //TEST(DBSearchTest, MERGE_ARRAY_TEST) {
// uint64_t NQ = 15; // size_t NQ = 15;
// uint64_t TOP_K = 64; // size_t TOP_K = 64;
// //
// /* test1, id1/dist1 valid, id2/dist2 empty */ // /* test1, id1/dist1 valid, id2/dist2 empty */
// MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true); // MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true);
...@@ -222,23 +231,23 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -222,23 +231,23 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
int32_t index_file_num = 478; /* sift1B dataset, index files num */ int32_t index_file_num = 478; /* sift1B dataset, index files num */
bool ascending = true; bool ascending = true;
std::vector<int32_t> thread_vec = {4, 8}; std::vector<size_t> thread_vec = {4};
std::vector<int32_t> nq_vec = {1, 10, 100}; std::vector<size_t> nq_vec = {1000};
std::vector<int32_t> topk_vec = {1, 4, 16, 64}; std::vector<size_t> topk_vec = {64};
int32_t NQ = nq_vec[nq_vec.size() - 1]; size_t NQ = nq_vec[nq_vec.size() - 1];
int32_t TOPK = topk_vec[topk_vec.size() - 1]; size_t TOPK = topk_vec[topk_vec.size() - 1];
std::vector<std::vector<int64_t>> id_vec; std::vector<ms::ResultIds> id_vec;
std::vector<std::vector<float>> dist_vec; std::vector<ms::ResultDistances> dist_vec;
std::vector<int64_t> input_ids; ms::ResultIds input_ids;
std::vector<float> input_distance; ms::ResultDistances input_distances;
int32_t i, k, step; int32_t i, k, step;
/* generate testing data */ /* generate testing data */
for (i = 0; i < index_file_num; i++) { for (i = 0; i < index_file_num; i++) {
BuildResult(input_ids, input_distance, TOPK, TOPK, NQ, ascending); BuildResult(input_ids, input_distances, TOPK, TOPK, NQ, ascending);
id_vec.push_back(input_ids); id_vec.push_back(input_ids);
dist_vec.push_back(input_distance); dist_vec.push_back(input_distances);
} }
for (int32_t max_thread_num : thread_vec) { for (int32_t max_thread_num : thread_vec) {
...@@ -247,10 +256,11 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -247,10 +256,11 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
for (int32_t nq : nq_vec) { for (int32_t nq : nq_vec) {
for (int32_t top_k : topk_vec) { for (int32_t top_k : topk_vec) {
ms::ResultSet final_result, final_result_2, final_result_3; ms::ResultIds final_result_ids, final_result_ids_2, final_result_ids_3;
ms::ResultDistances final_result_distances, final_result_distances_2, final_result_distances_3;
std::vector<std::vector<int64_t>> id_vec_1(index_file_num); std::vector<ms::ResultIds> id_vec_1(index_file_num);
std::vector<std::vector<float>> dist_vec_1(index_file_num); std::vector<ms::ResultDistances> dist_vec_1(index_file_num);
for (i = 0; i < index_file_num; i++) { for (i = 0; i < index_file_num; i++) {
CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
} }
...@@ -268,8 +278,10 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -268,8 +278,10 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
nq, nq,
top_k, top_k,
ascending, ascending,
final_result); final_result_ids,
ASSERT_EQ(final_result.size(), nq); final_result_distances);
ASSERT_EQ(final_result_ids.size(), nq * top_k);
ASSERT_EQ(final_result_distances.size(), nq * top_k);
} }
rc1.RecordSection("reduce done"); rc1.RecordSection("reduce done");
...@@ -278,7 +290,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -278,7 +290,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
// /* method-2 */ // /* method-2 */
// std::vector<std::vector<int64_t>> id_vec_2(index_file_num); // std::vector<std::vector<int64_t>> id_vec_2(index_file_num);
// std::vector<std::vector<float>> dist_vec_2(index_file_num); // std::vector<std::vector<float>> dist_vec_2(index_file_num);
// std::vector<uint64_t> k_vec_2(index_file_num); // std::vector<size_t> k_vec_2(index_file_num);
// for (i = 0; i < index_file_num; i++) { // for (i = 0; i < index_file_num; i++) {
// CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); // CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
// k_vec_2[i] = top_k; // k_vec_2[i] = top_k;
...@@ -321,7 +333,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { ...@@ -321,7 +333,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) {
// /* method-3 parallel */ // /* method-3 parallel */
// std::vector<std::vector<int64_t>> id_vec_3(index_file_num); // std::vector<std::vector<int64_t>> id_vec_3(index_file_num);
// std::vector<std::vector<float>> dist_vec_3(index_file_num); // std::vector<std::vector<float>> dist_vec_3(index_file_num);
// std::vector<uint64_t> k_vec_3(index_file_num); // std::vector<size_t> k_vec_3(index_file_num);
// for (i = 0; i < index_file_num; i++) { // for (i = 0; i < index_file_num; i++) {
// CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); // CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
// k_vec_3[i] = top_k; // k_vec_3[i] = top_k;
......
...@@ -75,7 +75,8 @@ TEST_F(MetricTest, METRIC_TEST) { ...@@ -75,7 +75,8 @@ TEST_F(MetricTest, METRIC_TEST) {
} }
std::thread search([&]() { std::thread search([&]() {
milvus::engine::QueryResults results; // milvus::engine::ResultIds result_ids;
// milvus::engine::ResultDistances result_distances;
int k = 10; int k = 10;
std::this_thread::sleep_for(std::chrono::seconds(2)); std::this_thread::sleep_for(std::chrono::seconds(2));
...@@ -90,7 +91,7 @@ TEST_F(MetricTest, METRIC_TEST) { ...@@ -90,7 +91,7 @@ TEST_F(MetricTest, METRIC_TEST) {
prev_count = count; prev_count = count;
START_TIMER; START_TIMER;
// stat = db_->Query(group_name, k, qb, qxb, results); // stat = db_->Query(group_name, k, qb, qxb, result_ids, result_distances);
ss << "Search " << j << " With Size " << (float) (count * group_dim * sizeof(float)) / (1024 * 1024) ss << "Search " << j << " With Size " << (float) (count * group_dim * sizeof(float)) / (1024 * 1024)
<< " M"; << " M";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册