提交 e79d2395 编写于 作者: J jinhai

Merge branch 'branch-0.4.0' into 'branch-0.4.0'

grpc and thrift server run concurrently

See merge request megasearch/milvus!327

Former-commit-id: 505108b219e5bdf2c74a65d5bf7ec069daea2aba
...@@ -60,6 +60,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -60,6 +60,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-310 - Add milvus CPU utilization ratio and CPU/GPU temperature metrics - MS-310 - Add milvus CPU utilization ratio and CPU/GPU temperature metrics
- MS-324 - Show error when there is not enough gpu memory to build index - MS-324 - Show error when there is not enough gpu memory to build index
- MS-328 - Check metric type on server start - MS-328 - Check metric type on server start
- MS-332 - Set grpc and thrift server run concurrently
## New Feature ## New Feature
- MS-180 - Add new mem manager - MS-180 - Add new mem manager
......
...@@ -90,7 +90,7 @@ define_option(MILVUS_WITH_SQLITE_ORM "Build with SQLite ORM library" ON) ...@@ -90,7 +90,7 @@ define_option(MILVUS_WITH_SQLITE_ORM "Build with SQLite ORM library" ON)
define_option(MILVUS_WITH_MYSQLPP "Build with MySQL++" ON) define_option(MILVUS_WITH_MYSQLPP "Build with MySQL++" ON)
define_option(MILVUS_WITH_THRIFT "Build with Apache Thrift library" OFF) define_option(MILVUS_WITH_THRIFT "Build with Apache Thrift library" ON)
define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON) define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON)
......
...@@ -2679,7 +2679,7 @@ macro(build_grpc) ...@@ -2679,7 +2679,7 @@ macro(build_grpc)
add_dependencies(grpc_protoc grpc_ep) add_dependencies(grpc_protoc grpc_ep)
endmacro() endmacro()
if(NOT MILVUS_WITH_THRIFT STREQUAL "ON") #if(NOT MILVUS_WITH_THRIFT STREQUAL "ON")
resolve_dependency(GRPC) resolve_dependency(GRPC)
get_target_property(GRPC_INCLUDE_DIR grpc INTERFACE_INCLUDE_DIRECTORIES) get_target_property(GRPC_INCLUDE_DIR grpc INTERFACE_INCLUDE_DIRECTORIES)
...@@ -2690,4 +2690,4 @@ if(NOT MILVUS_WITH_THRIFT STREQUAL "ON") ...@@ -2690,4 +2690,4 @@ if(NOT MILVUS_WITH_THRIFT STREQUAL "ON")
include_directories(SYSTEM ${GRPC_THIRD_PARTY_DIR}/protobuf/src) include_directories(SYSTEM ${GRPC_THIRD_PARTY_DIR}/protobuf/src)
link_directories(SYSTEM ${GRPC_PROTOBUF_LIB_DIR}) link_directories(SYSTEM ${GRPC_PROTOBUF_LIB_DIR})
endif() #endif()
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Proprietary and confidential. # Proprietary and confidential.
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
aux_source_directory(cache cache_files) aux_source_directory(cache cache_files)
aux_source_directory(config config_files) aux_source_directory(config config_files)
aux_source_directory(server server_files) aux_source_directory(server server_files)
...@@ -78,17 +77,17 @@ include_directories(/usr/include/mysql) ...@@ -78,17 +77,17 @@ include_directories(/usr/include/mysql)
include_directories(grpc/gen-status) include_directories(grpc/gen-status)
include_directories(grpc/gen-milvus) include_directories(grpc/gen-milvus)
if (MILVUS_WITH_THRIFT STREQUAL "ON") #if (MILVUS_WITH_THRIFT STREQUAL "ON")
set(client_lib set(client_thrift_lib
thrift) thrift)
else() #else()
set(client_lib set(client_grpc_lib
grpcpp_channelz grpcpp_channelz
grpc++ grpc++
grpc grpc
grpc_protobuf grpc_protobuf
grpc_protoc) grpc_protoc)
endif() #endif()
set(third_party_libs set(third_party_libs
knowhere knowhere
...@@ -100,7 +99,8 @@ set(third_party_libs ...@@ -100,7 +99,8 @@ set(third_party_libs
lapack lapack
easyloggingpp easyloggingpp
sqlite sqlite
${client_lib} ${client_thrift_lib}
${client_grpc_lib}
yaml-cpp yaml-cpp
prometheus-cpp-push prometheus-cpp-push
prometheus-cpp-pull prometheus-cpp-pull
...@@ -197,7 +197,7 @@ set(knowhere_libs ...@@ -197,7 +197,7 @@ set(knowhere_libs
tbb tbb
) )
if (MILVUS_WITH_THRIFT STREQUAL "ON") #if (MILVUS_WITH_THRIFT STREQUAL "ON")
add_executable(milvus_thrift_server add_executable(milvus_thrift_server
${config_files} ${config_files}
${server_files} ${server_files}
...@@ -206,7 +206,7 @@ if (MILVUS_WITH_THRIFT STREQUAL "ON") ...@@ -206,7 +206,7 @@ if (MILVUS_WITH_THRIFT STREQUAL "ON")
${thrift_service_files} ${thrift_service_files}
${metrics_files} ${metrics_files}
) )
else() #else()
add_executable(milvus_grpc_server add_executable(milvus_grpc_server
${config_files} ${config_files}
${server_files} ${server_files}
...@@ -215,7 +215,16 @@ else() ...@@ -215,7 +215,16 @@ else()
${grpc_service_files} ${grpc_service_files}
${metrics_files} ${metrics_files}
) )
endif() #endif()
add_executable(milvus_server
${config_files}
${server_files}
${thriftserver_files}
${grpcserver_files}
${utils_files}
${thrift_service_files}
${grpc_service_files}
${metrics_files})
if (ENABLE_LICENSE STREQUAL "ON") if (ENABLE_LICENSE STREQUAL "ON")
add_executable(get_sys_info ${get_sys_info_files}) add_executable(get_sys_info ${get_sys_info_files})
...@@ -224,25 +233,28 @@ if (ENABLE_LICENSE STREQUAL "ON") ...@@ -224,25 +233,28 @@ if (ENABLE_LICENSE STREQUAL "ON")
target_link_libraries(get_sys_info ${license_libs} license_check ${third_party_libs}) target_link_libraries(get_sys_info ${license_libs} license_check ${third_party_libs})
target_link_libraries(license_generator ${license_libs} ${third_party_libs}) target_link_libraries(license_generator ${license_libs} ${third_party_libs})
if(MILVUS_WITH_THRIFT STREQUAL "ON") # if(MILVUS_WITH_THRIFT STREQUAL "ON")
target_link_libraries(milvus_thrift_server ${server_libs} license_check ${knowhere_libs} ${third_party_libs}) target_link_libraries(milvus_thrift_server ${server_libs} license_check ${knowhere_libs} ${third_party_libs})
else() # else()
target_link_libraries(milvus_grpc_server ${server_libs} license_check ${knowhere_libs} ${third_party_libs}) target_link_libraries(milvus_grpc_server ${server_libs} license_check ${knowhere_libs} ${third_party_libs})
endif() # endif()
target_link_libraries(milvus_server ${server_libs} license_check ${knowhere_libs} ${third_party_libs})
else () else ()
if(MILVUS_WITH_THRIFT STREQUAL "ON") # if(MILVUS_WITH_THRIFT STREQUAL "ON")
target_link_libraries(milvus_thrift_server ${server_libs} ${knowhere_libs} ${third_party_libs}) target_link_libraries(milvus_thrift_server ${server_libs} ${knowhere_libs} ${third_party_libs})
else() # else()
target_link_libraries(milvus_grpc_server ${server_libs} ${knowhere_libs} ${third_party_libs}) target_link_libraries(milvus_grpc_server ${server_libs} ${knowhere_libs} ${third_party_libs})
endif() # endif()
target_link_libraries(milvus_server ${server_libs} ${knowhere_libs} ${third_party_libs})
endif() endif()
if (MILVUS_WITH_THRIFT STREQUAL "ON") #if (MILVUS_WITH_THRIFT STREQUAL "ON")
install(TARGETS milvus_thrift_server DESTINATION bin) install(TARGETS milvus_thrift_server DESTINATION bin)
else() #else()
install(TARGETS milvus_grpc_server DESTINATION bin) install(TARGETS milvus_grpc_server DESTINATION bin)
endif() #endif()
install(TARGETS milvus_server DESTINATION bin)
install(FILES install(FILES
${KNOWHERE_BUILD_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}tbb${CMAKE_SHARED_LIBRARY_SUFFIX} ${KNOWHERE_BUILD_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}tbb${CMAKE_SHARED_LIBRARY_SUFFIX}
......
...@@ -12,7 +12,7 @@ include_directories(/usr/include) ...@@ -12,7 +12,7 @@ include_directories(/usr/include)
include_directories(include) include_directories(include)
include_directories(/usr/local/include) include_directories(/usr/local/include)
if (MILVUS_WITH_THRIFT STREQUAL "ON") #if (MILVUS_WITH_THRIFT STREQUAL "ON")
aux_source_directory(thrift thrift_client_files) aux_source_directory(thrift thrift_client_files)
include_directories(thrift) include_directories(thrift)
include_directories(${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp) include_directories(${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp)
...@@ -34,7 +34,7 @@ if (MILVUS_WITH_THRIFT STREQUAL "ON") ...@@ -34,7 +34,7 @@ if (MILVUS_WITH_THRIFT STREQUAL "ON")
${third_party_libs} ${third_party_libs}
) )
install(TARGETS milvus_thrift_sdk DESTINATION lib) install(TARGETS milvus_thrift_sdk DESTINATION lib)
else() #else()
aux_source_directory(grpc grpc_client_files) aux_source_directory(grpc grpc_client_files)
include_directories(${CMAKE_SOURCE_DIR}/src/grpc/gen-milvus) include_directories(${CMAKE_SOURCE_DIR}/src/grpc/gen-milvus)
...@@ -58,6 +58,6 @@ else() ...@@ -58,6 +58,6 @@ else()
${third_party_libs} ${third_party_libs}
) )
install(TARGETS milvus_grpc_sdk DESTINATION lib) install(TARGETS milvus_grpc_sdk DESTINATION lib)
endif() #endif()
add_subdirectory(examples) add_subdirectory(examples)
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
# Proprietary and confidential. # Proprietary and confidential.
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
if (MILVUS_WITH_THRIFT STREQUAL "ON") #if (MILVUS_WITH_THRIFT STREQUAL "ON")
add_subdirectory(thriftsimple) add_subdirectory(thriftsimple)
else() #else()
add_subdirectory(grpcsimple) add_subdirectory(grpcsimple)
endif() #endif()
\ No newline at end of file \ No newline at end of file
...@@ -25,7 +25,7 @@ main(int argc, char *argv[]) { ...@@ -25,7 +25,7 @@ main(int argc, char *argv[]) {
{NULL, 0, 0, 0}}; {NULL, 0, 0, 0}};
int option_index = 0; int option_index = 0;
std::string address = "127.0.0.1", port = "19530"; std::string address = "127.0.0.1", port = "19531";
app_name = argv[0]; app_name = argv[0];
int value; int value;
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
// Unauthorized copying of this file, via any medium is strictly prohibited. // Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential. // Proprietary and confidential.
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
#include <thread>
#include "Server.h" #include "Server.h"
//#include "ServerConfig.h" //#include "ServerConfig.h"
#ifdef MILVUS_ENABLE_THRIFT //#ifdef MILVUS_ENABLE_THRIFT
#include "server/thrift_impl/MilvusServer.h" #include "server/thrift_impl/MilvusServer.h"
#else //#else
#include "server/grpc_impl/MilvusServer.h" #include "server/grpc_impl/GrpcMilvusServer.h"
#endif //#endif
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/SignalUtil.h" #include "utils/SignalUtil.h"
...@@ -224,12 +225,19 @@ Server::LoadConfig() { ...@@ -224,12 +225,19 @@ Server::LoadConfig() {
void void
Server::StartService() { Server::StartService() {
MilvusServer::StartService(); std::thread thrift_thread = std::thread(&MilvusServer::StartService);
std::thread grpc_thread = std::thread(&grpc::GrpcMilvusServer::StartService);
thrift_thread.join();
grpc_thread.join();
//
// MilvusServer::StartService();
// grpc::GrpcMilvusServer::StartService();
} }
void void
Server::StopService() { Server::StopService() {
MilvusServer::StopService(); MilvusServer::StartService();
grpc::GrpcMilvusServer::StopService();
} }
} }
......
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include "milvus.grpc.pb.h" #include "milvus.grpc.pb.h"
#include "MilvusServer.h" #include "GrpcMilvusServer.h"
#include "../ServerConfig.h" #include "../ServerConfig.h"
#include "../DBWrapper.h" #include "../DBWrapper.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "faiss/utils.h" #include "faiss/utils.h"
#include "RequestHandler.h" #include "GrpcRequestHandler.h"
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
...@@ -28,14 +28,15 @@ ...@@ -28,14 +28,15 @@
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
namespace grpc {
static std::unique_ptr<grpc::Server> server; static std::unique_ptr<::grpc::Server> server;
constexpr long MESSAGE_SIZE = -1; constexpr long MESSAGE_SIZE = -1;
void void
MilvusServer::StartService() { GrpcMilvusServer::StartService() {
if (server != nullptr){ if (server != nullptr) {
std::cout << "stopservice!\n"; std::cout << "stopservice!\n";
StopService(); StopService();
} }
...@@ -50,9 +51,9 @@ MilvusServer::StartService() { ...@@ -50,9 +51,9 @@ MilvusServer::StartService() {
DBWrapper::DB();//initialize db DBWrapper::DB();//initialize db
std::string server_address(address + ":" + std::to_string(port)); std::string server_address(address + ":" + std::to_string(port + 1));
grpc::ServerBuilder builder; ::grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(MESSAGE_SIZE); //default 4 * 1024 * 1024 builder.SetMaxReceiveMessageSize(MESSAGE_SIZE); //default 4 * 1024 * 1024
builder.SetMaxSendMessageSize(MESSAGE_SIZE); builder.SetMaxSendMessageSize(MESSAGE_SIZE);
...@@ -60,9 +61,9 @@ MilvusServer::StartService() { ...@@ -60,9 +61,9 @@ MilvusServer::StartService() {
builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_STREAM_GZIP); builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_STREAM_GZIP);
builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_HIGH); builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_HIGH);
RequestHandler service; GrpcRequestHandler service;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(&service); builder.RegisterService(&service);
server = builder.BuildAndStart(); server = builder.BuildAndStart();
...@@ -71,7 +72,7 @@ MilvusServer::StartService() { ...@@ -71,7 +72,7 @@ MilvusServer::StartService() {
} }
void void
MilvusServer::StopService() { GrpcMilvusServer::StopService() {
if (server != nullptr) { if (server != nullptr) {
server->Shutdown(); server->Shutdown();
} }
...@@ -80,3 +81,4 @@ MilvusServer::StopService() { ...@@ -80,3 +81,4 @@ MilvusServer::StopService() {
} }
} }
} }
}
\ No newline at end of file
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
#include <cstdint> #include <cstdint>
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
class MilvusServer { namespace grpc {
class GrpcMilvusServer {
public: public:
static void static void
StartService(); StartService();
...@@ -23,3 +25,4 @@ public: ...@@ -23,3 +25,4 @@ public:
} }
} }
} }
}
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include "RequestHandler.h" #include "GrpcRequestHandler.h"
#include "RequestTask.h" #include "GrpcRequestTask.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
namespace grpc {
::grpc::Status ::grpc::Status
RequestHandler::CreateTable(::grpc::ServerContext *context, GrpcRequestHandler::CreateTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableSchema *request, const ::milvus::grpc::TableSchema *request,
::milvus::grpc::Status *response) { ::milvus::grpc::Status *response) {
BaseTaskPtr task_ptr = CreateTableTask::Create(*request); BaseTaskPtr task_ptr = CreateTableTask::Create(*request);
RequestScheduler::ExecTask(task_ptr, response); GrpcRequestScheduler::ExecTask(task_ptr, response);
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
::grpc::Status ::grpc::Status
RequestHandler::HasTable(::grpc::ServerContext *context, GrpcRequestHandler::HasTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName *request, const ::milvus::grpc::TableName *request,
::milvus::grpc::BoolReply *response) { ::milvus::grpc::BoolReply *response) {
bool has_table = false; bool has_table = false;
BaseTaskPtr task_ptr = HasTableTask::Create(request->table_name(), has_table); BaseTaskPtr task_ptr = HasTableTask::Create(request->table_name(), has_table);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->set_bool_reply(has_table); response->set_bool_reply(has_table);
response->mutable_status()->set_reason(grpc_status.reason()); response->mutable_status()->set_reason(grpc_status.reason());
response->mutable_status()->set_error_code(grpc_status.error_code()); response->mutable_status()->set_error_code(grpc_status.error_code());
...@@ -38,47 +39,47 @@ RequestHandler::HasTable(::grpc::ServerContext *context, ...@@ -38,47 +39,47 @@ RequestHandler::HasTable(::grpc::ServerContext *context,
} }
::grpc::Status ::grpc::Status
RequestHandler::DropTable(::grpc::ServerContext* context, GrpcRequestHandler::DropTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, const ::milvus::grpc::TableName *request,
::milvus::grpc::Status* response) { ::milvus::grpc::Status *response) {
BaseTaskPtr task_ptr = DropTableTask::Create(request->table_name()); BaseTaskPtr task_ptr = DropTableTask::Create(request->table_name());
RequestScheduler::ExecTask(task_ptr, response); GrpcRequestScheduler::ExecTask(task_ptr, response);
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
::grpc::Status ::grpc::Status
RequestHandler::BuildIndex(::grpc::ServerContext* context, GrpcRequestHandler::BuildIndex(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, const ::milvus::grpc::TableName *request,
::milvus::grpc::Status* response) { ::milvus::grpc::Status *response) {
BaseTaskPtr task_ptr = BuildIndexTask::Create(request->table_name()); BaseTaskPtr task_ptr = BuildIndexTask::Create(request->table_name());
RequestScheduler::ExecTask(task_ptr, response); GrpcRequestScheduler::ExecTask(task_ptr, response);
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
::grpc::Status ::grpc::Status
RequestHandler::InsertVector(::grpc::ServerContext* context, GrpcRequestHandler::InsertVector(::grpc::ServerContext *context,
const ::milvus::grpc::InsertInfos* request, const ::milvus::grpc::InsertInfos *request,
::milvus::grpc::VectorIds* response) { ::milvus::grpc::VectorIds *response) {
BaseTaskPtr task_ptr = InsertVectorTask::Create(*request, *response); BaseTaskPtr task_ptr = InsertVectorTask::Create(*request, *response);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->mutable_status()->set_reason(grpc_status.reason()); response->mutable_status()->set_reason(grpc_status.reason());
response->mutable_status()->set_error_code(grpc_status.error_code()); response->mutable_status()->set_error_code(grpc_status.error_code());
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
::grpc::Status ::grpc::Status
RequestHandler::SearchVector(::grpc::ServerContext* context, GrpcRequestHandler::SearchVector(::grpc::ServerContext *context,
const ::milvus::grpc::SearchVectorInfos* request, const ::milvus::grpc::SearchVectorInfos *request,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>* writer) { ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) {
std::vector<std::string> file_id_array; std::vector<std::string> file_id_array;
BaseTaskPtr task_ptr = SearchVectorTask::Create(*request, file_id_array, *writer); BaseTaskPtr task_ptr = SearchVectorTask::Create(*request, file_id_array, *writer);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
if (grpc_status.error_code() != SERVER_SUCCESS) { if (grpc_status.error_code() != SERVER_SUCCESS) {
::grpc::Status status(::grpc::INVALID_ARGUMENT, grpc_status.reason()); ::grpc::Status status(::grpc::INVALID_ARGUMENT, grpc_status.reason());
return status; return status;
...@@ -88,14 +89,14 @@ RequestHandler::SearchVector(::grpc::ServerContext* context, ...@@ -88,14 +89,14 @@ RequestHandler::SearchVector(::grpc::ServerContext* context,
} }
::grpc::Status ::grpc::Status
RequestHandler::SearchVectorInFiles(::grpc::ServerContext* context, GrpcRequestHandler::SearchVectorInFiles(::grpc::ServerContext *context,
const ::milvus::grpc::SearchVectorInFilesInfos* request, const ::milvus::grpc::SearchVectorInFilesInfos *request,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>* writer) { ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) {
std::vector<std::string> file_id_array; std::vector<std::string> file_id_array;
BaseTaskPtr task_ptr = SearchVectorTask::Create(request->search_vector_infos(), file_id_array, *writer); BaseTaskPtr task_ptr = SearchVectorTask::Create(request->search_vector_infos(), file_id_array, *writer);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
if (grpc_status.error_code() != SERVER_SUCCESS) { if (grpc_status.error_code() != SERVER_SUCCESS) {
::grpc::Status status(::grpc::INVALID_ARGUMENT, grpc_status.reason()); ::grpc::Status status(::grpc::INVALID_ARGUMENT, grpc_status.reason());
return status; return status;
...@@ -105,27 +106,27 @@ RequestHandler::SearchVectorInFiles(::grpc::ServerContext* context, ...@@ -105,27 +106,27 @@ RequestHandler::SearchVectorInFiles(::grpc::ServerContext* context,
} }
::grpc::Status ::grpc::Status
RequestHandler::DescribeTable(::grpc::ServerContext* context, GrpcRequestHandler::DescribeTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, const ::milvus::grpc::TableName *request,
::milvus::grpc::TableSchema* response) { ::milvus::grpc::TableSchema *response) {
BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), *response); BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), *response);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->mutable_table_name()->mutable_status()->set_error_code(grpc_status.error_code()); response->mutable_table_name()->mutable_status()->set_error_code(grpc_status.error_code());
response->mutable_table_name()->mutable_status()->set_reason(grpc_status.reason()); response->mutable_table_name()->mutable_status()->set_reason(grpc_status.reason());
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
::grpc::Status ::grpc::Status
RequestHandler::GetTableRowCount(::grpc::ServerContext* context, GrpcRequestHandler::GetTableRowCount(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, const ::milvus::grpc::TableName *request,
::milvus::grpc::TableRowCount* response) { ::milvus::grpc::TableRowCount *response) {
int64_t row_count = 0; int64_t row_count = 0;
BaseTaskPtr task_ptr = GetTableRowCountTask::Create(request->table_name(), row_count); BaseTaskPtr task_ptr = GetTableRowCountTask::Create(request->table_name(), row_count);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->set_table_row_count(row_count); response->set_table_row_count(row_count);
response->mutable_status()->set_reason(grpc_status.reason()); response->mutable_status()->set_reason(grpc_status.reason());
response->mutable_status()->set_error_code(grpc_status.error_code()); response->mutable_status()->set_error_code(grpc_status.error_code());
...@@ -133,13 +134,13 @@ RequestHandler::GetTableRowCount(::grpc::ServerContext* context, ...@@ -133,13 +134,13 @@ RequestHandler::GetTableRowCount(::grpc::ServerContext* context,
} }
::grpc::Status ::grpc::Status
RequestHandler::ShowTables(::grpc::ServerContext* context, GrpcRequestHandler::ShowTables(::grpc::ServerContext *context,
const ::milvus::grpc::Command* request, const ::milvus::grpc::Command *request,
::grpc::ServerWriter<::milvus::grpc::TableName>* writer) { ::grpc::ServerWriter<::milvus::grpc::TableName> *writer) {
BaseTaskPtr task_ptr = ShowTablesTask::Create(*writer); BaseTaskPtr task_ptr = ShowTablesTask::Create(*writer);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
if (grpc_status.error_code() != SERVER_SUCCESS) { if (grpc_status.error_code() != SERVER_SUCCESS) {
::grpc::Status status(::grpc::UNKNOWN, grpc_status.reason()); ::grpc::Status status(::grpc::UNKNOWN, grpc_status.reason());
return status; return status;
...@@ -149,21 +150,21 @@ RequestHandler::ShowTables(::grpc::ServerContext* context, ...@@ -149,21 +150,21 @@ RequestHandler::ShowTables(::grpc::ServerContext* context,
} }
::grpc::Status ::grpc::Status
RequestHandler::Ping(::grpc::ServerContext* context, GrpcRequestHandler::Ping(::grpc::ServerContext *context,
const ::milvus::grpc::Command* request, const ::milvus::grpc::Command *request,
::milvus::grpc::ServerStatus* response) { ::milvus::grpc::ServerStatus *response) {
std::string result; std::string result;
BaseTaskPtr task_ptr = PingTask::Create(request->cmd(), result); BaseTaskPtr task_ptr = PingTask::Create(request->cmd(), result);
::milvus::grpc::Status grpc_status; ::milvus::grpc::Status grpc_status;
RequestScheduler::ExecTask(task_ptr, &grpc_status); GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status);
response->set_info(result); response->set_info(result);
response->mutable_status()->set_reason(grpc_status.reason()); response->mutable_status()->set_reason(grpc_status.reason());
response->mutable_status()->set_error_code(grpc_status.error_code()); response->mutable_status()->set_error_code(grpc_status.error_code());
return ::grpc::Status::OK; return ::grpc::Status::OK;
} }
}
} }
} }
} }
\ No newline at end of file
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
#include <cstdint> #include <cstdint>
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
class RequestHandler final : public ::milvus::grpc::MilvusService::Service { namespace grpc {
class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service {
public: public:
/** /**
* @brief Create table method * @brief Create table method
...@@ -32,8 +33,8 @@ public: ...@@ -32,8 +33,8 @@ public:
* @param context * @param context
*/ */
::grpc::Status ::grpc::Status
CreateTable(::grpc::ServerContext* context, CreateTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableSchema* request, ::milvus::grpc::Status* response) override ; const ::milvus::grpc::TableSchema *request, ::milvus::grpc::Status *response) override;
/** /**
* @brief Test table existence method * @brief Test table existence method
...@@ -51,8 +52,8 @@ public: ...@@ -51,8 +52,8 @@ public:
* @param context * @param context
*/ */
::grpc::Status ::grpc::Status
HasTable(::grpc::ServerContext* context, HasTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, ::milvus::grpc::BoolReply* response) override ; const ::milvus::grpc::TableName *request, ::milvus::grpc::BoolReply *response) override;
/** /**
* @brief Drop table method * @brief Drop table method
...@@ -70,8 +71,8 @@ public: ...@@ -70,8 +71,8 @@ public:
* @param context * @param context
*/ */
::grpc::Status ::grpc::Status
DropTable(::grpc::ServerContext* context, DropTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, ::milvus::grpc::Status* response) override; const ::milvus::grpc::TableName *request, ::milvus::grpc::Status *response) override;
/** /**
* @brief build index by table method * @brief build index by table method
...@@ -89,8 +90,8 @@ public: ...@@ -89,8 +90,8 @@ public:
* @param context * @param context
*/ */
::grpc::Status ::grpc::Status
BuildIndex(::grpc::ServerContext* context, BuildIndex(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, ::milvus::grpc::Status* response) override; const ::milvus::grpc::TableName *request, ::milvus::grpc::Status *response) override;
/** /**
...@@ -109,8 +110,9 @@ public: ...@@ -109,8 +110,9 @@ public:
* @param response * @param response
*/ */
::grpc::Status ::grpc::Status
InsertVector(::grpc::ServerContext* context, InsertVector(::grpc::ServerContext *context,
const ::milvus::grpc::InsertInfos* request, ::milvus::grpc::VectorIds* response) override; const ::milvus::grpc::InsertInfos *request,
::milvus::grpc::VectorIds *response) override;
/** /**
* @brief Query vector * @brief Query vector
...@@ -133,8 +135,9 @@ public: ...@@ -133,8 +135,9 @@ public:
* @param writer * @param writer
*/ */
::grpc::Status ::grpc::Status
SearchVector(::grpc::ServerContext* context, SearchVector(::grpc::ServerContext *context,
const ::milvus::grpc::SearchVectorInfos* request, ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>* writer) override; const ::milvus::grpc::SearchVectorInfos *request,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) override;
/** /**
* @brief Internal use query interface * @brief Internal use query interface
...@@ -157,8 +160,9 @@ public: ...@@ -157,8 +160,9 @@ public:
* @param writer * @param writer
*/ */
::grpc::Status ::grpc::Status
SearchVectorInFiles(::grpc::ServerContext* context, SearchVectorInFiles(::grpc::ServerContext *context,
const ::milvus::grpc::SearchVectorInFilesInfos* request, ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>* writer) override; const ::milvus::grpc::SearchVectorInFilesInfos *request,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> *writer) override;
/** /**
* @brief Get table schema * @brief Get table schema
...@@ -176,8 +180,9 @@ public: ...@@ -176,8 +180,9 @@ public:
* @param response * @param response
*/ */
::grpc::Status ::grpc::Status
DescribeTable(::grpc::ServerContext* context, DescribeTable(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, ::milvus::grpc::TableSchema* response) override; const ::milvus::grpc::TableName *request,
::milvus::grpc::TableSchema *response) override;
/** /**
* @brief Get table row count * @brief Get table row count
...@@ -195,8 +200,9 @@ public: ...@@ -195,8 +200,9 @@ public:
* @param context * @param context
*/ */
::grpc::Status ::grpc::Status
GetTableRowCount(::grpc::ServerContext* context, GetTableRowCount(::grpc::ServerContext *context,
const ::milvus::grpc::TableName* request, ::milvus::grpc::TableRowCount* response) override; const ::milvus::grpc::TableName *request,
::milvus::grpc::TableRowCount *response) override;
/** /**
* @brief List all tables in database * @brief List all tables in database
...@@ -214,8 +220,9 @@ public: ...@@ -214,8 +220,9 @@ public:
* @param writer * @param writer
*/ */
::grpc::Status ::grpc::Status
ShowTables(::grpc::ServerContext* context, ShowTables(::grpc::ServerContext *context,
const ::milvus::grpc::Command* request, ::grpc::ServerWriter< ::milvus::grpc::TableName>* writer) override; const ::milvus::grpc::Command *request,
::grpc::ServerWriter<::milvus::grpc::TableName> *writer) override;
/** /**
* @brief Give the server status * @brief Give the server status
...@@ -233,13 +240,12 @@ public: ...@@ -233,13 +240,12 @@ public:
* @param response * @param response
*/ */
::grpc::Status ::grpc::Status
Ping(::grpc::ServerContext* context, Ping(::grpc::ServerContext *context,
const ::milvus::grpc::Command* request, ::milvus::grpc::ServerStatus* response) override; const ::milvus::grpc::Command *request, ::milvus::grpc::ServerStatus *response) override;
}; };
} }
} }
} }
}
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include "RequestScheduler.h" #include "GrpcRequestScheduler.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "src/grpc/gen-status/status.pb.h" #include "src/grpc/gen-status/status.pb.h"
...@@ -11,11 +11,12 @@ ...@@ -11,11 +11,12 @@
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
namespace grpc {
using namespace ::milvus; using namespace ::milvus;
namespace { namespace {
const std::map<ServerError, ::milvus::grpc::ErrorCode > &ErrorMap() { const std::map<ServerError, ::milvus::grpc::ErrorCode> &ErrorMap() {
static const std::map<ServerError, ::milvus::grpc::ErrorCode> code_map = { static const std::map<ServerError, ::milvus::grpc::ErrorCode> code_map = {
{SERVER_UNEXPECTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR}, {SERVER_UNEXPECTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
{SERVER_UNSUPPORTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR}, {SERVER_UNSUPPORTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
...@@ -50,7 +51,7 @@ namespace { ...@@ -50,7 +51,7 @@ namespace {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BaseTask::BaseTask(const std::string& task_group, bool async) GrpcBaseTask::GrpcBaseTask(const std::string &task_group, bool async)
: task_group_(task_group), : task_group_(task_group),
async_(async), async_(async),
done_(false), done_(false),
...@@ -58,12 +59,12 @@ BaseTask::BaseTask(const std::string& task_group, bool async) ...@@ -58,12 +59,12 @@ BaseTask::BaseTask(const std::string& task_group, bool async)
} }
BaseTask::~BaseTask() { GrpcBaseTask::~GrpcBaseTask() {
WaitToFinish(); WaitToFinish();
} }
ServerError ServerError
BaseTask::Execute() { GrpcBaseTask::Execute() {
error_code_ = OnExecute(); error_code_ = OnExecute();
done_ = true; done_ = true;
finish_cond_.notify_all(); finish_cond_.notify_all();
...@@ -71,7 +72,7 @@ BaseTask::Execute() { ...@@ -71,7 +72,7 @@ BaseTask::Execute() {
} }
ServerError ServerError
BaseTask::SetError(ServerError error_code, const std::string& error_msg) { GrpcBaseTask::SetError(ServerError error_code, const std::string &error_msg) {
error_code_ = error_code; error_code_ = error_code;
error_msg_ = error_msg; error_msg_ = error_msg;
...@@ -80,33 +81,33 @@ BaseTask::SetError(ServerError error_code, const std::string& error_msg) { ...@@ -80,33 +81,33 @@ BaseTask::SetError(ServerError error_code, const std::string& error_msg) {
} }
ServerError ServerError
BaseTask::WaitToFinish() { GrpcBaseTask::WaitToFinish() {
std::unique_lock <std::mutex> lock(finish_mtx_); std::unique_lock<std::mutex> lock(finish_mtx_);
finish_cond_.wait(lock, [this] { return done_; }); finish_cond_.wait(lock, [this] { return done_; });
return error_code_; return error_code_;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
RequestScheduler::RequestScheduler() GrpcRequestScheduler::GrpcRequestScheduler()
: stopped_(false) { : stopped_(false) {
Start(); Start();
} }
RequestScheduler::~RequestScheduler() { GrpcRequestScheduler::~GrpcRequestScheduler() {
Stop(); Stop();
} }
void void
RequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status *grpc_status) { GrpcRequestScheduler::ExecTask(BaseTaskPtr &task_ptr, ::milvus::grpc::Status *grpc_status) {
if(task_ptr == nullptr) { if (task_ptr == nullptr) {
return; return;
} }
RequestScheduler& scheduler = RequestScheduler::GetInstance(); GrpcRequestScheduler &scheduler = GrpcRequestScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr); scheduler.ExecuteTask(task_ptr);
if(!task_ptr->IsAsync()) { if (!task_ptr->IsAsync()) {
task_ptr->WaitToFinish(); task_ptr->WaitToFinish();
ServerError err = task_ptr->ErrorCode(); ServerError err = task_ptr->ErrorCode();
if (err != SERVER_SUCCESS) { if (err != SERVER_SUCCESS) {
...@@ -117,8 +118,8 @@ RequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status *grpc_s ...@@ -117,8 +118,8 @@ RequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status *grpc_s
} }
void void
RequestScheduler::Start() { GrpcRequestScheduler::Start() {
if(!stopped_) { if (!stopped_) {
return; return;
} }
...@@ -126,23 +127,23 @@ RequestScheduler::Start() { ...@@ -126,23 +127,23 @@ RequestScheduler::Start() {
} }
void void
RequestScheduler::Stop() { GrpcRequestScheduler::Stop() {
if(stopped_) { if (stopped_) {
return; return;
} }
SERVER_LOG_INFO << "Scheduler gonna stop..."; SERVER_LOG_INFO << "Scheduler gonna stop...";
{ {
std::lock_guard<std::mutex> lock(queue_mtx_); std::lock_guard<std::mutex> lock(queue_mtx_);
for(auto iter : task_groups_) { for (auto iter : task_groups_) {
if(iter.second != nullptr) { if (iter.second != nullptr) {
iter.second->Put(nullptr); iter.second->Put(nullptr);
} }
} }
} }
for(auto iter : execute_threads_) { for (auto iter : execute_threads_) {
if(iter == nullptr) if (iter == nullptr)
continue; continue;
iter->join(); iter->join();
...@@ -152,18 +153,18 @@ RequestScheduler::Stop() { ...@@ -152,18 +153,18 @@ RequestScheduler::Stop() {
} }
ServerError ServerError
RequestScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { GrpcRequestScheduler::ExecuteTask(const BaseTaskPtr &task_ptr) {
if(task_ptr == nullptr) { if (task_ptr == nullptr) {
return SERVER_NULL_POINTER; return SERVER_NULL_POINTER;
} }
ServerError err = PutTaskToQueue(task_ptr); ServerError err = PutTaskToQueue(task_ptr);
if(err != SERVER_SUCCESS) { if (err != SERVER_SUCCESS) {
SERVER_LOG_ERROR << "Put task to queue failed with code: " << err ; SERVER_LOG_ERROR << "Put task to queue failed with code: " << err;
return err; return err;
} }
if(task_ptr->IsAsync()) { if (task_ptr->IsAsync()) {
return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere return SERVER_SUCCESS;//async execution, caller need to call WaitToFinish at somewhere
} }
...@@ -172,11 +173,11 @@ RequestScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { ...@@ -172,11 +173,11 @@ RequestScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
namespace { namespace {
void TakeTaskToExecute(TaskQueuePtr task_queue) { void TakeTaskToExecute(TaskQueuePtr task_queue) {
if(task_queue == nullptr) { if (task_queue == nullptr) {
return; return;
} }
while(true) { while (true) {
BaseTaskPtr task = task_queue->Take(); BaseTaskPtr task = task_queue->Take();
if (task == nullptr) { if (task == nullptr) {
SERVER_LOG_ERROR << "Take null from task queue, stop thread"; SERVER_LOG_ERROR << "Take null from task queue, stop thread";
...@@ -185,10 +186,10 @@ namespace { ...@@ -185,10 +186,10 @@ namespace {
try { try {
ServerError err = task->Execute(); ServerError err = task->Execute();
if(err != SERVER_SUCCESS) { if (err != SERVER_SUCCESS) {
SERVER_LOG_ERROR << "Task failed with code: " << err; SERVER_LOG_ERROR << "Task failed with code: " << err;
} }
} catch (std::exception& ex) { } catch (std::exception &ex) {
SERVER_LOG_ERROR << "Task failed to execute: " << ex.what(); SERVER_LOG_ERROR << "Task failed to execute: " << ex.what();
} }
} }
...@@ -196,11 +197,11 @@ namespace { ...@@ -196,11 +197,11 @@ namespace {
} }
ServerError ServerError
RequestScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { GrpcRequestScheduler::PutTaskToQueue(const BaseTaskPtr &task_ptr) {
std::lock_guard<std::mutex> lock(queue_mtx_); std::lock_guard<std::mutex> lock(queue_mtx_);
std::string group_name = task_ptr->TaskGroup(); std::string group_name = task_ptr->TaskGroup();
if(task_groups_.count(group_name) > 0) { if (task_groups_.count(group_name) > 0) {
task_groups_[group_name]->Put(task_ptr); task_groups_[group_name]->Put(task_ptr);
} else { } else {
TaskQueuePtr queue = std::make_shared<TaskQueue>(); TaskQueuePtr queue = std::make_shared<TaskQueue>();
...@@ -219,3 +220,4 @@ RequestScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { ...@@ -219,3 +220,4 @@ RequestScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
} }
} }
} }
}
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
#include "utils/BlockingQueue.h" #include "utils/BlockingQueue.h"
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
namespace grpc {
class BaseTask { class GrpcBaseTask {
protected: protected:
BaseTask(const std::string& task_group, bool async = false); GrpcBaseTask(const std::string &task_group, bool async = false);
virtual ~BaseTask();
virtual ~GrpcBaseTask();
public: public:
ServerError ServerError
...@@ -46,7 +48,7 @@ protected: ...@@ -46,7 +48,7 @@ protected:
OnExecute() = 0; OnExecute() = 0;
ServerError ServerError
SetError(ServerError error_code, const std::string& msg); SetError(ServerError error_code, const std::string &msg);
protected: protected:
mutable std::mutex finish_mtx_; mutable std::mutex finish_mtx_;
...@@ -59,33 +61,35 @@ protected: ...@@ -59,33 +61,35 @@ protected:
std::string error_msg_; std::string error_msg_;
}; };
using BaseTaskPtr = std::shared_ptr<BaseTask>; using BaseTaskPtr = std::shared_ptr<GrpcBaseTask>;
using TaskQueue = BlockingQueue<BaseTaskPtr>; using TaskQueue = BlockingQueue<BaseTaskPtr>;
using TaskQueuePtr = std::shared_ptr<TaskQueue>; using TaskQueuePtr = std::shared_ptr<TaskQueue>;
using ThreadPtr = std::shared_ptr<std::thread>; using ThreadPtr = std::shared_ptr<std::thread>;
class RequestScheduler { class GrpcRequestScheduler {
public: public:
static RequestScheduler& GetInstance() { static GrpcRequestScheduler &GetInstance() {
static RequestScheduler scheduler; static GrpcRequestScheduler scheduler;
return scheduler; return scheduler;
} }
void Start(); void Start();
void Stop(); void Stop();
ServerError ServerError
ExecuteTask(const BaseTaskPtr& task_ptr); ExecuteTask(const BaseTaskPtr &task_ptr);
static void static void
ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status* grpc_status); ExecTask(BaseTaskPtr &task_ptr, ::milvus::grpc::Status *grpc_status);
protected: protected:
RequestScheduler(); GrpcRequestScheduler();
virtual ~RequestScheduler();
virtual ~GrpcRequestScheduler();
ServerError ServerError
PutTaskToQueue(const BaseTaskPtr& task_ptr); PutTaskToQueue(const BaseTaskPtr &task_ptr);
private: private:
mutable std::mutex queue_mtx_; mutable std::mutex queue_mtx_;
...@@ -97,7 +101,7 @@ private: ...@@ -97,7 +101,7 @@ private:
bool stopped_; bool stopped_;
}; };
}
} }
} }
} }
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include "RequestTask.h" #include "GrpcRequestTask.h"
#include "../ServerConfig.h" #include "../ServerConfig.h"
#include "utils/CommonUtil.h" #include "utils/CommonUtil.h"
#include "utils/Log.h" #include "utils/Log.h"
...@@ -11,16 +11,18 @@ ...@@ -11,16 +11,18 @@
#include "utils/ValidationUtil.h" #include "utils/ValidationUtil.h"
#include "../DBWrapper.h" #include "../DBWrapper.h"
#include "version.h" #include "version.h"
#include "MilvusServer.h" #include "GrpcMilvusServer.h"
#include "src/server/Server.h" #include "src/server/Server.h"
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
static const char* DQL_TASK_GROUP = "dql"; namespace grpc {
static const char* DDL_DML_TASK_GROUP = "ddl_dml";
static const char* PING_TASK_GROUP = "ping"; static const char *DQL_TASK_GROUP = "dql";
static const char *DDL_DML_TASK_GROUP = "ddl_dml";
static const char *PING_TASK_GROUP = "ping";
using DB_META = zilliz::milvus::engine::meta::Meta; using DB_META = zilliz::milvus::engine::meta::Meta;
using DB_DATE = zilliz::milvus::engine::meta::DateT; using DB_DATE = zilliz::milvus::engine::meta::DateT;
...@@ -34,7 +36,7 @@ namespace { ...@@ -34,7 +36,7 @@ namespace {
{3, engine::EngineType::FAISS_IVFSQ8}, {3, engine::EngineType::FAISS_IVFSQ8},
}; };
if(map_type.find(type) == map_type.end()) { if (map_type.find(type) == map_type.end()) {
return engine::EngineType::INVALID; return engine::EngineType::INVALID;
} }
...@@ -49,7 +51,7 @@ namespace { ...@@ -49,7 +51,7 @@ namespace {
{engine::EngineType::FAISS_IVFSQ8, 3}, {engine::EngineType::FAISS_IVFSQ8, 3},
}; };
if(map_type.find(type) == map_type.end()) { if (map_type.find(type) == map_type.end()) {
return 0; return 0;
} }
...@@ -60,38 +62,40 @@ namespace { ...@@ -60,38 +62,40 @@ namespace {
void void
ConvertTimeRangeToDBDates(const std::vector<::milvus::grpc::Range> &range_array, ConvertTimeRangeToDBDates(const std::vector<::milvus::grpc::Range> &range_array,
std::vector<DB_DATE>& dates, std::vector<DB_DATE> &dates,
ServerError& error_code, ServerError &error_code,
std::string& error_msg) { std::string &error_msg) {
dates.clear(); dates.clear();
for(auto& range : range_array) { for (auto &range : range_array) {
time_t tt_start, tt_end; time_t tt_start, tt_end;
tm tm_start, tm_end; tm tm_start, tm_end;
if(!CommonUtil::TimeStrToTime(range.start_value(), tt_start, tm_start)){ if (!CommonUtil::TimeStrToTime(range.start_value(), tt_start, tm_start)) {
error_code = SERVER_INVALID_TIME_RANGE; error_code = SERVER_INVALID_TIME_RANGE;
error_msg = "Invalid time range: " + range.start_value(); error_msg = "Invalid time range: " + range.start_value();
return; return;
} }
if(!CommonUtil::TimeStrToTime(range.end_value(), tt_end, tm_end)){ if (!CommonUtil::TimeStrToTime(range.end_value(), tt_end, tm_end)) {
error_code = SERVER_INVALID_TIME_RANGE; error_code = SERVER_INVALID_TIME_RANGE;
error_msg = "Invalid time range: " + range.start_value(); error_msg = "Invalid time range: " + range.start_value();
return; return;
} }
long days = (tt_end > tt_start) ? (tt_end - tt_start)/DAY_SECONDS : (tt_start - tt_end)/DAY_SECONDS; long days = (tt_end > tt_start) ? (tt_end - tt_start) / DAY_SECONDS : (tt_start - tt_end) /
if(days == 0) { DAY_SECONDS;
if (days == 0) {
error_code = SERVER_INVALID_TIME_RANGE; error_code = SERVER_INVALID_TIME_RANGE;
error_msg = "Invalid time range: " + range.start_value() + " to " + range.end_value(); error_msg = "Invalid time range: " + range.start_value() + " to " + range.end_value();
return ; return;
} }
for(long i = 0; i < days; i++) { for (long i = 0; i < days; i++) {
time_t tt_day = tt_start + DAY_SECONDS*i; time_t tt_day = tt_start + DAY_SECONDS * i;
tm tm_day; tm tm_day;
CommonUtil::ConvertTime(tt_day, tm_day); CommonUtil::ConvertTime(tt_day, tm_day);
long date = tm_day.tm_year*10000 + tm_day.tm_mon*100 + tm_day.tm_mday;//according to db logic long date = tm_day.tm_year * 10000 + tm_day.tm_mon * 100 +
tm_day.tm_mday;//according to db logic
dates.push_back(date); dates.push_back(date);
} }
} }
...@@ -99,17 +103,15 @@ namespace { ...@@ -99,17 +103,15 @@ namespace {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema& schema) CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
schema_(schema) { schema_(schema) {
} }
BaseTaskPtr BaseTaskPtr
CreateTableTask::Create(const ::milvus::grpc::TableSchema& schema) { CreateTableTask::Create(const ::milvus::grpc::TableSchema &schema) {
// BaseTaskPtr create_table_task_ptr = std::make_shared<CreateTableTask>(schema); return std::shared_ptr<GrpcBaseTask>(new CreateTableTask(schema));
// return create_table_task_ptr;
return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
} }
ServerError ServerError
...@@ -119,35 +121,35 @@ CreateTableTask::OnExecute() { ...@@ -119,35 +121,35 @@ CreateTableTask::OnExecute() {
try { try {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(schema_.table_name().table_name()); ServerError res = ValidationUtil::ValidateTableName(schema_.table_name().table_name());
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + schema_.table_name().table_name()); return SetError(res, "Invalid table name: " + schema_.table_name().table_name());
} }
res = ValidationUtil::ValidateTableDimension(schema_.dimension()); res = ValidationUtil::ValidateTableDimension(schema_.dimension());
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table dimension: " + std::to_string(schema_.dimension())); return SetError(res, "Invalid table dimension: " + std::to_string(schema_.dimension()));
} }
res = ValidationUtil::ValidateTableIndexType(schema_.index_type()); res = ValidationUtil::ValidateTableIndexType(schema_.index_type());
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid index type: " + std::to_string(schema_.index_type())); return SetError(res, "Invalid index type: " + std::to_string(schema_.index_type()));
} }
//step 2: construct table schema //step 2: construct table schema
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.dimension_ = (uint16_t)schema_.dimension(); table_info.dimension_ = (uint16_t) schema_.dimension();
table_info.table_id_ = schema_.table_name().table_name(); table_info.table_id_ = schema_.table_name().table_name();
table_info.engine_type_ = (int)EngineType(schema_.index_type()); table_info.engine_type_ = (int) EngineType(schema_.index_type());
table_info.store_raw_data_ = schema_.store_raw_vector(); table_info.store_raw_data_ = schema_.store_raw_vector();
//step 3: create table //step 3: create table
engine::Status stat = DBWrapper::DB()->CreateTable(table_info); engine::Status stat = DBWrapper::DB()->CreateTable(table_info);
if(!stat.ok()) { if (!stat.ok()) {
//table could exist //table could exist
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -157,15 +159,15 @@ CreateTableTask::OnExecute() { ...@@ -157,15 +159,15 @@ CreateTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema& schema) DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name), table_name_(table_name),
schema_(schema) { schema_(schema) {
} }
BaseTaskPtr BaseTaskPtr
DescribeTableTask::Create(const std::string& table_name, ::milvus::grpc::TableSchema& schema) { DescribeTableTask::Create(const std::string &table_name, ::milvus::grpc::TableSchema &schema) {
return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema)); return std::shared_ptr<GrpcBaseTask>(new DescribeTableTask(table_name, schema));
} }
ServerError ServerError
...@@ -175,7 +177,7 @@ DescribeTableTask::OnExecute() { ...@@ -175,7 +177,7 @@ DescribeTableTask::OnExecute() {
try { try {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
...@@ -183,17 +185,17 @@ DescribeTableTask::OnExecute() { ...@@ -183,17 +185,17 @@ DescribeTableTask::OnExecute() {
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = table_name_; table_info.table_id_ = table_name_;
engine::Status stat = DBWrapper::DB()->DescribeTable(table_info); engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
if(!stat.ok()) { if (!stat.ok()) {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
schema_.mutable_table_name()->set_table_name(table_info.table_id_); schema_.mutable_table_name()->set_table_name(table_info.table_id_);
schema_.set_index_type(IndexType((engine::EngineType)table_info.engine_type_)); schema_.set_index_type(IndexType((engine::EngineType) table_info.engine_type_));
schema_.set_dimension(table_info.dimension_); schema_.set_dimension(table_info.dimension_);
schema_.set_store_raw_vector(table_info.store_raw_data_); schema_.set_store_raw_vector(table_info.store_raw_data_);
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -203,14 +205,14 @@ DescribeTableTask::OnExecute() { ...@@ -203,14 +205,14 @@ DescribeTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BuildIndexTask::BuildIndexTask(const std::string& table_name) BuildIndexTask::BuildIndexTask(const std::string &table_name)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name) { table_name_(table_name) {
} }
BaseTaskPtr BaseTaskPtr
BuildIndexTask::Create(const std::string& table_name) { BuildIndexTask::Create(const std::string &table_name) {
return std::shared_ptr<BaseTask>(new BuildIndexTask(table_name)); return std::shared_ptr<GrpcBaseTask>(new BuildIndexTask(table_name));
} }
ServerError ServerError
...@@ -220,28 +222,28 @@ BuildIndexTask::OnExecute() { ...@@ -220,28 +222,28 @@ BuildIndexTask::OnExecute() {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
bool has_table = false; bool has_table = false;
engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table); engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table);
if(!stat.ok()) { if (!stat.ok()) {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
if(!has_table) { if (!has_table) {
return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
} }
//step 2: check table existence //step 2: check table existence
stat = DBWrapper::DB()->BuildIndex(table_name_); stat = DBWrapper::DB()->BuildIndex(table_name_);
if(!stat.ok()) { if (!stat.ok()) {
return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString()); return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString());
} }
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -249,16 +251,16 @@ BuildIndexTask::OnExecute() { ...@@ -249,16 +251,16 @@ BuildIndexTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
HasTableTask::HasTableTask(const std::string& table_name, bool& has_table) HasTableTask::HasTableTask(const std::string &table_name, bool &has_table)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name), table_name_(table_name),
has_table_(has_table) { has_table_(has_table) {
} }
BaseTaskPtr BaseTaskPtr
HasTableTask::Create(const std::string& table_name, bool& has_table) { HasTableTask::Create(const std::string &table_name, bool &has_table) {
return std::shared_ptr<BaseTask>(new HasTableTask(table_name, has_table)); return std::shared_ptr<GrpcBaseTask>(new HasTableTask(table_name, has_table));
} }
ServerError ServerError
...@@ -268,18 +270,18 @@ HasTableTask::OnExecute() { ...@@ -268,18 +270,18 @@ HasTableTask::OnExecute() {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
//step 2: check table existence //step 2: check table existence
engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table_); engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table_);
if(!stat.ok()) { if (!stat.ok()) {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
rc.ElapseFromBegin("totally cost"); rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -287,15 +289,15 @@ HasTableTask::OnExecute() { ...@@ -287,15 +289,15 @@ HasTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DropTableTask::DropTableTask(const std::string& table_name) DropTableTask::DropTableTask(const std::string &table_name)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name) { table_name_(table_name) {
} }
BaseTaskPtr BaseTaskPtr
DropTableTask::Create(const std::string& table_name) { DropTableTask::Create(const std::string &table_name) {
return std::shared_ptr<BaseTask>(new DropTableTask(table_name)); return std::shared_ptr<GrpcBaseTask>(new DropTableTask(table_name));
} }
ServerError ServerError
...@@ -305,7 +307,7 @@ DropTableTask::OnExecute() { ...@@ -305,7 +307,7 @@ DropTableTask::OnExecute() {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
...@@ -313,8 +315,8 @@ DropTableTask::OnExecute() { ...@@ -313,8 +315,8 @@ DropTableTask::OnExecute() {
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = table_name_; table_info.table_id_ = table_name_;
engine::Status stat = DBWrapper::DB()->DescribeTable(table_info); engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
if(!stat.ok()) { if (!stat.ok()) {
if(stat.IsNotFound()) { if (stat.IsNotFound()) {
return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
} else { } else {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
...@@ -326,12 +328,12 @@ DropTableTask::OnExecute() { ...@@ -326,12 +328,12 @@ DropTableTask::OnExecute() {
//step 3: Drop table //step 3: Drop table
std::vector<DB_DATE> dates; std::vector<DB_DATE> dates;
stat = DBWrapper::DB()->DeleteTable(table_name_, dates); stat = DBWrapper::DB()->DeleteTable(table_name_, dates);
if(!stat.ok()) { if (!stat.ok()) {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
rc.ElapseFromBegin("total cost"); rc.ElapseFromBegin("total cost");
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -339,26 +341,26 @@ DropTableTask::OnExecute() { ...@@ -339,26 +341,26 @@ DropTableTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ShowTablesTask::ShowTablesTask(::grpc::ServerWriter< ::milvus::grpc::TableName>& writer) ShowTablesTask::ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> &writer)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
writer_(writer) { writer_(writer) {
} }
BaseTaskPtr BaseTaskPtr
ShowTablesTask::Create(::grpc::ServerWriter< ::milvus::grpc::TableName>& writer) { ShowTablesTask::Create(::grpc::ServerWriter<::milvus::grpc::TableName> &writer) {
return std::shared_ptr<BaseTask>(new ShowTablesTask(writer)); return std::shared_ptr<GrpcBaseTask>(new ShowTablesTask(writer));
} }
ServerError ServerError
ShowTablesTask::OnExecute() { ShowTablesTask::OnExecute() {
std::vector<engine::meta::TableSchema> schema_array; std::vector<engine::meta::TableSchema> schema_array;
engine::Status stat = DBWrapper::DB()->AllTables(schema_array); engine::Status stat = DBWrapper::DB()->AllTables(schema_array);
if(!stat.ok()) { if (!stat.ok()) {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
for(auto& schema : schema_array) { for (auto &schema : schema_array) {
::milvus::grpc::TableName tableName; ::milvus::grpc::TableName tableName;
tableName.set_table_name(schema.table_id_); tableName.set_table_name(schema.table_id_);
if (!writer_.Write(tableName)) { if (!writer_.Write(tableName)) {
...@@ -369,18 +371,18 @@ ShowTablesTask::OnExecute() { ...@@ -369,18 +371,18 @@ ShowTablesTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
InsertVectorTask::InsertVectorTask(const ::milvus::grpc::InsertInfos& insert_infos, InsertVectorTask::InsertVectorTask(const ::milvus::grpc::InsertInfos &insert_infos,
::milvus::grpc::VectorIds& record_ids) ::milvus::grpc::VectorIds &record_ids)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
insert_infos_(insert_infos), insert_infos_(insert_infos),
record_ids_(record_ids) { record_ids_(record_ids) {
record_ids_.Clear(); record_ids_.Clear();
} }
BaseTaskPtr BaseTaskPtr
InsertVectorTask::Create(const ::milvus::grpc::InsertInfos& insert_infos, InsertVectorTask::Create(const ::milvus::grpc::InsertInfos &insert_infos,
::milvus::grpc::VectorIds& record_ids) { ::milvus::grpc::VectorIds &record_ids) {
return std::shared_ptr<BaseTask>(new InsertVectorTask(insert_infos, record_ids)); return std::shared_ptr<GrpcBaseTask>(new InsertVectorTask(insert_infos, record_ids));
} }
ServerError ServerError
...@@ -390,10 +392,10 @@ InsertVectorTask::OnExecute() { ...@@ -390,10 +392,10 @@ InsertVectorTask::OnExecute() {
//step 1: check arguments //step 1: check arguments
ServerError res = ValidationUtil::ValidateTableName(insert_infos_.table_name()); ServerError res = ValidationUtil::ValidateTableName(insert_infos_.table_name());
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + insert_infos_.table_name()); return SetError(res, "Invalid table name: " + insert_infos_.table_name());
} }
if(insert_infos_.row_record_array().empty()) { if (insert_infos_.row_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
} }
...@@ -401,9 +403,10 @@ InsertVectorTask::OnExecute() { ...@@ -401,9 +403,10 @@ InsertVectorTask::OnExecute() {
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = insert_infos_.table_name(); table_info.table_id_ = insert_infos_.table_name();
engine::Status stat = DBWrapper::DB()->DescribeTable(table_info); engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
if(!stat.ok()) { if (!stat.ok()) {
if(stat.IsNotFound()) { if (stat.IsNotFound()) {
return SetError(SERVER_TABLE_NOT_EXIST, "Table " + insert_infos_.table_name() + " not exists"); return SetError(SERVER_TABLE_NOT_EXIST,
"Table " + insert_infos_.table_name() + " not exists");
} else { } else {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
...@@ -430,7 +433,8 @@ InsertVectorTask::OnExecute() { ...@@ -430,7 +433,8 @@ InsertVectorTask::OnExecute() {
if (vec_dim != table_info.dimension_) { if (vec_dim != table_info.dimension_) {
ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION;
std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(vec_dim) std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(vec_dim)
+ " vs. table dimension:" + std::to_string(table_info.dimension_); + " vs. table dimension:" +
std::to_string(table_info.dimension_);
return SetError(error_code, error_msg); return SetError(error_code, error_msg);
} }
vec_f[i * table_info.dimension_ + j] = insert_infos_.row_record_array(i).vector_data(j); vec_f[i * table_info.dimension_ + j] = insert_infos_.row_record_array(i).vector_data(j);
...@@ -440,12 +444,13 @@ InsertVectorTask::OnExecute() { ...@@ -440,12 +444,13 @@ InsertVectorTask::OnExecute() {
rc.ElapseFromBegin("prepare vectors data"); rc.ElapseFromBegin("prepare vectors data");
//step 4: insert vectors //step 4: insert vectors
auto vec_count = (uint64_t)insert_infos_.row_record_array_size(); auto vec_count = (uint64_t) insert_infos_.row_record_array_size();
std::vector<int64_t> vec_ids(record_ids_.vector_id_array_size(), 0); std::vector<int64_t> vec_ids(record_ids_.vector_id_array_size(), 0);
stat = DBWrapper::DB()->InsertVectors(insert_infos_.table_name(), vec_count, vec_f.data(), vec_ids); stat = DBWrapper::DB()->InsertVectors(insert_infos_.table_name(), vec_count, vec_f.data(),
vec_ids);
rc.ElapseFromBegin("add vectors to engine"); rc.ElapseFromBegin("add vectors to engine");
if(!stat.ok()) { if (!stat.ok()) {
return SetError(SERVER_CACHE_ERROR, "Cache error: " + stat.ToString()); return SetError(SERVER_CACHE_ERROR, "Cache error: " + stat.ToString());
} }
for (int64_t id : vec_ids) { for (int64_t id : vec_ids) {
...@@ -453,7 +458,7 @@ InsertVectorTask::OnExecute() { ...@@ -453,7 +458,7 @@ InsertVectorTask::OnExecute() {
} }
auto ids_size = record_ids_.vector_id_array_size(); auto ids_size = record_ids_.vector_id_array_size();
if(ids_size != vec_count) { if (ids_size != vec_count) {
std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return " std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return "
+ std::to_string(ids_size) + " id"; + std::to_string(ids_size) + " id";
return SetError(SERVER_ILLEGAL_VECTOR_ID, msg); return SetError(SERVER_ILLEGAL_VECTOR_ID, msg);
...@@ -466,7 +471,7 @@ InsertVectorTask::OnExecute() { ...@@ -466,7 +471,7 @@ InsertVectorTask::OnExecute() {
rc.RecordSection("add vectors to engine"); rc.RecordSection("add vectors to engine");
rc.ElapseFromBegin("total cost"); rc.ElapseFromBegin("total cost");
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -474,10 +479,10 @@ InsertVectorTask::OnExecute() { ...@@ -474,10 +479,10 @@ InsertVectorTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const ::milvus::grpc::SearchVectorInfos& search_vector_infos, SearchVectorTask::SearchVectorTask(const ::milvus::grpc::SearchVectorInfos &search_vector_infos,
const std::vector<std::string>& file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>& writer) ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer)
: BaseTask(DQL_TASK_GROUP), : GrpcBaseTask(DQL_TASK_GROUP),
search_vector_infos_(search_vector_infos), search_vector_infos_(search_vector_infos),
file_id_array_(file_id_array), file_id_array_(file_id_array),
writer_(writer) { writer_(writer) {
...@@ -485,10 +490,10 @@ SearchVectorTask::SearchVectorTask(const ::milvus::grpc::SearchVectorInfos& sear ...@@ -485,10 +490,10 @@ SearchVectorTask::SearchVectorTask(const ::milvus::grpc::SearchVectorInfos& sear
} }
BaseTaskPtr BaseTaskPtr
SearchVectorTask::Create(const ::milvus::grpc::SearchVectorInfos& search_vector_infos, SearchVectorTask::Create(const ::milvus::grpc::SearchVectorInfos &search_vector_infos,
const std::vector<std::string>& file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>& writer) { ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(search_vector_infos, file_id_array, return std::shared_ptr<GrpcBaseTask>(new SearchVectorTask(search_vector_infos, file_id_array,
writer)); writer));
} }
...@@ -500,17 +505,17 @@ SearchVectorTask::OnExecute() { ...@@ -500,17 +505,17 @@ SearchVectorTask::OnExecute() {
//step 1: check arguments //step 1: check arguments
std::string table_name_ = search_vector_infos_.table_name(); std::string table_name_ = search_vector_infos_.table_name();
ServerError res = ValidationUtil::ValidateTableName(table_name_); ServerError res = ValidationUtil::ValidateTableName(table_name_);
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
int top_k_ = search_vector_infos_.topk(); int top_k_ = search_vector_infos_.topk();
if(top_k_ <= 0 || top_k_ > 1024) { if (top_k_ <= 0 || top_k_ > 1024) {
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string( return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(
top_k_)); top_k_));
} }
if(search_vector_infos_.query_record_array().empty()) { if (search_vector_infos_.query_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
} }
...@@ -518,8 +523,8 @@ SearchVectorTask::OnExecute() { ...@@ -518,8 +523,8 @@ SearchVectorTask::OnExecute() {
engine::meta::TableSchema table_info; engine::meta::TableSchema table_info;
table_info.table_id_ = table_name_; table_info.table_id_ = table_name_;
engine::Status stat = DBWrapper::DB()->DescribeTable(table_info); engine::Status stat = DBWrapper::DB()->DescribeTable(table_info);
if(!stat.ok()) { if (!stat.ok()) {
if(stat.IsNotFound()) { if (stat.IsNotFound()) {
return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists");
} else { } else {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
...@@ -536,7 +541,7 @@ SearchVectorTask::OnExecute() { ...@@ -536,7 +541,7 @@ SearchVectorTask::OnExecute() {
range_array.emplace_back(search_vector_infos_.query_range_array(i)); range_array.emplace_back(search_vector_infos_.query_range_array(i));
} }
ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg); ConvertTimeRangeToDBDates(range_array, dates, error_code, error_msg);
if(error_code != SERVER_SUCCESS) { if (error_code != SERVER_SUCCESS) {
return SetError(error_code, error_msg); return SetError(error_code, error_msg);
} }
...@@ -555,41 +560,46 @@ SearchVectorTask::OnExecute() { ...@@ -555,41 +560,46 @@ SearchVectorTask::OnExecute() {
for (size_t i = 0; i < record_array_size; i++) { for (size_t i = 0; i < record_array_size; i++) {
for (size_t j = 0; j < table_info.dimension_; j++) { for (size_t j = 0; j < table_info.dimension_; j++) {
if (search_vector_infos_.query_record_array(i).vector_data().empty()) { if (search_vector_infos_.query_record_array(i).vector_data().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Query record float array is empty"); return SetError(SERVER_INVALID_ROWRECORD_ARRAY,
"Query record float array is empty");
} }
uint64_t query_vec_dim = search_vector_infos_.query_record_array(i).vector_data().size(); uint64_t query_vec_dim = search_vector_infos_.query_record_array(
i).vector_data().size();
if (query_vec_dim != table_info.dimension_) { if (query_vec_dim != table_info.dimension_) {
ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION; ServerError error_code = SERVER_INVALID_VECTOR_DIMENSION;
std::string error_msg = "Invalid rowrecord dimension: " + std::to_string(query_vec_dim) std::string error_msg =
"Invalid rowrecord dimension: " + std::to_string(query_vec_dim)
+ " vs. table dimension:" + std::to_string(table_info.dimension_); + " vs. table dimension:" + std::to_string(table_info.dimension_);
return SetError(error_code, error_msg); return SetError(error_code, error_msg);
} }
vec_f[i * table_info.dimension_ + j] = search_vector_infos_.query_record_array(i).vector_data(j); vec_f[i * table_info.dimension_ + j] = search_vector_infos_.query_record_array(
i).vector_data(j);
} }
} }
rc.ElapseFromBegin("prepare vector data"); rc.ElapseFromBegin("prepare vector data");
//step 4: search vectors //step 4: search vectors
engine::QueryResults results; engine::QueryResults results;
auto record_count = (uint64_t)search_vector_infos_.query_record_array().size(); auto record_count = (uint64_t) search_vector_infos_.query_record_array().size();
if(file_id_array_.empty()) { if (file_id_array_.empty()) {
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, vec_f.data(), dates, results); stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, vec_f.data(),
dates, results);
} else { } else {
stat = DBWrapper::DB()->Query(table_name_, file_id_array_, stat = DBWrapper::DB()->Query(table_name_, file_id_array_,
(size_t) top_k_, record_count, vec_f.data(), dates, results); (size_t) top_k_, record_count, vec_f.data(), dates, results);
} }
rc.ElapseFromBegin("search vectors from engine"); rc.ElapseFromBegin("search vectors from engine");
if(!stat.ok()) { if (!stat.ok()) {
return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString()); return SetError(DB_META_TRANSACTION_FAILED, "Engine failed: " + stat.ToString());
} }
if(results.empty()) { if (results.empty()) {
return SERVER_SUCCESS; //empty table return SERVER_SUCCESS; //empty table
} }
if(results.size() != record_count) { if (results.size() != record_count) {
std::string msg = "Search " + std::to_string(record_count) + " vectors but only return " std::string msg = "Search " + std::to_string(record_count) + " vectors but only return "
+ std::to_string(results.size()) + " results"; + std::to_string(results.size()) + " results";
return SetError(SERVER_ILLEGAL_SEARCH_RESULT, msg); return SetError(SERVER_ILLEGAL_SEARCH_RESULT, msg);
...@@ -598,11 +608,11 @@ SearchVectorTask::OnExecute() { ...@@ -598,11 +608,11 @@ SearchVectorTask::OnExecute() {
rc.ElapseFromBegin("do search"); rc.ElapseFromBegin("do search");
//step 5: construct result array //step 5: construct result array
for(uint64_t i = 0; i < record_count; i++) { for (uint64_t i = 0; i < record_count; i++) {
auto& result = results[i]; auto &result = results[i];
const auto &record = search_vector_infos_.query_record_array(i); const auto &record = search_vector_infos_.query_record_array(i);
::milvus::grpc::TopKQueryResult grpc_topk_result; ::milvus::grpc::TopKQueryResult grpc_topk_result;
for(auto& pair : result) { for (auto &pair : result) {
::milvus::grpc::QueryResult *grpc_result = grpc_topk_result.add_query_result_arrays(); ::milvus::grpc::QueryResult *grpc_result = grpc_topk_result.add_query_result_arrays();
grpc_result->set_id(pair.first); grpc_result->set_id(pair.first);
grpc_result->set_distance(pair.second); grpc_result->set_distance(pair.second);
...@@ -621,7 +631,7 @@ SearchVectorTask::OnExecute() { ...@@ -621,7 +631,7 @@ SearchVectorTask::OnExecute() {
//step 6: print time cost percent //step 6: print time cost percent
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -629,16 +639,16 @@ SearchVectorTask::OnExecute() { ...@@ -629,16 +639,16 @@ SearchVectorTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetTableRowCountTask::GetTableRowCountTask(const std::string& table_name, int64_t& row_count) GetTableRowCountTask::GetTableRowCountTask(const std::string &table_name, int64_t &row_count)
: BaseTask(DDL_DML_TASK_GROUP), : GrpcBaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name), table_name_(table_name),
row_count_(row_count) { row_count_(row_count) {
} }
BaseTaskPtr BaseTaskPtr
GetTableRowCountTask::Create(const std::string& table_name, int64_t& row_count) { GetTableRowCountTask::Create(const std::string &table_name, int64_t &row_count) {
return std::shared_ptr<BaseTask>(new GetTableRowCountTask(table_name, row_count)); return std::shared_ptr<GrpcBaseTask>(new GetTableRowCountTask(table_name, row_count));
} }
ServerError ServerError
...@@ -649,7 +659,7 @@ GetTableRowCountTask::OnExecute() { ...@@ -649,7 +659,7 @@ GetTableRowCountTask::OnExecute() {
//step 1: check arguments //step 1: check arguments
ServerError res = SERVER_SUCCESS; ServerError res = SERVER_SUCCESS;
res = ValidationUtil::ValidateTableName(table_name_); res = ValidationUtil::ValidateTableName(table_name_);
if(res != SERVER_SUCCESS) { if (res != SERVER_SUCCESS) {
return SetError(res, "Invalid table name: " + table_name_); return SetError(res, "Invalid table name: " + table_name_);
} }
...@@ -664,7 +674,7 @@ GetTableRowCountTask::OnExecute() { ...@@ -664,7 +674,7 @@ GetTableRowCountTask::OnExecute() {
rc.ElapseFromBegin("total cost"); rc.ElapseFromBegin("total cost");
} catch (std::exception& ex) { } catch (std::exception &ex) {
return SetError(SERVER_UNEXPECTED_ERROR, ex.what()); return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
} }
...@@ -672,21 +682,21 @@ GetTableRowCountTask::OnExecute() { ...@@ -672,21 +682,21 @@ GetTableRowCountTask::OnExecute() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
PingTask::PingTask(const std::string& cmd, std::string& result) PingTask::PingTask(const std::string &cmd, std::string &result)
: BaseTask(PING_TASK_GROUP), : GrpcBaseTask(PING_TASK_GROUP),
cmd_(cmd), cmd_(cmd),
result_(result) { result_(result) {
} }
BaseTaskPtr BaseTaskPtr
PingTask::Create(const std::string& cmd, std::string& result) { PingTask::Create(const std::string &cmd, std::string &result) {
return std::shared_ptr<BaseTask>(new PingTask(cmd, result)); return std::shared_ptr<GrpcBaseTask>(new PingTask(cmd, result));
} }
ServerError ServerError
PingTask::OnExecute() { PingTask::OnExecute() {
if(cmd_ == "version") { if (cmd_ == "version") {
result_ = MILVUS_VERSION; result_ = MILVUS_VERSION;
} else { } else {
result_ = "OK"; result_ = "OK";
...@@ -698,3 +708,4 @@ PingTask::OnExecute() { ...@@ -698,3 +708,4 @@ PingTask::OnExecute() {
} }
} }
} }
}
\ No newline at end of file
/******************************************************************************* /*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited. * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
#include "RequestScheduler.h" #include "GrpcRequestScheduler.h"
#include "utils/Error.h" #include "utils/Error.h"
#include "db/Types.h" #include "db/Types.h"
...@@ -17,16 +17,17 @@ ...@@ -17,16 +17,17 @@
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace server { namespace server {
namespace grpc {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CreateTableTask : public BaseTask { class CreateTableTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::TableSchema& schema); Create(const ::milvus::grpc::TableSchema &schema);
protected: protected:
explicit explicit
CreateTableTask(const ::milvus::grpc::TableSchema& request); CreateTableTask(const ::milvus::grpc::TableSchema &request);
ServerError ServerError
OnExecute() override; OnExecute() override;
...@@ -36,13 +37,13 @@ private: ...@@ -36,13 +37,13 @@ private:
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class HasTableTask : public BaseTask { class HasTableTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string& table_name, bool& has_table); Create(const std::string &table_name, bool &has_table);
protected: protected:
HasTableTask(const std::string& request, bool& has_table); HasTableTask(const std::string &request, bool &has_table);
ServerError ServerError
OnExecute() override; OnExecute() override;
...@@ -50,17 +51,17 @@ protected: ...@@ -50,17 +51,17 @@ protected:
private: private:
std::string table_name_; std::string table_name_;
bool& has_table_; bool &has_table_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DescribeTableTask : public BaseTask { class DescribeTableTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string& table_name, ::milvus::grpc::TableSchema& schema); Create(const std::string &table_name, ::milvus::grpc::TableSchema &schema);
protected: protected:
DescribeTableTask(const std::string& table_name, ::milvus::grpc::TableSchema& schema); DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema &schema);
ServerError ServerError
OnExecute() override; OnExecute() override;
...@@ -68,18 +69,18 @@ protected: ...@@ -68,18 +69,18 @@ protected:
private: private:
std::string table_name_; std::string table_name_;
::milvus::grpc::TableSchema& schema_; ::milvus::grpc::TableSchema &schema_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DropTableTask : public BaseTask { class DropTableTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string& table_name); Create(const std::string &table_name);
protected: protected:
explicit explicit
DropTableTask(const std::string& table_name); DropTableTask(const std::string &table_name);
ServerError ServerError
OnExecute() override; OnExecute() override;
...@@ -90,14 +91,14 @@ private: ...@@ -90,14 +91,14 @@ private:
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class BuildIndexTask : public BaseTask { class BuildIndexTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string& table_name); Create(const std::string &table_name);
protected: protected:
explicit explicit
BuildIndexTask(const std::string& table_name); BuildIndexTask(const std::string &table_name);
ServerError ServerError
OnExecute() override; OnExecute() override;
...@@ -108,53 +109,53 @@ private: ...@@ -108,53 +109,53 @@ private:
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class ShowTablesTask : public BaseTask { class ShowTablesTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(::grpc::ServerWriter< ::milvus::grpc::TableName>& writer); Create(::grpc::ServerWriter<::milvus::grpc::TableName> &writer);
protected: protected:
explicit explicit
ShowTablesTask(::grpc::ServerWriter< ::milvus::grpc::TableName>& writer); ShowTablesTask(::grpc::ServerWriter<::milvus::grpc::TableName> &writer);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
::grpc::ServerWriter< ::milvus::grpc::TableName> writer_; ::grpc::ServerWriter<::milvus::grpc::TableName> writer_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class InsertVectorTask : public BaseTask { class InsertVectorTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::InsertInfos& insert_infos, Create(const ::milvus::grpc::InsertInfos &insert_infos,
::milvus::grpc::VectorIds& record_ids_); ::milvus::grpc::VectorIds &record_ids_);
protected: protected:
InsertVectorTask(const ::milvus::grpc::InsertInfos& insert_infos, InsertVectorTask(const ::milvus::grpc::InsertInfos &insert_infos,
::milvus::grpc::VectorIds& record_ids_); ::milvus::grpc::VectorIds &record_ids_);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
const ::milvus::grpc::InsertInfos insert_infos_; const ::milvus::grpc::InsertInfos insert_infos_;
::milvus::grpc::VectorIds& record_ids_; ::milvus::grpc::VectorIds &record_ids_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask { class SearchVectorTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const ::milvus::grpc::SearchVectorInfos& searchVectorInfos, Create(const ::milvus::grpc::SearchVectorInfos &searchVectorInfos,
const std::vector<std::string>& file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>& writer); ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer);
protected: protected:
SearchVectorTask(const ::milvus::grpc::SearchVectorInfos& searchVectorInfos, SearchVectorTask(const ::milvus::grpc::SearchVectorInfos &searchVectorInfos,
const std::vector<std::string>& file_id_array, const std::vector<std::string> &file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>& writer); ::grpc::ServerWriter<::milvus::grpc::TopKQueryResult> &writer);
ServerError ServerError
OnExecute() override; OnExecute() override;
...@@ -166,39 +167,39 @@ private: ...@@ -166,39 +167,39 @@ private:
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class GetTableRowCountTask : public BaseTask { class GetTableRowCountTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string& table_name, int64_t& row_count); Create(const std::string &table_name, int64_t &row_count);
protected: protected:
GetTableRowCountTask(const std::string& table_name, int64_t& row_count); GetTableRowCountTask(const std::string &table_name, int64_t &row_count);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
std::string table_name_; std::string table_name_;
int64_t& row_count_; int64_t &row_count_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class PingTask : public BaseTask { class PingTask : public GrpcBaseTask {
public: public:
static BaseTaskPtr static BaseTaskPtr
Create(const std::string& cmd, std::string& result); Create(const std::string &cmd, std::string &result);
protected: protected:
PingTask(const std::string& cmd, std::string& result); PingTask(const std::string &cmd, std::string &result);
ServerError ServerError
OnExecute() override; OnExecute() override;
private: private:
std::string cmd_; std::string cmd_;
std::string& result_; std::string &result_;
}; };
}
} }
} }
} }
\ No newline at end of file
#!/bin/bash #!/bin/bash
./cmake_build/src/milvus_grpc_server -c ./conf/server_config.yaml -l ./conf/log_config.conf & ./cmake_build/src/milvus_server -c ./conf/server_config.yaml -l ./conf/log_config.conf &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册