diff --git a/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp b/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp index be0c6adebb094dd621704c0c3d7c9c08002ea5b2..0ae711c32a2828e8b9aabfc7354d75852a0f3f8c 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -18,7 +18,7 @@ GrpcRequestHandler::CreateTable(::grpc::ServerContext *context, const ::milvus::grpc::TableSchema *request, ::milvus::grpc::Status *response) { - BaseTaskPtr task_ptr = CreateTableTask::Create(*request); + BaseTaskPtr task_ptr = CreateTableTask::Create(request); GrpcRequestScheduler::ExecTask(task_ptr, response); return ::grpc::Status::OK; } @@ -52,7 +52,7 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext *context, const ::milvus::grpc::IndexParam *request, ::milvus::grpc::Status *response) { - BaseTaskPtr task_ptr = CreateIndexTask::Create(*request); + BaseTaskPtr task_ptr = CreateIndexTask::Create(request); GrpcRequestScheduler::ExecTask(task_ptr, response); return ::grpc::Status::OK; } @@ -62,7 +62,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext *context, const ::milvus::grpc::InsertParam *request, ::milvus::grpc::VectorIds *response) { - BaseTaskPtr task_ptr = InsertTask::Create(*request, *response); + BaseTaskPtr task_ptr = InsertTask::Create(request, response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); response->mutable_status()->set_reason(grpc_status.reason()); @@ -76,7 +76,7 @@ GrpcRequestHandler::Search(::grpc::ServerContext *context, ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) { std::vector file_id_array; - BaseTaskPtr task_ptr = SearchTask::Create(*request, file_id_array, *writer); + BaseTaskPtr task_ptr = SearchTask::Create(request, file_id_array, writer); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); if (grpc_status.error_code() != SERVER_SUCCESS) { @@ -93,7 +93,11 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext *context, ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) { std::vector file_id_array; - BaseTaskPtr task_ptr = SearchTask::Create(request->search_param(), file_id_array, *writer); + for(int i = 0; i < request->file_id_array_size(); i++) { + file_id_array.push_back(request->file_id_array(i)); + } + ::milvus::grpc::SearchInFilesParam *request_mutable = const_cast<::milvus::grpc::SearchInFilesParam *>(request); + BaseTaskPtr task_ptr = SearchTask::Create(request_mutable->mutable_search_param(), file_id_array, writer); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); if (grpc_status.error_code() != SERVER_SUCCESS) { @@ -109,7 +113,7 @@ GrpcRequestHandler::DescribeTable(::grpc::ServerContext *context, const ::milvus::grpc::TableName *request, ::milvus::grpc::TableSchema *response) { - BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), *response); + BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); response->mutable_table_name()->mutable_status()->set_error_code(grpc_status.error_code()); @@ -137,7 +141,7 @@ GrpcRequestHandler::ShowTables(::grpc::ServerContext *context, const ::milvus::grpc::Command *request, ::grpc::ServerWriter<::milvus::grpc::TableName> *writer) { - BaseTaskPtr task_ptr = ShowTablesTask::Create(*writer); + BaseTaskPtr task_ptr = ShowTablesTask::Create(writer); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); if (grpc_status.error_code() != SERVER_SUCCESS) { @@ -167,7 +171,7 @@ GrpcRequestHandler::Cmd(::grpc::ServerContext *context, GrpcRequestHandler::DeleteByRange(::grpc::ServerContext *context, const ::milvus::grpc::DeleteByRangeParam *request, ::milvus::grpc::Status *response) { - BaseTaskPtr task_ptr = DeleteByRangeTask::Create(*request); + BaseTaskPtr task_ptr = DeleteByRangeTask::Create(request); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); response->set_error_code(grpc_status.error_code()); @@ -191,7 +195,7 @@ GrpcRequestHandler::PreloadTable(::grpc::ServerContext *context, GrpcRequestHandler::DescribeIndex(::grpc::ServerContext *context, const ::milvus::grpc::TableName *request, ::milvus::grpc::IndexParam *response) { - BaseTaskPtr task_ptr = DescribeIndexTask::Create(request->table_name(), *response); + BaseTaskPtr task_ptr = DescribeIndexTask::Create(request->table_name(), response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); response->mutable_table_name()->mutable_status()->set_reason(grpc_status.reason()); diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp index a185272f269f2048f20b38326dcfbefad9eefefa..a4b8d68c296e4066f263ecf41f148e18845e67bc 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp @@ -107,14 +107,18 @@ namespace { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema &schema) +CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema *schema) : GrpcBaseTask(DDL_DML_TASK_GROUP), schema_(schema) { } BaseTaskPtr -CreateTableTask::Create(const ::milvus::grpc::TableSchema &schema) { +CreateTableTask::Create(const ::milvus::grpc::TableSchema *schema) { + if(schema == nullptr) { + SERVER_LOG_ERROR << "grpc input is null!"; + return nullptr; + } return std::shared_ptr(new CreateTableTask(schema)); } @@ -124,26 +128,26 @@ CreateTableTask::OnExecute() { try { //step 1: check arguments - ServerError res = ValidationUtil::ValidateTableName(schema_.table_name().table_name()); + ServerError res = ValidationUtil::ValidateTableName(schema_->table_name().table_name()); if (res != SERVER_SUCCESS) { - return SetError(res, "Invalid table name: " + schema_.table_name().table_name()); + return SetError(res, "Invalid table name: " + schema_->table_name().table_name()); } - res = ValidationUtil::ValidateTableDimension(schema_.dimension()); + res = ValidationUtil::ValidateTableDimension(schema_->dimension()); if (res != SERVER_SUCCESS) { - return SetError(res, "Invalid table dimension: " + std::to_string(schema_.dimension())); + return SetError(res, "Invalid table dimension: " + std::to_string(schema_->dimension())); } - res = ValidationUtil::ValidateTableIndexFileSize(schema_.index_file_size()); + res = ValidationUtil::ValidateTableIndexFileSize(schema_->index_file_size()); if(res != SERVER_SUCCESS) { - return SetError(res, "Invalid index file size: " + std::to_string(schema_.index_file_size())); + return SetError(res, "Invalid index file size: " + std::to_string(schema_->index_file_size())); } //step 2: construct table schema engine::meta::TableSchema table_info; - table_info.table_id_ = schema_.table_name().table_name(); - table_info.dimension_ = (uint16_t) schema_.dimension(); - table_info.index_file_size_ = schema_.index_file_size(); + table_info.table_id_ = schema_->table_name().table_name(); + table_info.dimension_ = (uint16_t) schema_->dimension(); + table_info.index_file_size_ = schema_->index_file_size(); //step 3: create table engine::Status stat = DBWrapper::DB()->CreateTable(table_info); @@ -162,14 +166,14 @@ CreateTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema &schema) +DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema *schema) : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name), schema_(schema) { } BaseTaskPtr -DescribeTableTask::Create(const std::string &table_name, ::milvus::grpc::TableSchema &schema) { +DescribeTableTask::Create(const std::string &table_name, ::milvus::grpc::TableSchema *schema) { return std::shared_ptr(new DescribeTableTask(table_name, schema)); } @@ -192,8 +196,8 @@ DescribeTableTask::OnExecute() { return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); } - schema_.mutable_table_name()->set_table_name(table_info.table_id_); - schema_.set_dimension(table_info.dimension_); + schema_->mutable_table_name()->set_table_name(table_info.table_id_); + schema_->set_dimension(table_info.dimension_); } catch (std::exception &ex) { return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); @@ -205,13 +209,17 @@ DescribeTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CreateIndexTask::CreateIndexTask(const ::milvus::grpc::IndexParam &index_param) +CreateIndexTask::CreateIndexTask(const ::milvus::grpc::IndexParam *index_param) : GrpcBaseTask(DDL_DML_TASK_GROUP), index_param_(index_param) { } BaseTaskPtr -CreateIndexTask::Create(const ::milvus::grpc::IndexParam &index_param) { +CreateIndexTask::Create(const ::milvus::grpc::IndexParam *index_param) { + if(index_param == nullptr) { + SERVER_LOG_ERROR << "grpc input is null!"; + return nullptr; + } return std::shared_ptr(new CreateIndexTask(index_param)); } @@ -221,7 +229,7 @@ CreateIndexTask::OnExecute() { TimeRecorder rc("CreateIndexTask"); //step 1: check arguments - std::string table_name_ = index_param_.table_name().table_name(); + std::string table_name_ = index_param_->table_name().table_name(); ServerError res = ValidationUtil::ValidateTableName(table_name_); if (res != SERVER_SUCCESS) { return SetError(res, "Invalid table name: " + table_name_); @@ -237,26 +245,27 @@ CreateIndexTask::OnExecute() { return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); } - res = ValidationUtil::ValidateTableIndexType(index_param_.mutable_index()->index_type()); + auto &grpc_index = index_param_->index(); + res = ValidationUtil::ValidateTableIndexType(grpc_index.index_type()); if(res != SERVER_SUCCESS) { - return SetError(res, "Invalid index type: " + std::to_string(index_param_.mutable_index()->index_type())); + return SetError(res, "Invalid index type: " + std::to_string(grpc_index.index_type())); } - res = ValidationUtil::ValidateTableIndexNlist(index_param_.mutable_index()->nlist()); + res = ValidationUtil::ValidateTableIndexNlist(grpc_index.nlist()); if(res != SERVER_SUCCESS) { - return SetError(res, "Invalid index nlist: " + std::to_string(index_param_.mutable_index()->nlist())); + return SetError(res, "Invalid index nlist: " + std::to_string(grpc_index.nlist())); } - res = ValidationUtil::ValidateTableIndexMetricType(index_param_.mutable_index()->metric_type()); + res = ValidationUtil::ValidateTableIndexMetricType(grpc_index.metric_type()); if(res != SERVER_SUCCESS) { - return SetError(res, "Invalid index metric type: " + std::to_string(index_param_.mutable_index()->metric_type())); + return SetError(res, "Invalid index metric type: " + std::to_string(grpc_index.metric_type())); } //step 2: check table existence engine::TableIndex index; - index.engine_type_ = index_param_.mutable_index()->index_type(); - index.nlist_ = index_param_.mutable_index()->nlist(); - index.metric_type_ = index_param_.mutable_index()->metric_type(); + index.engine_type_ = grpc_index.index_type(); + index.nlist_ = grpc_index.nlist(); + index.metric_type_ = grpc_index.metric_type(); stat = DBWrapper::DB()->CreateIndex(table_name_, index); if (!stat.ok()) { return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString()); @@ -361,14 +370,14 @@ DropTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -ShowTablesTask::ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> &writer) +ShowTablesTask::ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> *writer) : GrpcBaseTask(DDL_DML_TASK_GROUP), writer_(writer) { } BaseTaskPtr -ShowTablesTask::Create(::grpc::ServerWriter<::milvus::grpc::TableName> &writer) { +ShowTablesTask::Create(::grpc::ServerWriter<::milvus::grpc::TableName> *writer) { return std::shared_ptr(new ShowTablesTask(writer)); } @@ -383,7 +392,7 @@ ShowTablesTask::OnExecute() { for (auto &schema : schema_array) { ::milvus::grpc::TableName tableName; tableName.set_table_name(schema.table_id_); - if (!writer_.Write(tableName)) { + if (!writer_->Write(tableName)) { return SetError(SERVER_WRITE_ERROR, "Write table name failed!"); } } @@ -391,17 +400,21 @@ ShowTablesTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -InsertTask::InsertTask(const ::milvus::grpc::InsertParam &insert_param, - ::milvus::grpc::VectorIds &record_ids) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - insert_param_(insert_param), - record_ids_(record_ids) { - record_ids_.Clear(); +InsertTask::InsertTask(const ::milvus::grpc::InsertParam *insert_param, + ::milvus::grpc::VectorIds *record_ids) + : GrpcBaseTask(DDL_DML_TASK_GROUP), + insert_param_(insert_param), + record_ids_(record_ids) { + record_ids_->Clear(); } BaseTaskPtr -InsertTask::Create(const ::milvus::grpc::InsertParam &insert_param, - ::milvus::grpc::VectorIds &record_ids) { +InsertTask::Create(const ::milvus::grpc::InsertParam *insert_param, + ::milvus::grpc::VectorIds *record_ids) { + if(insert_param == nullptr) { + SERVER_LOG_ERROR << "grpc input is null!"; + return nullptr; + } return std::shared_ptr(new InsertTask(insert_param, record_ids)); } @@ -411,16 +424,16 @@ InsertTask::OnExecute() { TimeRecorder rc("InsertVectorTask"); //step 1: check arguments - ServerError res = ValidationUtil::ValidateTableName(insert_param_.table_name()); + ServerError res = ValidationUtil::ValidateTableName(insert_param_->table_name()); if (res != SERVER_SUCCESS) { - return SetError(res, "Invalid table name: " + insert_param_.table_name()); + return SetError(res, "Invalid table name: " + insert_param_->table_name()); } - if (insert_param_.row_record_array().empty()) { + if (insert_param_->row_record_array().empty()) { return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); } - if (!record_ids_.vector_id_array().empty()) { - if (record_ids_.vector_id_array().size() != insert_param_.row_record_array_size()) { + if (!record_ids_->vector_id_array().empty()) { + if (record_ids_->vector_id_array().size() != insert_param_->row_record_array_size()) { return SetError(SERVER_ILLEGAL_VECTOR_ID, "Size of vector ids is not equal to row record array size"); } @@ -428,12 +441,12 @@ InsertTask::OnExecute() { //step 2: check table existence engine::meta::TableSchema table_info; - table_info.table_id_ = insert_param_.table_name(); + table_info.table_id_ = insert_param_->table_name(); engine::Status stat = DBWrapper::DB()->DescribeTable(table_info); if (!stat.ok()) { if (stat.IsNotFound()) { return SetError(SERVER_TABLE_NOT_EXIST, - "Table " + insert_param_.table_name() + " not exists"); + "Table " + insert_param_->table_name() + " not exists"); } else { return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); } @@ -443,7 +456,7 @@ InsertTask::OnExecute() { uint64_t row_count = 0; DBWrapper::DB()->GetTableRowCount(table_info.table_id_, row_count); bool empty_table = (row_count == 0); - bool user_provide_ids = !insert_param_.row_id_array().empty(); + bool user_provide_ids = !insert_param_->row_id_array().empty(); if(!empty_table) { //user already provided id before, all insert action require user id if(engine::utils::UserDefinedId(table_info.flag_) && !user_provide_ids) { @@ -465,14 +478,14 @@ InsertTask::OnExecute() { #endif //step 3: prepare float data - std::vector vec_f(insert_param_.row_record_array_size() * table_info.dimension_, 0); + std::vector vec_f(insert_param_->row_record_array_size() * table_info.dimension_, 0); // TODO: change to one dimension array in protobuf or use multiple-thread to copy the data - for (size_t i = 0; i < insert_param_.row_record_array_size(); i++) { - if (insert_param_.row_record_array(i).vector_data().empty()) { + for (size_t i = 0; i < insert_param_->row_record_array_size(); i++) { + if (insert_param_->row_record_array(i).vector_data().empty()) { return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record float array is empty"); } - uint64_t vec_dim = insert_param_.row_record_array(i).vector_data().size(); + uint64_t vec_dim = insert_param_->row_record_array(i).vector_data().size(); if (vec_dim != table_info.dimension_) { ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(vec_dim) @@ -481,31 +494,31 @@ InsertTask::OnExecute() { return SetError(error_code, error_msg); } memcpy(&vec_f[i * table_info.dimension_], - insert_param_.row_record_array(i).vector_data().data(), + insert_param_->row_record_array(i).vector_data().data(), table_info.dimension_ * sizeof(float)); } rc.ElapseFromBegin("prepare vectors data"); //step 4: insert vectors - auto vec_count = (uint64_t) insert_param_.row_record_array_size(); - std::vector vec_ids(insert_param_.row_id_array_size(), 0); - if(!insert_param_.row_id_array().empty()) { - const int64_t* src_data = insert_param_.row_id_array().data(); + auto vec_count = (uint64_t) insert_param_->row_record_array_size(); + std::vector vec_ids(insert_param_->row_id_array_size(), 0); + if(!insert_param_->row_id_array().empty()) { + const int64_t* src_data = insert_param_->row_id_array().data(); int64_t* target_data = vec_ids.data(); - memcpy(target_data, src_data, (size_t)(sizeof(int64_t)*insert_param_.row_id_array_size())); + memcpy(target_data, src_data, (size_t)(sizeof(int64_t)*insert_param_->row_id_array_size())); } - stat = DBWrapper::DB()->InsertVectors(insert_param_.table_name(), vec_count, vec_f.data(), vec_ids); + stat = DBWrapper::DB()->InsertVectors(insert_param_->table_name(), vec_count, vec_f.data(), vec_ids); rc.ElapseFromBegin("add vectors to engine"); if (!stat.ok()) { return SetError(SERVER_CACHE_ERROR, "Cache error: " + stat.ToString()); } for (int64_t id : vec_ids) { - record_ids_.add_vector_id_array(id); + record_ids_->add_vector_id_array(id); } - auto ids_size = record_ids_.vector_id_array_size(); + auto ids_size = record_ids_->vector_id_array_size(); if (ids_size != vec_count) { std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return " + std::to_string(ids_size) + " id"; @@ -514,7 +527,7 @@ InsertTask::OnExecute() { //step 5: update table flag if(empty_table && user_provide_ids) { - stat = DBWrapper::DB()->UpdateTableFlag(insert_param_.table_name(), + stat = DBWrapper::DB()->UpdateTableFlag(insert_param_->table_name(), table_info.flag_ | engine::meta::FLAG_MASK_USERID); } @@ -533,9 +546,9 @@ InsertTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -SearchTask::SearchTask(const ::milvus::grpc::SearchParam &search_vector_infos, +SearchTask::SearchTask(const ::milvus::grpc::SearchParam *search_vector_infos, const std::vector &file_id_array, - ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer) + ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) : GrpcBaseTask(DQL_TASK_GROUP), search_param_(search_vector_infos), file_id_array_(file_id_array), @@ -544,9 +557,13 @@ SearchTask::SearchTask(const ::milvus::grpc::SearchParam &search_vector_infos, } BaseTaskPtr -SearchTask::Create(const ::milvus::grpc::SearchParam &search_vector_infos, - const std::vector &file_id_array, - ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer) { +SearchTask::Create(const ::milvus::grpc::SearchParam *search_vector_infos, + const std::vector &file_id_array, + ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) { + if(search_vector_infos == nullptr) { + SERVER_LOG_ERROR << "grpc input is null!"; + return nullptr; + } return std::shared_ptr(new SearchTask(search_vector_infos, file_id_array, writer)); } @@ -557,24 +574,24 @@ SearchTask::OnExecute() { TimeRecorder rc("SearchTask"); //step 1: check arguments - std::string table_name_ = search_param_.table_name(); + std::string table_name_ = search_param_->table_name(); ServerError res = ValidationUtil::ValidateTableName(table_name_); if (res != SERVER_SUCCESS) { return SetError(res, "Invalid table name: " + table_name_); } - int64_t top_k_ = search_param_.topk(); + int64_t top_k_ = search_param_->topk(); if (top_k_ <= 0 || top_k_ > 1024) { return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_)); } - int64_t nprobe = search_param_.nprobe(); + int64_t nprobe = search_param_->nprobe(); if (nprobe <= 0) { return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe)); } - if (search_param_.query_record_array().empty()) { + if (search_param_->query_record_array().empty()) { return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); } @@ -596,8 +613,8 @@ SearchTask::OnExecute() { std::string error_msg; std::vector<::milvus::grpc::Range> range_array; - for (size_t i = 0; i < search_param_.query_range_array_size(); i++) { - range_array.emplace_back(search_param_.query_range_array(i)); + for (size_t i = 0; i < search_param_->query_range_array_size(); i++) { + range_array.emplace_back(search_param_->query_range_array(i)); } ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg); if (error_code != SERVER_SUCCESS) { @@ -614,13 +631,13 @@ SearchTask::OnExecute() { #endif //step 3: prepare float data - auto record_array_size = search_param_.query_record_array_size(); + auto record_array_size = search_param_->query_record_array_size(); std::vector vec_f(record_array_size * table_info.dimension_, 0); for (size_t i = 0; i < record_array_size; i++) { - if (search_param_.query_record_array(i).vector_data().empty()) { + if (search_param_->query_record_array(i).vector_data().empty()) { return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Query record float array is empty"); } - uint64_t query_vec_dim = search_param_.query_record_array(i).vector_data().size(); + uint64_t query_vec_dim = search_param_->query_record_array(i).vector_data().size(); if (query_vec_dim != table_info.dimension_) { ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(query_vec_dim) @@ -629,14 +646,14 @@ SearchTask::OnExecute() { } memcpy(&vec_f[i * table_info.dimension_], - search_param_.query_record_array(i).vector_data().data(), + search_param_->query_record_array(i).vector_data().data(), table_info.dimension_ * sizeof(float)); } rc.ElapseFromBegin("prepare vector data"); //step 4: search vectors engine::QueryResults results; - auto record_count = (uint64_t) search_param_.query_record_array().size(); + auto record_count = (uint64_t) search_param_->query_record_array().size(); if (file_id_array_.empty()) { stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(), @@ -666,14 +683,14 @@ SearchTask::OnExecute() { //step 5: construct result array for (uint64_t i = 0; i < record_count; i++) { auto &result = results[i]; - const auto &record = search_param_.query_record_array(i); + const auto &record = search_param_->query_record_array(i); ::milvus::grpc::TopKQueryResult grpc_topk_result; for (auto &pair : result) { ::milvus::grpc::QueryResult *grpc_result = grpc_topk_result.add_query_result_arrays(); grpc_result->set_id(pair.first); grpc_result->set_distance(pair.second); } - if (!writer_.Write(grpc_topk_result)) { + if (!writer_->Write(grpc_topk_result)) { return SetError(SERVER_WRITE_ERROR, "Write topk result failed!"); } } @@ -765,13 +782,17 @@ CmdTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DeleteByRangeTask::DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param) +DeleteByRangeTask::DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param) : GrpcBaseTask(DDL_DML_TASK_GROUP), delete_by_range_param_(delete_by_range_param){ } BaseTaskPtr -DeleteByRangeTask::Create(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param) { +DeleteByRangeTask::Create(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param) { + if(delete_by_range_param == nullptr) { + SERVER_LOG_ERROR << "grpc input is null!"; + return nullptr; + } return std::shared_ptr(new DeleteByRangeTask(delete_by_range_param)); } @@ -781,7 +802,7 @@ DeleteByRangeTask::OnExecute() { TimeRecorder rc("DeleteByRangeTask"); //step 1: check arguments - std::string table_name = delete_by_range_param_.table_name(); + std::string table_name = delete_by_range_param_->table_name(); ServerError res = ValidationUtil::ValidateTableName(table_name); if (res != SERVER_SUCCESS) { return SetError(res, "Invalid table name: " + table_name); @@ -807,7 +828,7 @@ DeleteByRangeTask::OnExecute() { std::string error_msg; std::vector<::milvus::grpc::Range> range_array; - range_array.emplace_back(delete_by_range_param_.range()); + range_array.emplace_back(delete_by_range_param_->range()); ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg); if (error_code != SERVER_SUCCESS) { return SetError(error_code, error_msg); @@ -870,7 +891,7 @@ PreloadTableTask::OnExecute() { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// DescribeIndexTask::DescribeIndexTask(const std::string &table_name, - ::milvus::grpc::IndexParam &index_param) + ::milvus::grpc::IndexParam *index_param) : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name), index_param_(index_param) { @@ -879,7 +900,7 @@ DescribeIndexTask::DescribeIndexTask(const std::string &table_name, BaseTaskPtr DescribeIndexTask::Create(const std::string &table_name, - ::milvus::grpc::IndexParam &index_param){ + ::milvus::grpc::IndexParam *index_param){ return std::shared_ptr(new DescribeIndexTask(table_name, index_param)); } @@ -901,10 +922,10 @@ DescribeIndexTask::OnExecute() { return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); } - index_param_.mutable_table_name()->set_table_name(table_name_); - index_param_.mutable_index()->set_index_type(index.engine_type_); - index_param_.mutable_index()->set_nlist(index.nlist_); - index_param_.mutable_index()->set_metric_type(index.metric_type_); + index_param_->mutable_table_name()->set_table_name(table_name_); + index_param_->mutable_index()->set_index_type(index.engine_type_); + index_param_->mutable_index()->set_nlist(index.nlist_); + index_param_->mutable_index()->set_metric_type(index.metric_type_); rc.ElapseFromBegin("totally cost"); } catch (std::exception &ex) { diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.h b/cpp/src/server/grpc_impl/GrpcRequestTask.h index e43b9fba6056f2ed4faec659f093ad2b4886448a..91d7bbfea0f962f4d2fa2c0c2a67bdafc7b27eaf 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.h +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.h @@ -23,17 +23,17 @@ namespace grpc { class CreateTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::TableSchema &schema); + Create(const ::milvus::grpc::TableSchema *schema); protected: explicit - CreateTableTask(const ::milvus::grpc::TableSchema &request); + CreateTableTask(const ::milvus::grpc::TableSchema *request); ServerError OnExecute() override; private: - const ::milvus::grpc::TableSchema schema_; + const ::milvus::grpc::TableSchema *schema_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -58,10 +58,10 @@ private: class DescribeTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name, ::milvus::grpc::TableSchema &schema); + Create(const std::string &table_name, ::milvus::grpc::TableSchema *schema); protected: - DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema &schema); + DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema *schema); ServerError OnExecute() override; @@ -69,7 +69,7 @@ protected: private: std::string table_name_; - ::milvus::grpc::TableSchema &schema_; + ::milvus::grpc::TableSchema *schema_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -94,76 +94,76 @@ private: class CreateIndexTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::IndexParam &index_Param); + Create(const ::milvus::grpc::IndexParam *index_Param); protected: explicit - CreateIndexTask(const ::milvus::grpc::IndexParam &index_Param); + CreateIndexTask(const ::milvus::grpc::IndexParam *index_Param); ServerError OnExecute() override; private: - ::milvus::grpc::IndexParam index_param_; + const ::milvus::grpc::IndexParam *index_param_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class ShowTablesTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(::grpc::ServerWriter<::milvus::grpc::TableName> &writer); + Create(::grpc::ServerWriter<::milvus::grpc::TableName> *writer); protected: explicit - ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> &writer); + ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> *writer); ServerError OnExecute() override; private: - ::grpc::ServerWriter<::milvus::grpc::TableName> writer_; + ::grpc::ServerWriter<::milvus::grpc::TableName> *writer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class InsertTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::InsertParam &insert_Param, - ::milvus::grpc::VectorIds &record_ids_); + Create(const ::milvus::grpc::InsertParam *insert_Param, + ::milvus::grpc::VectorIds *record_ids_); protected: - InsertTask(const ::milvus::grpc::InsertParam &insert_Param, - ::milvus::grpc::VectorIds &record_ids_); + InsertTask(const ::milvus::grpc::InsertParam *insert_Param, + ::milvus::grpc::VectorIds *record_ids_); ServerError OnExecute() override; private: - const ::milvus::grpc::InsertParam insert_param_; - ::milvus::grpc::VectorIds &record_ids_; + const ::milvus::grpc::InsertParam *insert_param_; + ::milvus::grpc::VectorIds *record_ids_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class SearchTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::SearchParam &search_param, + Create(const ::milvus::grpc::SearchParam *search_param, const std::vector &file_id_array, - ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer); + ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer); protected: - SearchTask(const ::milvus::grpc::SearchParam &search_param, + SearchTask(const ::milvus::grpc::SearchParam *search_param, const std::vector &file_id_array, - ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer); + ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer); ServerError OnExecute() override; private: - const ::milvus::grpc::SearchParam search_param_; + const ::milvus::grpc::SearchParam *search_param_; std::vector file_id_array_; - ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> writer_; + ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -204,16 +204,16 @@ private: class DeleteByRangeTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param); + Create(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param); protected: - DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam &delete_by_range_param); + DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param); ServerError OnExecute() override; private: - ::milvus::grpc::DeleteByRangeParam delete_by_range_param_; + const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -237,18 +237,18 @@ class DescribeIndexTask : public GrpcBaseTask { public: static BaseTaskPtr Create(const std::string &table_name, - ::milvus::grpc::IndexParam &index_param); + ::milvus::grpc::IndexParam *index_param); protected: DescribeIndexTask(const std::string &table_name, - ::milvus::grpc::IndexParam &index_param); + ::milvus::grpc::IndexParam *index_param); ServerError OnExecute() override; private: std::string table_name_; - ::milvus::grpc::IndexParam& index_param_; + ::milvus::grpc::IndexParam *index_param_; }; ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////