提交 07dfe09a 编写于 作者: F fishpenguin

#204 - improve grpc performance in search

上级 28810ad0
...@@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA.
## Feature ## Feature
## Improvement ## Improvement
- \#204 - improve grpc performance in search
- \#207 - Add more unittest for config set/get - \#207 - Add more unittest for config set/get
- \#208 - optimize unittest to support run single test more easily - \#208 - optimize unittest to support run single test more easily
......
We manually add two APIs in "milvus.pd.h":
add_vector_data
add_row_id_array
If proto files need be generated again, remember to re-add above APIs.
\ No newline at end of file
...@@ -201,60 +201,60 @@ void MilvusService::Stub::experimental_async::Insert(::grpc::ClientContext* cont ...@@ -201,60 +201,60 @@ void MilvusService::Stub::experimental_async::Insert(::grpc::ClientContext* cont
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::VectorIds>::Create(channel_.get(), cq, rpcmethod_Insert_, context, request, false); return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::VectorIds>::Create(channel_.get(), cq, rpcmethod_Insert_, context, request, false);
} }
::grpc::Status MilvusService::Stub::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::milvus::grpc::TopKQueryResultList* response) { ::grpc::Status MilvusService::Stub::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::milvus::grpc::TopKQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Search_, context, request, response); return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Search_, context, request, response);
} }
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) { void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f)); ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f));
} }
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) { void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f)); ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f));
} }
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) { void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor); ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor);
} }
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) { void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor); ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor);
} }
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::AsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) { ::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, true); return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, true);
} }
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::PrepareAsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) { ::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, false); return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, false);
} }
::grpc::Status MilvusService::Stub::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::milvus::grpc::TopKQueryResultList* response) { ::grpc::Status MilvusService::Stub::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::milvus::grpc::TopKQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_SearchInFiles_, context, request, response); return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_SearchInFiles_, context, request, response);
} }
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) { void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f)); ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f));
} }
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) { void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f)); ::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f));
} }
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) { void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor); ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor);
} }
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) { void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor); ::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor);
} }
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::AsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) { ::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, true); return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, true);
} }
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::PrepareAsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) { ::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, false); return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, false);
} }
::grpc::Status MilvusService::Stub::DescribeTable(::grpc::ClientContext* context, const ::milvus::grpc::TableName& request, ::milvus::grpc::TableSchema* response) { ::grpc::Status MilvusService::Stub::DescribeTable(::grpc::ClientContext* context, const ::milvus::grpc::TableName& request, ::milvus::grpc::TableSchema* response) {
...@@ -510,12 +510,12 @@ MilvusService::Service::Service() { ...@@ -510,12 +510,12 @@ MilvusService::Service::Service() {
AddMethod(new ::grpc::internal::RpcServiceMethod( AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[5], MilvusService_method_names[5],
::grpc::internal::RpcMethod::NORMAL_RPC, ::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchParam, ::milvus::grpc::TopKQueryResultList>( new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchParam, ::milvus::grpc::TopKQueryResult>(
std::mem_fn(&MilvusService::Service::Search), this))); std::mem_fn(&MilvusService::Service::Search), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod( AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[6], MilvusService_method_names[6],
::grpc::internal::RpcMethod::NORMAL_RPC, ::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchInFilesParam, ::milvus::grpc::TopKQueryResultList>( new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchInFilesParam, ::milvus::grpc::TopKQueryResult>(
std::mem_fn(&MilvusService::Service::SearchInFiles), this))); std::mem_fn(&MilvusService::Service::SearchInFiles), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod( AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[7], MilvusService_method_names[7],
...@@ -597,14 +597,14 @@ MilvusService::Service::~Service() { ...@@ -597,14 +597,14 @@ MilvusService::Service::~Service() {
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
} }
::grpc::Status MilvusService::Service::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response) { ::grpc::Status MilvusService::Service::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResult* response) {
(void) context; (void) context;
(void) request; (void) request;
(void) response; (void) response;
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
} }
::grpc::Status MilvusService::Service::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response) { ::grpc::Status MilvusService::Service::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResult* response) {
(void) context; (void) context;
(void) request; (void) request;
(void) response; (void) response;
......
...@@ -84,24 +84,12 @@ message SearchInFilesParam { ...@@ -84,24 +84,12 @@ message SearchInFilesParam {
/** /**
* @brief Query result params * @brief Query result params
*/ */
message QueryResult {
int64 id = 1;
double distance = 2;
}
/**
* @brief TopK query result
*/
message TopKQueryResult { message TopKQueryResult {
repeated QueryResult query_result_arrays = 1;
}
/**
* @brief List of topK query result
*/
message TopKQueryResultList {
Status status = 1; Status status = 1;
repeated TopKQueryResult topk_query_result = 2; int64 nq = 2;
int64 topk = 3;
bytes ids_binary = 4;
bytes distances_binary = 5;
} }
/** /**
...@@ -227,7 +215,7 @@ service MilvusService { ...@@ -227,7 +215,7 @@ service MilvusService {
* *
* @return query result array. * @return query result array.
*/ */
rpc Search(SearchParam) returns (TopKQueryResultList) {} rpc Search(SearchParam) returns (TopKQueryResult) {}
/** /**
* @brief Internal use query interface * @brief Internal use query interface
...@@ -241,7 +229,7 @@ service MilvusService { ...@@ -241,7 +229,7 @@ service MilvusService {
* *
* @return query result array. * @return query result array.
*/ */
rpc SearchInFiles(SearchInFilesParam) returns (TopKQueryResultList) {} rpc SearchInFiles(SearchInFilesParam) returns (TopKQueryResult) {}
/** /**
* @brief Get table schema * @brief Get table schema
......
...@@ -57,22 +57,22 @@ PrintTableSchema(const milvus::TableSchema& tb_schema) { ...@@ -57,22 +57,22 @@ PrintTableSchema(const milvus::TableSchema& tb_schema) {
void void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array, PrintSearchResult(const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array,
const std::vector<milvus::TopKQueryResult>& topk_query_result_array) { const milvus::TopKQueryResult& topk_query_result) {
BLOCK_SPLITER BLOCK_SPLITER
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl; size_t nq = topk_query_result.row_num;
size_t topk = topk_query_result.topk;
std::cout << "Returned result count: " << nq * topk << std::endl;
int32_t index = 0; int32_t index = 0;
for (auto& result : topk_query_result_array) { for (size_t i = 0; i < nq; i++) {
auto search_id = search_record_array[index].first; auto search_id = search_record_array[index].first;
index++; index++;
std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id) << " top " std::cout << "No." << index << " vector " << search_id << " top " << topk << " search result:" << std::endl;
<< std::to_string(result.query_result_arrays.size()) << " search result:" << std::endl; for (size_t j = 0; j < topk; j++) {
for (auto& item : result.query_result_arrays) { size_t idx = i * nq + j;
std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance); std::cout << "\t" << topk_query_result.ids[idx] << "\t" << topk_query_result.distances[idx] << std::endl;
std::cout << std::endl;
} }
} }
BLOCK_SPLITER BLOCK_SPLITER
} }
...@@ -166,11 +166,13 @@ class TimeRecorder { ...@@ -166,11 +166,13 @@ class TimeRecorder {
void void
CheckResult(const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array, CheckResult(const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array,
const std::vector<milvus::TopKQueryResult>& topk_query_result_array) { const milvus::TopKQueryResult& topk_query_result) {
BLOCK_SPLITER BLOCK_SPLITER
size_t nq = topk_query_result.row_num;
size_t result_k = topk_query_result.topk;
int64_t index = 0; int64_t index = 0;
for (auto& result : topk_query_result_array) { for (size_t i = 0; i < nq; i++) {
auto result_id = result.query_result_arrays[0].id; auto result_id = topk_query_result.ids[i * result_k];
auto search_id = search_record_array[index++].first; auto search_id = search_record_array[index++].first;
if (result_id != search_id) { if (result_id != search_id) {
std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl; std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl;
...@@ -196,19 +198,18 @@ DoSearch(std::shared_ptr<milvus::Connection> conn, ...@@ -196,19 +198,18 @@ DoSearch(std::shared_ptr<milvus::Connection> conn,
} }
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
std::vector<milvus::TopKQueryResult> topk_query_result_array; milvus::TopKQueryResult topk_query_result;
{ {
TimeRecorder rc(phase_name); TimeRecorder rc(phase_name);
milvus::Status stat = milvus::Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result);
conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result_array);
std::cout << "SearchVector function call status: " << stat.message() << std::endl; std::cout << "SearchVector function call status: " << stat.message() << std::endl;
} }
auto finish = std::chrono::high_resolution_clock::now(); auto finish = std::chrono::high_resolution_clock::now();
std::cout << "SEARCHVECTOR COST: " std::cout << "SEARCHVECTOR COST: "
<< std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n"; << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
PrintSearchResult(search_record_array, topk_query_result_array); PrintSearchResult(search_record_array, topk_query_result);
CheckResult(search_record_array, topk_query_result_array); CheckResult(search_record_array, topk_query_result);
} }
} // namespace } // namespace
......
...@@ -189,22 +189,19 @@ ClientProxy::Insert(const std::string& table_name, const std::vector<RowRecord>& ...@@ -189,22 +189,19 @@ ClientProxy::Insert(const std::string& table_name, const std::vector<RowRecord>&
for (auto& record : record_array) { for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array(); ::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array();
for (size_t i = 0; i < record.data.size(); i++) { for (size_t i = 0; i < record.data.size(); i++) {
grpc_record->add_vector_data(record.data[i]); grpc_record->add_vector_data(record.data.begin(), record.data.end());
} }
} }
// Single thread // Single thread
::milvus::grpc::VectorIds vector_ids; ::milvus::grpc::VectorIds vector_ids;
if (!id_array.empty()) { if (!id_array.empty()) {
for (auto i = 0; i < id_array.size(); i++) { insert_param.add_row_id_array(id_array.begin(), id_array.end());
insert_param.add_row_id_array(id_array[i]);
}
client_ptr_->Insert(vector_ids, insert_param, status); client_ptr_->Insert(vector_ids, insert_param, status);
} else { } else {
client_ptr_->Insert(vector_ids, insert_param, status); client_ptr_->Insert(vector_ids, insert_param, status);
for (size_t i = 0; i < vector_ids.vector_id_array_size(); i++) { /* return Milvus generated ids back to user */
id_array.push_back(vector_ids.vector_id_array(i)); id_array.insert(id_array.end(), vector_ids.vector_id_array().begin(), vector_ids.vector_id_array().end());
}
} }
#endif #endif
} catch (std::exception& ex) { } catch (std::exception& ex) {
...@@ -217,7 +214,7 @@ ClientProxy::Insert(const std::string& table_name, const std::vector<RowRecord>& ...@@ -217,7 +214,7 @@ ClientProxy::Insert(const std::string& table_name, const std::vector<RowRecord>&
Status Status
ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array, ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe, const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) { TopKQueryResult& topk_query_result) {
try { try {
// step 1: convert vectors data // step 1: convert vectors data
::milvus::grpc::SearchParam search_param; ::milvus::grpc::SearchParam search_param;
...@@ -226,9 +223,7 @@ ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>& ...@@ -226,9 +223,7 @@ ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>&
search_param.set_nprobe(nprobe); search_param.set_nprobe(nprobe);
for (auto& record : query_record_array) { for (auto& record : query_record_array) {
::milvus::grpc::RowRecord* row_record = search_param.add_query_record_array(); ::milvus::grpc::RowRecord* row_record = search_param.add_query_record_array();
for (auto& rec : record.data) { row_record->add_vector_data(record.data.begin(), record.data.end());
row_record->add_vector_data(rec);
}
} }
// step 2: convert range array // step 2: convert range array
...@@ -239,21 +234,17 @@ ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>& ...@@ -239,21 +234,17 @@ ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>&
} }
// step 3: search vectors // step 3: search vectors
::milvus::grpc::TopKQueryResultList topk_query_result_list; ::milvus::grpc::TopKQueryResult result;
Status status = client_ptr_->Search(topk_query_result_list, search_param); Status status = client_ptr_->Search(result, search_param);
// step 4: convert result array // step 4: convert result array
for (uint64_t i = 0; i < topk_query_result_list.topk_query_result_size(); ++i) { topk_query_result.row_num = result.nq();
TopKQueryResult result; topk_query_result.topk = result.topk();
for (uint64_t j = 0; j < topk_query_result_list.topk_query_result(i).query_result_arrays_size(); ++j) { topk_query_result.ids.resize(result.ids_binary().size());
QueryResult query_result; memcpy(topk_query_result.ids.data(), result.ids_binary().data(), result.ids_binary().size());
query_result.id = topk_query_result_list.topk_query_result(i).query_result_arrays(j).id(); topk_query_result.distances.resize(result.distances_binary().size());
query_result.distance = topk_query_result_list.topk_query_result(i).query_result_arrays(j).distance(); memcpy(topk_query_result.distances.data(), result.distances_binary().data(), result.distances_binary().size());
result.query_result_arrays.emplace_back(query_result);
}
topk_query_result_array.emplace_back(result);
}
return status; return status;
} catch (std::exception& ex) { } catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "fail to search vectors: " + std::string(ex.what())); return Status(StatusCode::UnknownError, "fail to search vectors: " + std::string(ex.what()));
......
...@@ -60,7 +60,7 @@ class ClientProxy : public Connection { ...@@ -60,7 +60,7 @@ class ClientProxy : public Connection {
Status Status
Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array, Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe, const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) override; TopKQueryResult& topk_query_result) override;
Status Status
DescribeTable(const std::string& table_name, TableSchema& table_schema) override; DescribeTable(const std::string& table_name, TableSchema& table_schema) override;
......
...@@ -134,20 +134,20 @@ GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids, const ::milvus::grpc:: ...@@ -134,20 +134,20 @@ GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids, const ::milvus::grpc::
} }
Status Status
GrpcClient::Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list, GrpcClient::Search(::milvus::grpc::TopKQueryResult& topk_query_result,
const ::milvus::grpc::SearchParam& search_param) { const ::milvus::grpc::SearchParam& search_param) {
::milvus::grpc::TopKQueryResult query_result; ::milvus::grpc::TopKQueryResult query_result;
ClientContext context; ClientContext context;
::grpc::Status grpc_status = stub_->Search(&context, search_param, &topk_query_result_list); ::grpc::Status grpc_status = stub_->Search(&context, search_param, &topk_query_result);
if (!grpc_status.ok()) { if (!grpc_status.ok()) {
std::cerr << "SearchVector rpc failed!" << std::endl; std::cerr << "SearchVector rpc failed!" << std::endl;
std::cerr << grpc_status.error_message() << std::endl; std::cerr << grpc_status.error_message() << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message()); return Status(StatusCode::RPCFailed, grpc_status.error_message());
} }
if (topk_query_result_list.status().error_code() != grpc::SUCCESS) { if (topk_query_result.status().error_code() != grpc::SUCCESS) {
std::cerr << topk_query_result_list.status().reason() << std::endl; std::cerr << topk_query_result.status().reason() << std::endl;
return Status(StatusCode::ServerFailed, topk_query_result_list.status().reason()); return Status(StatusCode::ServerFailed, topk_query_result.status().reason());
} }
return Status::OK(); return Status::OK();
......
...@@ -57,7 +57,7 @@ class GrpcClient { ...@@ -57,7 +57,7 @@ class GrpcClient {
Insert(grpc::VectorIds& vector_ids, const grpc::InsertParam& insert_param, Status& status); Insert(grpc::VectorIds& vector_ids, const grpc::InsertParam& insert_param, Status& status);
Status Status
Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list, const grpc::SearchParam& search_param); Search(::milvus::grpc::TopKQueryResult& topk_query_result, const grpc::SearchParam& search_param);
Status Status
DescribeTable(grpc::TableSchema& grpc_schema, const std::string& table_name); DescribeTable(grpc::TableSchema& grpc_schema, const std::string& table_name);
......
...@@ -78,19 +78,14 @@ struct RowRecord { ...@@ -78,19 +78,14 @@ struct RowRecord {
std::vector<float> data; ///< Vector raw data std::vector<float> data; ///< Vector raw data
}; };
/**
* @brief Query result
*/
struct QueryResult {
int64_t id; ///< Output result
double distance; ///< Vector similarity distance
};
/** /**
* @brief TopK query result * @brief TopK query result
*/ */
struct TopKQueryResult { struct TopKQueryResult {
std::vector<QueryResult> query_result_arrays; ///< TopK query result int64_t row_num;
int64_t topk;
std::vector<int64_t> ids;
std::vector<float> distances;
}; };
/** /**
...@@ -261,7 +256,7 @@ class Connection { ...@@ -261,7 +256,7 @@ class Connection {
virtual Status virtual Status
Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array, Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe, const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) = 0; TopKQueryResult& topk_query_result) = 0;
/** /**
* @brief Show table description * @brief Show table description
......
...@@ -91,9 +91,8 @@ ConnectionImpl::Insert(const std::string& table_name, const std::vector<RowRecor ...@@ -91,9 +91,8 @@ ConnectionImpl::Insert(const std::string& table_name, const std::vector<RowRecor
Status Status
ConnectionImpl::Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array, ConnectionImpl::Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe, const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) { TopKQueryResult& topk_query_result) {
return client_proxy_->Search(table_name, query_record_array, query_range_array, topk, nprobe, return client_proxy_->Search(table_name, query_record_array, query_range_array, topk, nprobe, topk_query_result);
topk_query_result_array);
} }
Status Status
......
...@@ -62,7 +62,7 @@ class ConnectionImpl : public Connection { ...@@ -62,7 +62,7 @@ class ConnectionImpl : public Connection {
Status Status
Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array, Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe, const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) override; TopKQueryResult& topk_query_result) override;
Status Status
DescribeTable(const std::string& table_name, TableSchema& table_schema) override; DescribeTable(const std::string& table_name, TableSchema& table_schema) override;
......
...@@ -75,7 +75,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc: ...@@ -75,7 +75,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
::grpc::Status ::grpc::Status
GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request,
::milvus::grpc::TopKQueryResultList* response) { ::milvus::grpc::TopKQueryResult* response) {
std::vector<std::string> file_id_array; std::vector<std::string> file_id_array;
BaseTaskPtr task_ptr = SearchTask::Create(request, file_id_array, response); BaseTaskPtr task_ptr = SearchTask::Create(request, file_id_array, response);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
...@@ -87,7 +87,7 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc: ...@@ -87,7 +87,7 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
::grpc::Status ::grpc::Status
GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request,
::milvus::grpc::TopKQueryResultList* response) { ::milvus::grpc::TopKQueryResult* response) {
std::vector<std::string> file_id_array; std::vector<std::string> file_id_array;
for (int i = 0; i < request->file_id_array_size(); i++) { for (int i = 0; i < request->file_id_array_size(); i++) {
file_id_array.push_back(request->file_id_array(i)); file_id_array.push_back(request->file_id_array(i));
......
...@@ -145,7 +145,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { ...@@ -145,7 +145,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service {
*/ */
::grpc::Status ::grpc::Status
Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request,
::milvus::grpc::TopKQueryResultList* response) override; ::milvus::grpc::TopKQueryResult* response) override;
/** /**
* @brief Internal use query interface * @brief Internal use query interface
...@@ -169,7 +169,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { ...@@ -169,7 +169,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service {
*/ */
::grpc::Status ::grpc::Status
SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request,
::milvus::grpc::TopKQueryResultList* response) override; ::milvus::grpc::TopKQueryResult* response) override;
/** /**
* @brief Get table schema * @brief Get table schema
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <string.h> #include <string.h>
#include <map> #include <map>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
//#include <gperftools/profiler.h> //#include <gperftools/profiler.h>
...@@ -541,16 +542,16 @@ InsertTask::OnExecute() { ...@@ -541,16 +542,16 @@ InsertTask::OnExecute() {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchTask::SearchTask(const ::milvus::grpc::SearchParam* search_vector_infos, SearchTask::SearchTask(const ::milvus::grpc::SearchParam* search_vector_infos,
const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResultList* response) const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResult* response)
: GrpcBaseTask(DQL_TASK_GROUP), : GrpcBaseTask(DQL_TASK_GROUP),
search_param_(search_vector_infos), search_param_(search_vector_infos),
file_id_array_(file_id_array), file_id_array_(file_id_array),
topk_result_list(response) { topk_result_(response) {
} }
BaseTaskPtr BaseTaskPtr
SearchTask::Create(const ::milvus::grpc::SearchParam* search_vector_infos, SearchTask::Create(const ::milvus::grpc::SearchParam* search_vector_infos,
const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResultList* response) { const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResult* response) {
if (search_vector_infos == nullptr) { if (search_vector_infos == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!"; SERVER_LOG_ERROR << "grpc input is null!";
return nullptr; return nullptr;
...@@ -671,15 +672,20 @@ SearchTask::OnExecute() { ...@@ -671,15 +672,20 @@ SearchTask::OnExecute() {
size_t result_k = result_ids.size() / record_count; size_t result_k = result_ids.size() / record_count;
// step 7: construct result array // step 7: construct result array
for (size_t i = 0; i < record_count; i++) { topk_result_->set_nq(record_count);
::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result(); topk_result_->set_topk(result_ids.size() / record_count);
for (size_t j = 0; j < result_k; j++) {
::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays(); std::string ids_str;
size_t idx = i * result_k + j; size_t ids_len = sizeof(int64_t) * result_ids.size();
grpc_result->set_id(result_ids[idx]); ids_str.resize(ids_len);
grpc_result->set_distance(result_distances[idx]); memcpy((void*)(ids_str.data()), result_ids.data(), ids_len);
} topk_result_->set_ids_binary(std::move(ids_str));
}
std::string distances_str;
size_t distances_len = sizeof(float) * result_distances.size();
distances_str.resize(distances_len);
memcpy((void*)(distances_str.data()), result_distances.data(), distances_len);
topk_result_->set_distances_binary(std::move(distances_str));
// step 8: print time cost percent // step 8: print time cost percent
rc.RecordSection("construct result and send"); rc.RecordSection("construct result and send");
......
...@@ -153,11 +153,11 @@ class SearchTask : public GrpcBaseTask { ...@@ -153,11 +153,11 @@ class SearchTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::SearchParam* search_param, const std::vector<std::string>& file_id_array, Create(const ::milvus::grpc::SearchParam* search_param, const std::vector<std::string>& file_id_array,
::milvus::grpc::TopKQueryResultList* response); ::milvus::grpc::TopKQueryResult* response);
protected: protected:
SearchTask(const ::milvus::grpc::SearchParam* search_param, const std::vector<std::string>& file_id_array, SearchTask(const ::milvus::grpc::SearchParam* search_param, const std::vector<std::string>& file_id_array,
::milvus::grpc::TopKQueryResultList* response); ::milvus::grpc::TopKQueryResult* response);
Status Status
OnExecute() override; OnExecute() override;
...@@ -165,7 +165,7 @@ class SearchTask : public GrpcBaseTask { ...@@ -165,7 +165,7 @@ class SearchTask : public GrpcBaseTask {
private: private:
const ::milvus::grpc::SearchParam* search_param_; const ::milvus::grpc::SearchParam* search_param_;
std::vector<std::string> file_id_array_; std::vector<std::string> file_id_array_;
::milvus::grpc::TopKQueryResultList* topk_result_list; ::milvus::grpc::TopKQueryResult* topk_result_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -93,7 +93,7 @@ GrpcServer::StartService() { ...@@ -93,7 +93,7 @@ GrpcServer::StartService() {
builder.SetCompressionAlgorithmSupportStatus(GRPC_COMPRESS_STREAM_GZIP, true); builder.SetCompressionAlgorithmSupportStatus(GRPC_COMPRESS_STREAM_GZIP, true);
builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_STREAM_GZIP); builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_STREAM_GZIP);
builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_HIGH); builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_NONE);
GrpcRequestHandler service; GrpcRequestHandler service;
......
...@@ -202,9 +202,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) { ...@@ -202,9 +202,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) {
::milvus::grpc::VectorIds vector_ids; ::milvus::grpc::VectorIds vector_ids;
for (auto& record : record_array) { for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array(); ::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array();
for (size_t i = 0; i < record.size(); i++) { grpc_record->add_vector_data(record.begin(), record.end());
grpc_record->add_vector_data(record[i]);
}
} }
handler->Insert(&context, &request, &vector_ids); handler->Insert(&context, &request, &vector_ids);
ASSERT_EQ(vector_ids.vector_id_array_size(), VECTOR_COUNT); ASSERT_EQ(vector_ids.vector_id_array_size(), VECTOR_COUNT);
...@@ -213,7 +211,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) { ...@@ -213,7 +211,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) {
TEST_F(RpcHandlerTest, SEARCH_TEST) { TEST_F(RpcHandlerTest, SEARCH_TEST) {
::grpc::ServerContext context; ::grpc::ServerContext context;
::milvus::grpc::SearchParam request; ::milvus::grpc::SearchParam request;
::milvus::grpc::TopKQueryResultList response; ::milvus::grpc::TopKQueryResult response;
//test null input //test null input
handler->Search(&context, nullptr, &response); handler->Search(&context, nullptr, &response);
...@@ -241,9 +239,7 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) { ...@@ -241,9 +239,7 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) {
::milvus::grpc::InsertParam insert_param; ::milvus::grpc::InsertParam insert_param;
for (auto& record : record_array) { for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array(); ::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array();
for (size_t i = 0; i < record.size(); i++) { grpc_record->add_vector_data(record.begin(), record.end());
grpc_record->add_vector_data(record[i]);
}
} }
//insert vectors //insert vectors
insert_param.set_table_name(TABLE_NAME); insert_param.set_table_name(TABLE_NAME);
...@@ -254,9 +250,7 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) { ...@@ -254,9 +250,7 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) {
BuildVectors(0, 10, record_array); BuildVectors(0, 10, record_array);
for (auto& record : record_array) { for (auto& record : record_array) {
::milvus::grpc::RowRecord* row_record = request.add_query_record_array(); ::milvus::grpc::RowRecord* row_record = request.add_query_record_array();
for (auto& rec : record) { row_record->add_vector_data(record.begin(), record.end());
row_record->add_vector_data(rec);
}
} }
handler->Search(&context, &request, &response); handler->Search(&context, &request, &response);
...@@ -325,9 +319,7 @@ TEST_F(RpcHandlerTest, TABLES_TEST) { ...@@ -325,9 +319,7 @@ TEST_F(RpcHandlerTest, TABLES_TEST) {
for (auto& record : record_array) { for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array(); ::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array();
for (size_t i = 0; i < record.size(); i++) { grpc_record->add_vector_data(record.begin(), record.end());
grpc_record->add_vector_data(record[i]);
}
} }
//test vector_id size not equal to row record size //test vector_id size not equal to row record size
vector_ids.clear_vector_id_array(); vector_ids.clear_vector_id_array();
...@@ -342,9 +334,7 @@ TEST_F(RpcHandlerTest, TABLES_TEST) { ...@@ -342,9 +334,7 @@ TEST_F(RpcHandlerTest, TABLES_TEST) {
vector_ids.clear_vector_id_array(); vector_ids.clear_vector_id_array();
for (uint64_t i = 0; i < 10; ++i) { for (uint64_t i = 0; i < 10; ++i) {
::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array(); ::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array();
for (size_t j = 0; j < 10; j++) { grpc_record->add_vector_data(record_array[i].begin(), record_array[i].end());
grpc_record->add_vector_data(record_array[i][j]);
}
} }
handler->Insert(&context, &request, &vector_ids); handler->Insert(&context, &request, &vector_ids);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册