From b9c11671c4fd59b887f3df8095af1e805e98b9e0 Mon Sep 17 00:00:00 2001 From: kun yu Date: Mon, 29 Jul 2019 18:05:14 +0800 Subject: [PATCH] fix search error handle bug Former-commit-id: dfded9c68611b0622581deebd1e94363566af432 --- .../examples/grpcsimple/src/ClientTest.cpp | 8 ++------ cpp/src/server/grpc_impl/RequestHandler.cpp | 20 +++++++++++++++---- cpp/src/server/grpc_impl/RequestTask.cpp | 12 ++++++++++- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp index b2fe7bcd..4c24cddc 100644 --- a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp @@ -174,7 +174,7 @@ namespace { std::vector topk_query_result_array; { TimeRecorder rc(phase_name); - Status stat = conn->SearchVector(TABLE_NAME, record_array, query_range_array, TOP_K, topk_query_result_array); + Status stat = conn->SearchVector("qqq", record_array, query_range_array, TOP_K, topk_query_result_array); std::cout << "SearchVector function call status: " << stat.ToString() << std::endl; } @@ -244,11 +244,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { auto start = std::chrono::high_resolution_clock::now(); - std::vector null_record; - RowRecord rowRecord; - rowRecord.data.resize(0); - null_record.push_back(rowRecord); - Status stat = conn->InsertVector(TABLE_NAME, null_record, record_ids); + Status stat = conn->InsertVector(TABLE_NAME, record_array, record_ids); auto finish = std::chrono::high_resolution_clock::now(); std::cout << "InsertVector cost: " << std::chrono::duration_cast>(finish - start).count() << "s\n"; diff --git a/cpp/src/server/grpc_impl/RequestHandler.cpp b/cpp/src/server/grpc_impl/RequestHandler.cpp index 748461a3..43252179 100644 --- a/cpp/src/server/grpc_impl/RequestHandler.cpp +++ b/cpp/src/server/grpc_impl/RequestHandler.cpp @@ -59,8 +59,14 @@ RequestHandler::InsertVector(::grpc::ServerContext* context, const ::milvus::grp RequestHandler::SearchVector(::grpc::ServerContext* context, const ::milvus::grpc::SearchVectorInfos* request, ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>* writer) { std::vector file_id_array; BaseTaskPtr task_ptr = SearchVectorTask::Create(*request, file_id_array, *writer); - RequestScheduler::ExecTask(task_ptr, nullptr); - return ::grpc::Status::OK; + ::milvus::grpc::Status grpc_status; + RequestScheduler::ExecTask(task_ptr, &grpc_status); + if (grpc_status.error_code() != SERVER_SUCCESS) { + ::grpc::Status status(::grpc::INVALID_ARGUMENT, grpc_status.reason()); + return status; + } else { + return ::grpc::Status::OK; + } } ::grpc::Status @@ -96,8 +102,14 @@ RequestHandler::GetTableRowCount(::grpc::ServerContext* context, const ::milvus: ::grpc::Status RequestHandler::ShowTables(::grpc::ServerContext* context, const ::milvus::grpc::Command* request, ::grpc::ServerWriter< ::milvus::grpc::TableName>* writer) { BaseTaskPtr task_ptr = ShowTablesTask::Create(*writer); - RequestScheduler::ExecTask(task_ptr, nullptr); - return ::grpc::Status::OK; + ::milvus::grpc::Status grpc_status; + RequestScheduler::ExecTask(task_ptr, &grpc_status); + if (grpc_status.error_code() != SERVER_SUCCESS) { + ::grpc::Status status(::grpc::UNKNOWN, grpc_status.reason()); + return status; + } else { + return ::grpc::Status::OK; + } } ::grpc::Status diff --git a/cpp/src/server/grpc_impl/RequestTask.cpp b/cpp/src/server/grpc_impl/RequestTask.cpp index e0374d70..e03e9274 100644 --- a/cpp/src/server/grpc_impl/RequestTask.cpp +++ b/cpp/src/server/grpc_impl/RequestTask.cpp @@ -486,7 +486,7 @@ ServerError SearchVectorTask::OnExecute() { int top_k_ = search_vector_infos_.topk(); - if(top_k_ <= 0) { + if(top_k_ <= 0 || top_k_ > 1024) { return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string( top_k_)); } @@ -535,6 +535,16 @@ ServerError SearchVectorTask::OnExecute() { //TODO for (size_t i = 0; i < record_array_size; i++) { for (size_t j = 0; j < table_info.dimension_; j++) { + if (search_vector_infos_.query_record_array(i).vector_data().empty()) { + return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Query record float array is empty"); + } + uint64_t query_vec_dim = search_vector_infos_.query_record_array(i).vector_data().size(); + if (query_vec_dim != table_info.dimension_) { + ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; + std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(query_vec_dim) + + " vs. table dimension:" + std::to_string(table_info.dimension_); + return SetError(error_code, error_msg); + } vec_f[i * table_info.dimension_ + j] = search_vector_infos_.query_record_array(i).vector_data(j); } } -- GitLab