未验证 提交 c4b7be8e 编写于 作者: G groot 提交者: GitHub

#1728 Optimize request handler to combine similar query (#1743)

* modify changelog
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* improve search qps
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* changelog
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix hang bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix combine request result bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* add unittest for combine request
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix unittest failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>
上级 578d2633
......@@ -16,6 +16,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1667 Create index failed with type: rnsg if metric_type is IP
- \#1708 NSG search crashed
- \#1724 Remove unused unittests
- \#1728 Optimize request handler to combine similar query
- \#1734 Opentracing for combined search request
- \#1735 Fix search out of memory with ivf_flat
- \#1756 Fix memory exhausted during searching
......
......@@ -210,6 +210,9 @@ Status
SqliteMetaImpl::DescribeTable(TableSchema& table_schema) {
try {
server::MetricCollector metric;
// multi-threads call sqlite update may get exception('bad logic', etc), so we add a lock here
std::lock_guard<std::mutex> meta_lock(meta_mutex_);
fiu_do_on("SqliteMetaImpl.DescribeTable.throw_exception", throw std::exception());
auto groups = ConnectorPtr->select(
columns(&TableSchema::id_, &TableSchema::state_, &TableSchema::dimension_, &TableSchema::created_on_,
......
......@@ -48,15 +48,21 @@ ContextChild::ContextChild(const ContextPtr& context, const std::string& operati
}
ContextChild::~ContextChild() {
Finish();
}
void
ContextChild::Finish() {
if (context_) {
context_->GetTraceContext()->GetSpan()->Finish();
context_ = nullptr;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
ContextFollower::ContextFollower(const ContextPtr& context, const std::string& operation_name) {
if (context) {
context_ = context->Child(operation_name);
context_ = context->Follower(operation_name);
}
}
......@@ -66,5 +72,13 @@ ContextFollower::~ContextFollower() {
}
}
void
ContextFollower::Finish() {
if (context_) {
context_->GetTraceContext()->GetSpan()->Finish();
context_ = nullptr;
}
}
} // namespace server
} // namespace milvus
......@@ -53,6 +53,9 @@ class ContextChild {
return context_;
}
void
Finish();
private:
ContextPtr context_;
};
......@@ -67,6 +70,9 @@ class ContextFollower {
return context_;
}
void
Finish();
private:
ContextPtr context_;
};
......
......@@ -36,10 +36,6 @@ RequestScheduler::ExecRequest(BaseRequestPtr& request_ptr) {
RequestScheduler& scheduler = RequestScheduler::GetInstance();
scheduler.ExecuteRequest(request_ptr);
if (!request_ptr->IsAsync()) {
request_ptr->WaitToFinish();
}
}
void
......@@ -84,18 +80,31 @@ RequestScheduler::ExecuteRequest(const BaseRequestPtr& request_ptr) {
return Status::OK();
}
auto status = PutToQueue(request_ptr);
auto status = request_ptr->PreExecute();
if (!status.ok()) {
request_ptr->Done();
return status;
}
status = PutToQueue(request_ptr);
fiu_do_on("RequestScheduler.ExecuteRequest.push_queue_fail", status = Status(SERVER_INVALID_ARGUMENT, ""));
if (!status.ok()) {
SERVER_LOG_ERROR << "Put request to queue failed with code: " << status.ToString();
request_ptr->Done();
return status;
}
if (request_ptr->IsAsync()) {
return Status::OK(); // async execution, caller need to call WaitToFinish at somewhere
}
return request_ptr->WaitToFinish(); // sync execution
status = request_ptr->WaitToFinish(); // sync execution
if (!status.ok()) {
return status;
}
return request_ptr->PostExecute();
}
void
......
......@@ -83,6 +83,15 @@ BaseRequest::~BaseRequest() {
WaitToFinish();
}
Status
BaseRequest::PreExecute() {
status_ = OnPreExecute();
if (!status_.ok()) {
Done();
}
return status_;
}
Status
BaseRequest::Execute() {
status_ = OnExecute();
......@@ -90,6 +99,22 @@ BaseRequest::Execute() {
return status_;
}
Status
BaseRequest::PostExecute() {
status_ = OnPostExecute();
return status_;
}
Status
BaseRequest::OnPreExecute() {
return Status::OK();
}
Status
BaseRequest::OnPostExecute() {
return Status::OK();
}
void
BaseRequest::Done() {
done_ = true;
......
......@@ -159,9 +159,15 @@ class BaseRequest {
virtual ~BaseRequest();
public:
Status
PreExecute();
Status
Execute();
Status
PostExecute();
void
Done();
......@@ -192,9 +198,15 @@ class BaseRequest {
}
protected:
virtual Status
OnPreExecute();
virtual Status
OnExecute() = 0;
virtual Status
OnPostExecute();
std::string
TableNotExistMsg(const std::string& table_name);
......
......@@ -27,6 +27,7 @@ namespace server {
namespace {
constexpr int64_t MAX_TOPK_GAP = 200;
constexpr uint64_t MAX_NQ = 200;
void
GetUniqueList(const std::vector<std::string>& list, std::set<std::string>& unique_list) {
......@@ -61,6 +62,7 @@ FreeRequest(SearchRequestPtr& request, const Status& status) {
class TracingContextList {
public:
TracingContextList() = default;
~TracingContextList() {
Finish();
}
......@@ -181,6 +183,15 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchReque
return false;
}
// sum of nq must less-equal than MAX_NQ
if (left->VectorsData().vector_count_ > MAX_NQ || right->VectorsData().vector_count_ > MAX_NQ) {
return false;
}
uint64_t total_nq = left->VectorsData().vector_count_ + right->VectorsData().vector_count_;
if (total_nq > MAX_NQ) {
return false;
}
// partition list must be equal for each one
std::set<std::string> left_partition_list, right_partition_list;
GetUniqueList(left->PartitionList(), left_partition_list);
......@@ -213,24 +224,17 @@ Status
SearchCombineRequest::OnExecute() {
try {
size_t combined_request = request_list_.size();
SERVER_LOG_DEBUG << "SearchCombineRequest begin execute, combined requests=" << combined_request
SERVER_LOG_DEBUG << "SearchCombineRequest execute, request count=" << combined_request
<< ", extra_params=" << extra_params_.dump();
std::string hdr = "SearchCombineRequest(table=" + table_name_ + ")";
TimeRecorder rc(hdr);
TimeRecorderAuto rc(hdr);
// step 1: check table name
auto status = ValidationUtil::ValidateTableName(table_name_);
if (!status.ok()) {
FreeRequests(status);
return status;
}
// step 2: check table existence
// step 1: check table existence
// only process root table, ignore partition table
engine::meta::TableSchema table_schema;
table_schema.table_id_ = table_name_;
status = DBWrapper::DB()->DescribeTable(table_schema);
auto status = DBWrapper::DB()->DescribeTable(table_schema);
if (!status.ok()) {
if (status.code() == DB_NOT_FOUND) {
status = Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name_));
......@@ -248,7 +252,7 @@ SearchCombineRequest::OnExecute() {
}
}
// step 3: check input
// step 2: check input
size_t run_request = 0;
std::vector<SearchRequestPtr>::iterator iter = request_list_.begin();
for (; iter != request_list_.end();) {
......@@ -303,7 +307,7 @@ SearchCombineRequest::OnExecute() {
SERVER_LOG_DEBUG << "reset topk to " << search_topk_;
rc.RecordSection("check validation");
// step 5: construct vectors_data and set search_topk
// step 3: construct vectors_data
SearchRequestPtr& first_request = *request_list_.begin();
uint64_t total_count = 0;
for (auto& request : request_list_) {
......@@ -323,25 +327,26 @@ SearchCombineRequest::OnExecute() {
int64_t offset = 0;
for (auto& request : request_list_) {
const engine::VectorsData& src = request->VectorsData();
size_t data_size = 0;
if (is_float) {
data_size = src.vector_count_ * dimension;
memcpy(vectors_data_.float_data_.data() + offset, src.float_data_.data(), data_size);
size_t element_cnt = src.vector_count_ * dimension;
memcpy(vectors_data_.float_data_.data() + offset, src.float_data_.data(), element_cnt * sizeof(float));
offset += element_cnt;
} else {
data_size = src.vector_count_ * dimension / 8;
memcpy(vectors_data_.binary_data_.data() + offset, src.binary_data_.data(), data_size);
size_t element_cnt = src.vector_count_ * dimension / 8;
memcpy(vectors_data_.binary_data_.data() + offset, src.binary_data_.data(), element_cnt);
offset += element_cnt;
}
offset += data_size;
}
SERVER_LOG_DEBUG << total_count << " query vectors combined";
rc.RecordSection("combined query vectors");
// step 6: search vectors
// step 4: search vectors
const std::vector<std::string>& partition_list = first_request->PartitionList();
const std::vector<std::string>& file_id_list = first_request->FileIDList();
engine::ResultIds result_ids;
engine::ResultDistances result_distances;
{
TracingContextList context_list;
context_list.CreateChild(request_list_, "Combine Query");
......@@ -355,7 +360,7 @@ SearchCombineRequest::OnExecute() {
}
}
rc.RecordSection("search combined vectors from engine");
rc.RecordSection("search vectors from engine");
if (!status.ok()) {
// let all request return
......@@ -369,7 +374,7 @@ SearchCombineRequest::OnExecute() {
return status;
}
// step 6: construct result array
// step 5: construct result array
offset = 0;
for (auto& request : request_list_) {
uint64_t count = request->VectorsData().vector_count_;
......@@ -387,9 +392,10 @@ SearchCombineRequest::OnExecute() {
}
rc.RecordSection("construct result and send");
rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
Status status = Status(SERVER_UNEXPECTED_ERROR, ex.what());
FreeRequests(status);
return status;
}
return Status::OK();
......
......@@ -51,35 +51,46 @@ SearchRequest::Create(const std::shared_ptr<milvus::server::Context>& context, c
}
Status
SearchRequest::OnExecute() {
try {
fiu_do_on("SearchRequest.OnExecute.throw_std_exception", throw std::exception());
uint64_t vector_count = vectors_data_.vector_count_;
auto pre_query_ctx = context_->Child("Pre query");
SearchRequest::OnPreExecute() {
std::string hdr = "SearchRequest pre-execute(table=" + table_name_ + ")";
TimeRecorderAuto rc(hdr);
milvus::server::ContextChild tracer_pre(context_, "Pre Query");
// step 1: check table name
auto status = ValidationUtil::ValidateTableName(table_name_);
if (!status.ok()) {
return status;
}
SERVER_LOG_DEBUG << "SearchRequest begin execute, extra_params=" << extra_params_.dump();
std::string hdr = "SearchRequest(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
", k=" + std::to_string(topk_) + ")";
// step 2: check search topk
status = ValidationUtil::ValidateSearchTopk(topk_);
if (!status.ok()) {
return status;
}
TimeRecorder rc(hdr);
// step 3: check partition tags
status = ValidationUtil::ValidatePartitionTags(partition_list_);
fiu_do_on("SearchRequest.OnExecute.invalid_partition_tags", status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
if (!status.ok()) {
return status;
}
// step 1: check table name
auto status = ValidationUtil::ValidateTableName(table_name_);
if (!status.ok()) {
return status;
}
return Status::OK();
}
// step 2: check search topk
status = ValidationUtil::ValidateSearchTopk(topk_);
if (!status.ok()) {
return status;
}
Status
SearchRequest::OnExecute() {
try {
uint64_t vector_count = vectors_data_.vector_count_;
fiu_do_on("SearchRequest.OnExecute.throw_std_exception", throw std::exception());
std::string hdr = "SearchRequest execute(table=" + table_name_ + ", nq=" + std::to_string(vector_count) +
", k=" + std::to_string(topk_) + ")";
TimeRecorderAuto rc(hdr);
// step 3: check table existence
// step 4: check table existence
// only process root table, ignore partition table
engine::meta::TableSchema table_schema;
table_schema.table_id_ = table_name_;
status = DBWrapper::DB()->DescribeTable(table_schema);
table_schema_.table_id_ = table_name_;
auto status = DBWrapper::DB()->DescribeTable(table_schema_);
fiu_do_on("SearchRequest.OnExecute.describe_table_fail", status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
if (!status.ok()) {
if (status.code() == DB_NOT_FOUND) {
......@@ -88,27 +99,19 @@ SearchRequest::OnExecute() {
return status;
}
} else {
if (!table_schema.owner_table_.empty()) {
if (!table_schema_.owner_table_.empty()) {
return Status(SERVER_INVALID_TABLE_NAME, TableNotExistMsg(table_name_));
}
}
// step 4: check search parameters
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, topk_);
// step 5: check search parameters
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema_, topk_);
if (!status.ok()) {
return status;
}
// step 5: check vector data according to metric type
status = ValidationUtil::ValidateVectorData(vectors_data_, table_schema);
if (!status.ok()) {
return status;
}
// step 6: check partition tags
status = ValidationUtil::ValidatePartitionTags(partition_list_);
fiu_do_on("SearchRequest.OnExecute.invalid_partition_tags",
status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
// step 6: check vector data according to metric type
status = ValidationUtil::ValidateVectorData(vectors_data_, table_schema_);
if (!status.ok()) {
return status;
}
......@@ -116,15 +119,13 @@ SearchRequest::OnExecute() {
rc.RecordSection("check validation");
// step 7: search vectors
engine::ResultIds result_ids;
engine::ResultDistances result_distances;
#ifdef MILVUS_ENABLE_PROFILING
std::string fname = "/tmp/search_" + CommonUtil::GetCurrentTimeStr() + ".profiling";
ProfilerStart(fname.c_str());
#endif
pre_query_ctx->GetTraceContext()->GetSpan()->Finish();
engine::ResultIds result_ids;
engine::ResultDistances result_distances;
if (file_id_list_.empty()) {
status = DBWrapper::DB()->Query(context_, table_name_, partition_list_, (size_t)topk_, extra_params_,
......@@ -134,11 +135,11 @@ SearchRequest::OnExecute() {
vectors_data_, result_ids, result_distances);
}
rc.RecordSection("query vectors from engine");
#ifdef MILVUS_ENABLE_PROFILING
ProfilerStop();
#endif
rc.RecordSection("search vectors from engine");
fiu_do_on("SearchRequest.OnExecute.query_fail", status = Status(milvus::SERVER_UNEXPECTED_ERROR, ""));
if (!status.ok()) {
return status;
......@@ -148,23 +149,17 @@ SearchRequest::OnExecute() {
return Status::OK(); // empty table
}
auto post_query_ctx = context_->Child("Constructing result");
// step 8: construct result array
result_.row_num_ = vector_count;
result_.distance_list_ = result_distances;
result_.id_list_ = result_ids;
post_query_ctx->GetTraceContext()->GetSpan()->Finish();
rc.RecordSection("construct result and send");
rc.ElapseFromBegin("totally cost");
milvus::server::ContextChild tracer(context_, "Constructing result");
result_.row_num_ = vectors_data_.vector_count_;
result_.id_list_.swap(result_ids);
result_.distance_list_.swap(result_distances);
rc.RecordSection("construct result");
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
return Status::OK();
}
} // namespace server
} // namespace milvus
......@@ -63,12 +63,20 @@ class SearchRequest : public BaseRequest {
return result_;
}
const milvus::engine::meta::TableSchema&
TableSchema() const {
return table_schema_;
}
protected:
SearchRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& table_name,
const engine::VectorsData& vectors, int64_t topk, const milvus::json& extra_params,
const std::vector<std::string>& partition_list, const std::vector<std::string>& file_id_list,
TopKQueryResult& result);
Status
OnPreExecute() override;
Status
OnExecute() override;
......@@ -81,6 +89,9 @@ class SearchRequest : public BaseRequest {
const std::vector<std::string> file_id_list_;
TopKQueryResult& result_;
// for validation
milvus::engine::meta::TableSchema table_schema_;
};
using SearchRequestPtr = std::shared_ptr<SearchRequest>;
......
......@@ -50,6 +50,21 @@ CopyRowRecord(::milvus::grpc::RowRecord* target, const std::vector<float>& src)
memcpy(vector_data->mutable_data(), src.data(), src.size() * sizeof(float));
}
void
CopyBinRowRecord(::milvus::grpc::RowRecord* target, const std::vector<uint8_t>& src) {
auto vector_data = target->mutable_binary_data();
vector_data->resize(static_cast<int>(src.size()));
memcpy(vector_data->data(), src.data(), src.size());
}
void
SearchFunc(std::shared_ptr<milvus::server::grpc::GrpcRequestHandler> handler,
::grpc::ServerContext* context,
std::shared_ptr<::milvus::grpc::SearchParam> request,
std::shared_ptr<::milvus::grpc::TopKQueryResult> result) {
handler->Search(context, request.get(), result.get());
}
class RpcHandlerTest : public testing::Test {
protected:
void
......@@ -135,7 +150,25 @@ BuildVectors(int64_t from, int64_t to, std::vector<std::vector<float>>& vector_r
std::vector<float> record;
record.resize(TABLE_DIM);
for (int64_t i = 0; i < TABLE_DIM; i++) {
record[i] = (float)(k % (i + 1));
record[i] = (float)(i + k);
}
vector_record_array.emplace_back(record);
}
}
void
BuildBinVectors(int64_t from, int64_t to, std::vector<std::vector<uint8_t>>& vector_record_array) {
if (to <= from) {
return;
}
vector_record_array.clear();
for (int64_t k = from; k < to; k++) {
std::vector<uint8_t> record;
record.resize(TABLE_DIM / 8);
for (int64_t i = 0; i < TABLE_DIM / 8; i++) {
record[i] = (i + k) % 256;
}
vector_record_array.emplace_back(record);
......@@ -367,6 +400,7 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) {
kv->set_key(milvus::server::grpc::EXTRA_PARAM_KEY);
kv->set_value("{ \"nprobe\": 32 }");
handler->Search(&context, &request, &response);
ASSERT_EQ(response.ids_size(), 0UL);
std::vector<std::vector<float>> record_array;
BuildVectors(0, VECTOR_COUNT, record_array);
......@@ -380,17 +414,228 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) {
::milvus::grpc::VectorIds vector_ids;
handler->Insert(&context, &insert_param, &vector_ids);
// flush
::milvus::grpc::Status grpc_status;
::milvus::grpc::FlushParam flush_param;
flush_param.add_table_name_array(TABLE_NAME);
handler->Flush(&context, &flush_param, &grpc_status);
// search
BuildVectors(0, 10, record_array);
for (auto& record : record_array) {
::milvus::grpc::RowRecord* row_record = request.add_query_record_array();
CopyRowRecord(row_record, record);
}
handler->Search(&context, &request, &response);
ASSERT_NE(response.ids_size(), 0UL);
// wrong file id
::milvus::grpc::SearchInFilesParam search_in_files_param;
std::string* file_id = search_in_files_param.add_file_id_array();
*file_id = "test_tbl";
handler->SearchInFiles(&context, &search_in_files_param, &response);
ASSERT_EQ(response.ids_size(), 0UL);
}
TEST_F(RpcHandlerTest, COMBINE_SEARCH_TEST) {
::grpc::ServerContext context;
handler->SetContext(&context, dummy_context);
handler->RegisterRequestHandler(milvus::server::RequestHandler());
// create table
std::string table_name = "combine";
::milvus::grpc::TableSchema tableschema;
tableschema.set_table_name(table_name);
tableschema.set_dimension(TABLE_DIM);
tableschema.set_index_file_size(INDEX_FILE_SIZE);
tableschema.set_metric_type(1); // L2 metric
::milvus::grpc::Status status;
handler->CreateTable(&context, &tableschema, &status);
ASSERT_EQ(status.error_code(), 0);
// insert vectors
std::vector<std::vector<float>> record_array;
BuildVectors(0, VECTOR_COUNT, record_array);
::milvus::grpc::InsertParam insert_param;
int64_t vec_id = 0;
for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array();
CopyRowRecord(grpc_record, record);
insert_param.add_row_id_array(++vec_id);
}
insert_param.set_table_name(table_name);
::milvus::grpc::VectorIds vector_ids;
handler->Insert(&context, &insert_param, &vector_ids);
// flush
::milvus::grpc::Status grpc_status;
::milvus::grpc::FlushParam flush_param;
flush_param.add_table_name_array(table_name);
handler->Flush(&context, &flush_param, &grpc_status);
// multi thread search requests will be combined
int QUERY_COUNT = 10;
int64_t NQ = 2;
int64_t TOPK = 5;
using RequestPtr = std::shared_ptr<::milvus::grpc::SearchParam>;
std::vector<RequestPtr> request_array;
for (int i = 0; i < QUERY_COUNT; i++) {
RequestPtr request = std::make_shared<::milvus::grpc::SearchParam>();
request->set_table_name(table_name);
request->set_topk(TOPK);
milvus::grpc::KeyValuePair* kv = request->add_extra_params();
kv->set_key(milvus::server::grpc::EXTRA_PARAM_KEY);
kv->set_value("{}");
BuildVectors(i * NQ, (i + 1) * NQ, record_array);
for (auto& record : record_array) {
::milvus::grpc::RowRecord* row_record = request->add_query_record_array();
CopyRowRecord(row_record, record);
}
request_array.emplace_back(request);
}
using ResultPtr = std::shared_ptr<::milvus::grpc::TopKQueryResult>;
std::vector<ResultPtr> result_array;
using ThreadPtr = std::shared_ptr<std::thread>;
std::vector<ThreadPtr> thread_list;
for (int i = 0; i < QUERY_COUNT; i++) {
ResultPtr result_ptr = std::make_shared<::milvus::grpc::TopKQueryResult>();
result_array.push_back(result_ptr);
ThreadPtr
thread = std::make_shared<std::thread>(SearchFunc, handler, &context, request_array[i], result_ptr);
thread_list.emplace_back(thread);
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
// wait search finish
for (auto& iter : thread_list) {
iter->join();
}
// check result
int64_t index = 0;
for (auto& result_ptr : result_array) {
ASSERT_NE(result_ptr->ids_size(), 0);
std::string msg = "Result no." + std::to_string(index) + ": \n";
for (int64_t i = 0; i < NQ; i++) {
for (int64_t k = 0; k < TOPK; k++) {
msg += "[";
msg += std::to_string(result_ptr->ids(i * TOPK + k));
msg += ", ";
msg += std::to_string(result_ptr->distances(i * TOPK + k));
msg += "]";
msg += ", ";
}
msg += "\n";
ASSERT_NE(result_ptr->ids(i * TOPK), 0);
ASSERT_LT(result_ptr->distances(i * TOPK), 0.00001);
}
std::cout << msg << std::endl;
index++;
}
}
TEST_F(RpcHandlerTest, COMBINE_SEARCH_BINARY_TEST) {
::grpc::ServerContext context;
handler->SetContext(&context, dummy_context);
handler->RegisterRequestHandler(milvus::server::RequestHandler());
// create table
std::string table_name = "combine_bin";
::milvus::grpc::TableSchema tableschema;
tableschema.set_table_name(table_name);
tableschema.set_dimension(TABLE_DIM);
tableschema.set_index_file_size(INDEX_FILE_SIZE);
tableschema.set_metric_type(5); // tanimoto metric
::milvus::grpc::Status status;
handler->CreateTable(&context, &tableschema, &status);
ASSERT_EQ(status.error_code(), 0);
// insert vectors
std::vector<std::vector<uint8_t>> record_array;
BuildBinVectors(0, VECTOR_COUNT, record_array);
::milvus::grpc::InsertParam insert_param;
int64_t vec_id = 0;
for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array();
CopyBinRowRecord(grpc_record, record);
insert_param.add_row_id_array(++vec_id);
}
insert_param.set_table_name(table_name);
::milvus::grpc::VectorIds vector_ids;
handler->Insert(&context, &insert_param, &vector_ids);
// flush
::milvus::grpc::Status grpc_status;
::milvus::grpc::FlushParam flush_param;
flush_param.add_table_name_array(table_name);
handler->Flush(&context, &flush_param, &grpc_status);
// multi thread search requests will be combined
int QUERY_COUNT = 10;
int64_t NQ = 2;
int64_t TOPK = 5;
using RequestPtr = std::shared_ptr<::milvus::grpc::SearchParam>;
std::vector<RequestPtr> request_array;
for (int i = 0; i < QUERY_COUNT; i++) {
RequestPtr request = std::make_shared<::milvus::grpc::SearchParam>();
request->set_table_name(table_name);
request->set_topk(TOPK);
milvus::grpc::KeyValuePair* kv = request->add_extra_params();
kv->set_key(milvus::server::grpc::EXTRA_PARAM_KEY);
kv->set_value("{}");
BuildBinVectors(i * NQ, (i + 1) * NQ, record_array);
for (auto& record : record_array) {
::milvus::grpc::RowRecord* row_record = request->add_query_record_array();
CopyBinRowRecord(row_record, record);
}
request_array.emplace_back(request);
}
using ResultPtr = std::shared_ptr<::milvus::grpc::TopKQueryResult>;
std::vector<ResultPtr> result_array;
using ThreadPtr = std::shared_ptr<std::thread>;
std::vector<ThreadPtr> thread_list;
for (int i = 0; i < QUERY_COUNT; i++) {
ResultPtr result_ptr = std::make_shared<::milvus::grpc::TopKQueryResult>();
result_array.push_back(result_ptr);
ThreadPtr
thread = std::make_shared<std::thread>(SearchFunc, handler, &context, request_array[i], result_ptr);
thread_list.emplace_back(thread);
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
// wait search finish
for (auto& iter : thread_list) {
iter->join();
}
// check result
int64_t index = 0;
for (auto& result_ptr : result_array) {
ASSERT_NE(result_ptr->ids_size(), 0);
std::string msg = "Result no." + std::to_string(++index) + ": \n";
for (int64_t i = 0; i < NQ; i++) {
for (int64_t k = 0; k < TOPK; k++) {
msg += "[";
msg += std::to_string(result_ptr->ids(i * TOPK + k));
msg += ", ";
msg += std::to_string(result_ptr->distances(i * TOPK + k));
msg += "]";
msg += ", ";
}
msg += "\n";
ASSERT_NE(result_ptr->ids(i * TOPK), 0);
ASSERT_LT(result_ptr->distances(i * TOPK), 0.00001);
}
std::cout << msg << std::endl;
}
}
TEST_F(RpcHandlerTest, TABLES_TEST) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册