未验证 提交 6c8c4d18 编写于 作者: J Jin Hai 提交者: GitHub

Merge pull request #239 from fishpenguin/0.5.3-yk

improve grpc performance in search 
......@@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA.
## Feature
## Improvement
- \#204 - improve grpc performance in search
- \#207 - Add more unittest for config set/get
- \#208 - optimize unittest to support run single test more easily
......
sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts'
sh 'helm repo update'
dir ('milvus-helm') {
checkout([$class: 'GitSCM', branches: [[name: "0.5.0"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.0:refs/remotes/origin/0.5.0"]]])
checkout([$class: 'GitSCM', branches: [[name: "0.5.3"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.3:refs/remotes/origin/0.5.3"]]])
dir ("milvus-gpu") {
sh "helm install --wait --timeout 300 --set engine.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP --name ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu -f ci/db_backend/sqlite_values.yaml -f ci/filebeat/values.yaml --namespace milvus ."
}
......
......@@ -8,7 +8,7 @@ timeout(time: 90, unit: 'MINUTES') {
if (!fileExists('milvus-helm')) {
dir ("milvus-helm") {
checkout([$class: 'GitSCM', branches: [[name: "0.5.0"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.0:refs/remotes/origin/0.5.0"]]])
checkout([$class: 'GitSCM', branches: [[name: "0.5.3"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.3:refs/remotes/origin/0.5.3"]]])
}
}
dir ("milvus-helm") {
......
......@@ -10,7 +10,7 @@ timeout(time: 60, unit: 'MINUTES') {
// if (!fileExists('milvus-helm')) {
// dir ("milvus-helm") {
// checkout([$class: 'GitSCM', branches: [[name: "0.5.0"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.0:refs/remotes/origin/0.5.0"]]])
// checkout([$class: 'GitSCM', branches: [[name: "0.5.3"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.3:refs/remotes/origin/0.5.3"]]])
// }
// }
// dir ("milvus-helm") {
......
......@@ -71,7 +71,7 @@ if(MILVUS_VERSION_MAJOR STREQUAL ""
OR MILVUS_VERSION_MINOR STREQUAL ""
OR MILVUS_VERSION_PATCH STREQUAL "")
message(WARNING "Failed to determine Milvus version from git branch name")
set(MILVUS_VERSION "0.5.2")
set(MILVUS_VERSION "0.5.3")
endif()
message(STATUS "Build version = ${MILVUS_VERSION}")
......
We manually change two APIs in "milvus.pd.h":
add_vector_data()
add_row_id_array()
If proto files need be generated again, remember to re-change above APIs.
\ No newline at end of file
......@@ -201,60 +201,60 @@ void MilvusService::Stub::experimental_async::Insert(::grpc::ClientContext* cont
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::VectorIds>::Create(channel_.get(), cq, rpcmethod_Insert_, context, request, false);
}
::grpc::Status MilvusService::Stub::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::milvus::grpc::TopKQueryResultList* response) {
::grpc::Status MilvusService::Stub::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::milvus::grpc::TopKQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Search_, context, request, response);
}
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor);
}
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::AsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, true);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, true);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::PrepareAsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, false);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, false);
}
::grpc::Status MilvusService::Stub::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::milvus::grpc::TopKQueryResultList* response) {
::grpc::Status MilvusService::Stub::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::milvus::grpc::TopKQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_SearchInFiles_, context, request, response);
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor);
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::AsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, true);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, true);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::PrepareAsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, false);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, false);
}
::grpc::Status MilvusService::Stub::DescribeTable(::grpc::ClientContext* context, const ::milvus::grpc::TableName& request, ::milvus::grpc::TableSchema* response) {
......@@ -510,12 +510,12 @@ MilvusService::Service::Service() {
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[5],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchParam, ::milvus::grpc::TopKQueryResultList>(
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchParam, ::milvus::grpc::TopKQueryResult>(
std::mem_fn(&MilvusService::Service::Search), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[6],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchInFilesParam, ::milvus::grpc::TopKQueryResultList>(
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchInFilesParam, ::milvus::grpc::TopKQueryResult>(
std::mem_fn(&MilvusService::Service::SearchInFiles), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[7],
......@@ -597,14 +597,14 @@ MilvusService::Service::~Service() {
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response) {
::grpc::Status MilvusService::Service::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResult* response) {
(void) context;
(void) request;
(void) response;
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response) {
::grpc::Status MilvusService::Service::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResult* response) {
(void) context;
(void) request;
(void) response;
......
......@@ -84,24 +84,12 @@ message SearchInFilesParam {
/**
* @brief Query result params
*/
message QueryResult {
int64 id = 1;
double distance = 2;
}
/**
* @brief TopK query result
*/
message TopKQueryResult {
repeated QueryResult query_result_arrays = 1;
}
/**
* @brief List of topK query result
*/
message TopKQueryResultList {
Status status = 1;
repeated TopKQueryResult topk_query_result = 2;
int64 nq = 2;
int64 topk = 3;
bytes ids_binary = 4;
bytes distances_binary = 5;
}
/**
......@@ -227,7 +215,7 @@ service MilvusService {
*
* @return query result array.
*/
rpc Search(SearchParam) returns (TopKQueryResultList) {}
rpc Search(SearchParam) returns (TopKQueryResult) {}
/**
* @brief Internal use query interface
......@@ -241,7 +229,7 @@ service MilvusService {
*
* @return query result array.
*/
rpc SearchInFiles(SearchInFilesParam) returns (TopKQueryResultList) {}
rpc SearchInFiles(SearchInFilesParam) returns (TopKQueryResult) {}
/**
* @brief Get table schema
......
......@@ -57,22 +57,22 @@ PrintTableSchema(const milvus::TableSchema& tb_schema) {
void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array,
const std::vector<milvus::TopKQueryResult>& topk_query_result_array) {
const milvus::TopKQueryResult& topk_query_result) {
BLOCK_SPLITER
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
size_t nq = topk_query_result.row_num;
size_t topk = topk_query_result.topk;
std::cout << "Returned result count: " << nq * topk << std::endl;
int32_t index = 0;
for (auto& result : topk_query_result_array) {
for (size_t i = 0; i < nq; i++) {
auto search_id = search_record_array[index].first;
index++;
std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id) << " top "
<< std::to_string(result.query_result_arrays.size()) << " search result:" << std::endl;
for (auto& item : result.query_result_arrays) {
std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance);
std::cout << std::endl;
std::cout << "No." << index << " vector " << search_id << " top " << topk << " search result:" << std::endl;
for (size_t j = 0; j < topk; j++) {
size_t idx = i * nq + j;
std::cout << "\t" << topk_query_result.ids[idx] << "\t" << topk_query_result.distances[idx] << std::endl;
}
}
BLOCK_SPLITER
}
......@@ -166,11 +166,13 @@ class TimeRecorder {
void
CheckResult(const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array,
const std::vector<milvus::TopKQueryResult>& topk_query_result_array) {
const milvus::TopKQueryResult& topk_query_result) {
BLOCK_SPLITER
size_t nq = topk_query_result.row_num;
size_t result_k = topk_query_result.topk;
int64_t index = 0;
for (auto& result : topk_query_result_array) {
auto result_id = result.query_result_arrays[0].id;
for (size_t i = 0; i < nq; i++) {
auto result_id = topk_query_result.ids[i * result_k];
auto search_id = search_record_array[index++].first;
if (result_id != search_id) {
std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl;
......@@ -196,19 +198,18 @@ DoSearch(std::shared_ptr<milvus::Connection> conn,
}
auto start = std::chrono::high_resolution_clock::now();
std::vector<milvus::TopKQueryResult> topk_query_result_array;
milvus::TopKQueryResult topk_query_result;
{
TimeRecorder rc(phase_name);
milvus::Status stat =
conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result_array);
milvus::Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result);
std::cout << "SearchVector function call status: " << stat.message() << std::endl;
}
auto finish = std::chrono::high_resolution_clock::now();
std::cout << "SEARCHVECTOR COST: "
<< std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
PrintSearchResult(search_record_array, topk_query_result_array);
CheckResult(search_record_array, topk_query_result_array);
PrintSearchResult(search_record_array, topk_query_result);
CheckResult(search_record_array, topk_query_result);
}
} // namespace
......
......@@ -188,23 +188,19 @@ ClientProxy::Insert(const std::string& table_name, const std::vector<RowRecord>&
for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array();
for (size_t i = 0; i < record.data.size(); i++) {
grpc_record->add_vector_data(record.data[i]);
}
grpc_record->add_vector_data(record.data.begin(), record.data.end());
}
// Single thread
::milvus::grpc::VectorIds vector_ids;
if (!id_array.empty()) {
for (auto i = 0; i < id_array.size(); i++) {
insert_param.add_row_id_array(id_array[i]);
}
/* set user's ids */
insert_param.add_row_id_array(id_array.begin(), id_array.end());
client_ptr_->Insert(vector_ids, insert_param, status);
} else {
client_ptr_->Insert(vector_ids, insert_param, status);
for (size_t i = 0; i < vector_ids.vector_id_array_size(); i++) {
id_array.push_back(vector_ids.vector_id_array(i));
}
/* return Milvus generated ids back to user */
id_array.insert(id_array.end(), vector_ids.vector_id_array().begin(), vector_ids.vector_id_array().end());
}
#endif
} catch (std::exception& ex) {
......@@ -217,7 +213,7 @@ ClientProxy::Insert(const std::string& table_name, const std::vector<RowRecord>&
Status
ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) {
TopKQueryResult& topk_query_result) {
try {
// step 1: convert vectors data
::milvus::grpc::SearchParam search_param;
......@@ -226,9 +222,7 @@ ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>&
search_param.set_nprobe(nprobe);
for (auto& record : query_record_array) {
::milvus::grpc::RowRecord* row_record = search_param.add_query_record_array();
for (auto& rec : record.data) {
row_record->add_vector_data(rec);
}
row_record->add_vector_data(record.data.begin(), record.data.end());
}
// step 2: convert range array
......@@ -239,21 +233,17 @@ ClientProxy::Search(const std::string& table_name, const std::vector<RowRecord>&
}
// step 3: search vectors
::milvus::grpc::TopKQueryResultList topk_query_result_list;
Status status = client_ptr_->Search(topk_query_result_list, search_param);
::milvus::grpc::TopKQueryResult result;
Status status = client_ptr_->Search(result, search_param);
// step 4: convert result array
for (uint64_t i = 0; i < topk_query_result_list.topk_query_result_size(); ++i) {
TopKQueryResult result;
for (uint64_t j = 0; j < topk_query_result_list.topk_query_result(i).query_result_arrays_size(); ++j) {
QueryResult query_result;
query_result.id = topk_query_result_list.topk_query_result(i).query_result_arrays(j).id();
query_result.distance = topk_query_result_list.topk_query_result(i).query_result_arrays(j).distance();
result.query_result_arrays.emplace_back(query_result);
}
topk_query_result.row_num = result.nq();
topk_query_result.topk = result.topk();
topk_query_result.ids.resize(result.ids_binary().size());
memcpy(topk_query_result.ids.data(), result.ids_binary().data(), result.ids_binary().size());
topk_query_result.distances.resize(result.distances_binary().size());
memcpy(topk_query_result.distances.data(), result.distances_binary().data(), result.distances_binary().size());
topk_query_result_array.emplace_back(result);
}
return status;
} catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "fail to search vectors: " + std::string(ex.what()));
......
......@@ -60,7 +60,7 @@ class ClientProxy : public Connection {
Status
Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) override;
TopKQueryResult& topk_query_result) override;
Status
DescribeTable(const std::string& table_name, TableSchema& table_schema) override;
......
......@@ -134,20 +134,20 @@ GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids, const ::milvus::grpc::
}
Status
GrpcClient::Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list,
GrpcClient::Search(::milvus::grpc::TopKQueryResult& topk_query_result,
const ::milvus::grpc::SearchParam& search_param) {
::milvus::grpc::TopKQueryResult query_result;
ClientContext context;
::grpc::Status grpc_status = stub_->Search(&context, search_param, &topk_query_result_list);
::grpc::Status grpc_status = stub_->Search(&context, search_param, &topk_query_result);
if (!grpc_status.ok()) {
std::cerr << "SearchVector rpc failed!" << std::endl;
std::cerr << grpc_status.error_message() << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (topk_query_result_list.status().error_code() != grpc::SUCCESS) {
std::cerr << topk_query_result_list.status().reason() << std::endl;
return Status(StatusCode::ServerFailed, topk_query_result_list.status().reason());
if (topk_query_result.status().error_code() != grpc::SUCCESS) {
std::cerr << topk_query_result.status().reason() << std::endl;
return Status(StatusCode::ServerFailed, topk_query_result.status().reason());
}
return Status::OK();
......
......@@ -57,7 +57,7 @@ class GrpcClient {
Insert(grpc::VectorIds& vector_ids, const grpc::InsertParam& insert_param, Status& status);
Status
Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list, const grpc::SearchParam& search_param);
Search(::milvus::grpc::TopKQueryResult& topk_query_result, const grpc::SearchParam& search_param);
Status
DescribeTable(grpc::TableSchema& grpc_schema, const std::string& table_name);
......
......@@ -78,19 +78,14 @@ struct RowRecord {
std::vector<float> data; ///< Vector raw data
};
/**
* @brief Query result
*/
struct QueryResult {
int64_t id; ///< Output result
double distance; ///< Vector similarity distance
};
/**
* @brief TopK query result
*/
struct TopKQueryResult {
std::vector<QueryResult> query_result_arrays; ///< TopK query result
int64_t row_num;
int64_t topk;
std::vector<int64_t> ids;
std::vector<float> distances;
};
/**
......@@ -261,7 +256,7 @@ class Connection {
virtual Status
Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) = 0;
TopKQueryResult& topk_query_result) = 0;
/**
* @brief Show table description
......
......@@ -91,9 +91,8 @@ ConnectionImpl::Insert(const std::string& table_name, const std::vector<RowRecor
Status
ConnectionImpl::Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) {
return client_proxy_->Search(table_name, query_record_array, query_range_array, topk, nprobe,
topk_query_result_array);
TopKQueryResult& topk_query_result) {
return client_proxy_->Search(table_name, query_record_array, query_range_array, topk, nprobe, topk_query_result);
}
Status
......
......@@ -62,7 +62,7 @@ class ConnectionImpl : public Connection {
Status
Search(const std::string& table_name, const std::vector<RowRecord>& query_record_array,
const std::vector<Range>& query_range_array, int64_t topk, int64_t nprobe,
std::vector<TopKQueryResult>& topk_query_result_array) override;
TopKQueryResult& topk_query_result) override;
Status
DescribeTable(const std::string& table_name, TableSchema& table_schema) override;
......
......@@ -75,7 +75,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
::grpc::Status
GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request,
::milvus::grpc::TopKQueryResultList* response) {
::milvus::grpc::TopKQueryResult* response) {
std::vector<std::string> file_id_array;
BaseTaskPtr task_ptr = SearchTask::Create(request, file_id_array, response);
::milvus::grpc::Status grpc_status;
......@@ -87,7 +87,7 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
::grpc::Status
GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request,
::milvus::grpc::TopKQueryResultList* response) {
::milvus::grpc::TopKQueryResult* response) {
std::vector<std::string> file_id_array;
for (int i = 0; i < request->file_id_array_size(); i++) {
file_id_array.push_back(request->file_id_array(i));
......
......@@ -145,7 +145,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service {
*/
::grpc::Status
Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request,
::milvus::grpc::TopKQueryResultList* response) override;
::milvus::grpc::TopKQueryResult* response) override;
/**
* @brief Internal use query interface
......@@ -169,7 +169,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service {
*/
::grpc::Status
SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request,
::milvus::grpc::TopKQueryResultList* response) override;
::milvus::grpc::TopKQueryResult* response) override;
/**
* @brief Get table schema
......
......@@ -20,6 +20,7 @@
#include <string.h>
#include <map>
#include <string>
#include <utility>
#include <vector>
//#include <gperftools/profiler.h>
......@@ -541,16 +542,16 @@ InsertTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchTask::SearchTask(const ::milvus::grpc::SearchParam* search_vector_infos,
const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResultList* response)
const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResult* response)
: GrpcBaseTask(DQL_TASK_GROUP),
search_param_(search_vector_infos),
file_id_array_(file_id_array),
topk_result_list(response) {
topk_result_(response) {
}
BaseTaskPtr
SearchTask::Create(const ::milvus::grpc::SearchParam* search_vector_infos,
const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResultList* response) {
const std::vector<std::string>& file_id_array, ::milvus::grpc::TopKQueryResult* response) {
if (search_vector_infos == nullptr) {
SERVER_LOG_ERROR << "grpc input is null!";
return nullptr;
......@@ -671,15 +672,20 @@ SearchTask::OnExecute() {
size_t result_k = result_ids.size() / record_count;
// step 7: construct result array
for (size_t i = 0; i < record_count; i++) {
::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result();
for (size_t j = 0; j < result_k; j++) {
::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays();
size_t idx = i * result_k + j;
grpc_result->set_id(result_ids[idx]);
grpc_result->set_distance(result_distances[idx]);
}
}
topk_result_->set_nq(record_count);
topk_result_->set_topk(result_ids.size() / record_count);
std::string ids_str;
size_t ids_len = sizeof(int64_t) * result_ids.size();
ids_str.resize(ids_len);
memcpy((void*)(ids_str.data()), result_ids.data(), ids_len);
topk_result_->set_ids_binary(std::move(ids_str));
std::string distances_str;
size_t distances_len = sizeof(float) * result_distances.size();
distances_str.resize(distances_len);
memcpy((void*)(distances_str.data()), result_distances.data(), distances_len);
topk_result_->set_distances_binary(std::move(distances_str));
// step 8: print time cost percent
rc.RecordSection("construct result and send");
......
......@@ -153,11 +153,11 @@ class SearchTask : public GrpcBaseTask {
public:
static BaseTaskPtr
Create(const ::milvus::grpc::SearchParam* search_param, const std::vector<std::string>& file_id_array,
::milvus::grpc::TopKQueryResultList* response);
::milvus::grpc::TopKQueryResult* response);
protected:
SearchTask(const ::milvus::grpc::SearchParam* search_param, const std::vector<std::string>& file_id_array,
::milvus::grpc::TopKQueryResultList* response);
::milvus::grpc::TopKQueryResult* response);
Status
OnExecute() override;
......@@ -165,7 +165,7 @@ class SearchTask : public GrpcBaseTask {
private:
const ::milvus::grpc::SearchParam* search_param_;
std::vector<std::string> file_id_array_;
::milvus::grpc::TopKQueryResultList* topk_result_list;
::milvus::grpc::TopKQueryResult* topk_result_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -93,7 +93,7 @@ GrpcServer::StartService() {
builder.SetCompressionAlgorithmSupportStatus(GRPC_COMPRESS_STREAM_GZIP, true);
builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_STREAM_GZIP);
builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_HIGH);
builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_NONE);
GrpcRequestHandler service;
......
......@@ -202,9 +202,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) {
::milvus::grpc::VectorIds vector_ids;
for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array();
for (size_t i = 0; i < record.size(); i++) {
grpc_record->add_vector_data(record[i]);
}
grpc_record->add_vector_data(record.begin(), record.end());
}
handler->Insert(&context, &request, &vector_ids);
ASSERT_EQ(vector_ids.vector_id_array_size(), VECTOR_COUNT);
......@@ -213,7 +211,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) {
TEST_F(RpcHandlerTest, SEARCH_TEST) {
::grpc::ServerContext context;
::milvus::grpc::SearchParam request;
::milvus::grpc::TopKQueryResultList response;
::milvus::grpc::TopKQueryResult response;
//test null input
handler->Search(&context, nullptr, &response);
......@@ -241,22 +239,17 @@ TEST_F(RpcHandlerTest, SEARCH_TEST) {
::milvus::grpc::InsertParam insert_param;
for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array();
for (size_t i = 0; i < record.size(); i++) {
grpc_record->add_vector_data(record[i]);
}
grpc_record->add_vector_data(record.begin(), record.end());
}
//insert vectors
insert_param.set_table_name(TABLE_NAME);
::milvus::grpc::VectorIds vector_ids;
handler->Insert(&context, &insert_param, &vector_ids);
sleep(7);
BuildVectors(0, 10, record_array);
for (auto& record : record_array) {
::milvus::grpc::RowRecord* row_record = request.add_query_record_array();
for (auto& rec : record) {
row_record->add_vector_data(rec);
}
row_record->add_vector_data(record.begin(), record.end());
}
handler->Search(&context, &request, &response);
......@@ -325,9 +318,7 @@ TEST_F(RpcHandlerTest, TABLES_TEST) {
for (auto& record : record_array) {
::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array();
for (size_t i = 0; i < record.size(); i++) {
grpc_record->add_vector_data(record[i]);
}
grpc_record->add_vector_data(record.begin(), record.end());
}
//test vector_id size not equal to row record size
vector_ids.clear_vector_id_array();
......@@ -342,9 +333,7 @@ TEST_F(RpcHandlerTest, TABLES_TEST) {
vector_ids.clear_vector_id_array();
for (uint64_t i = 0; i < 10; ++i) {
::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array();
for (size_t j = 0; j < 10; j++) {
grpc_record->add_vector_data(record_array[i][j]);
}
grpc_record->add_vector_data(record_array[i].begin(), record_array[i].end());
}
handler->Insert(&context, &request, &vector_ids);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册