From a0faf1a78693b5b5cd3987e38c4eef53e0174803 Mon Sep 17 00:00:00 2001 From: BossZou <40255591+BossZou@users.noreply.github.com> Date: Mon, 13 Jan 2020 13:55:48 +0800 Subject: [PATCH] Add http server (#956) * refactoring(create_table done) * refactoring * refactor server delivery (insert done) * refactoring server module (count_table done) * server refactor done * cmake pass * refactor server module done. * set grpc response status correctly * format done. * fix redefine ErrorMap() * optimize insert reducing ids data copy * optimize grpc request with reducing data copy * clang format * [skip ci] Refactor server module done. update changlog. prepare for PR * remove explicit and change int32_t to int64_t * add web server * [skip ci] add license in web module * modify header include & comment oatpp environment config * add port configure & create table in handler * modify web url * simple url complation done & add swagger * make sure web url * web functionality done. debuging * add web unittest * web test pass * add web server port * add web server port in template * update unittest cmake file * change web server default port to 19121 * rename method in web module & unittest pass * add search case in unittest for web module * rename some variables * fix bug * unittest pass * web prepare * fix cmd bug(check server status) * update changlog * add web port validate & default set * clang-format pass * add web port test in unittest * add CORS & redirect root to swagger ui * add web status * web table method func cascade test pass * add config url in web module * modify thirdparty cmake to avoid building oatpp test * clang format * update changlog * add constants in web module * reserve Config.cpp * fix constants reference bug * replace web server with async module * modify component to support async * format * developing controller & add test clent into unittest * add web port into demo/server_config * modify thirdparty cmake to allow build test * remove unnecessary comment * add endpoint info in controller * finish web test(bug here) * clang format * add web test cpp to lint exclusions * check null field in GetConfig * add macro RETURN STATUS DTo * fix cmake conflict * fix crash when exit server * remove surplus comments & add http param check * add uri /docs to direct swagger * format * change cmd to system * add default value & unittest in web module * add macros to judge if GPU supported * add macros in unit & add default in index dto & print error message when bind http port fail * format (fix #788) * fix cors bug (not completed) * comment cors * change web framework to simple api * comments optimize * change to simple API * remove comments in controller.hpp * remove EP_COMMON_CMAKE_ARGS in oatpp and oatpp-swagger * add ep cmake args to sqlite * clang-format * change a format * test pass * change name to * fix compiler issue(oatpp-swagger depend on oatpp) * add & in start_server.h * specify lib location with oatpp and oatpp-swagger * add comments * add swagger definition * [skip ci] change http method options status code * remove oatpp swagger(fix #970) * remove comments * check Start web behavior * add default to cpu_cache_capacity * remove swagger component.hpp & /docs url * remove /docs info * remove /docs in unittest * remove space in test rpc * remove repeate info in CHANGLOG * change cache_insert_data default value as a constant * [skip ci] Fix some broken links (#960) * [skip ci] Fix broken link * [skip ci] Fix broken link * [skip ci] Fix broken link * [skip ci] Fix broken links * fix issue 373 (#964) * fix issue 373 * Adjustment format * Adjustment format * Adjustment format * change readme * #966 update NOTICE.md (#967) * remove comments * check Start web behavior * add default to cpu_cache_capacity * remove swagger component.hpp & /docs url * remove /docs info * remove /docs in unittest * remove space in test rpc * remove repeate info in CHANGLOG * change cache_insert_data default value as a constant * adjust web port cofig place * rename web_port variable * set advanced config name add DEFAULT Co-authored-by: jielinxu <52057195+jielinxu@users.noreply.github.com> Co-authored-by: JackLCL <53512883+JackLCL@users.noreply.github.com> Co-authored-by: Cai Yudong --- CHANGELOG.md | 1 + core/build-support/lint_exclusions.txt | 3 +- core/cmake/DefineOptions.cmake | 2 + core/cmake/ThirdPartyPackages.cmake | 60 +- core/conf/demo/server_config.yaml | 4 + core/conf/server_cpu_config.template | 4 + core/conf/server_gpu_config.template | 4 + core/src/CMakeLists.txt | 15 + core/src/server/Config.cpp | 38 + core/src/server/Config.h | 8 + core/src/server/Server.cpp | 3 + core/src/server/delivery/RequestHandler.h | 2 +- core/src/server/delivery/RequestScheduler.cpp | 8 +- .../request/ShowPartitionsRequest.cpp | 6 +- .../server/grpc_impl/GrpcRequestHandler.cpp | 1 + core/src/server/grpc_impl/GrpcServer.cpp | 3 +- core/src/server/web_impl/Constants.h | 79 ++ core/src/server/web_impl/Types.h | 104 ++ core/src/server/web_impl/WebServer.cpp | 108 ++ core/src/server/web_impl/WebServer.h | 69 ++ .../web_impl/component/AppComponent.hpp | 82 ++ .../web_impl/controller/WebController.hpp | 639 +++++++++++ core/src/server/web_impl/dto/CmdDto.hpp | 45 + core/src/server/web_impl/dto/ConfigDto.hpp | 55 + core/src/server/web_impl/dto/DevicesDto.hpp | 45 + core/src/server/web_impl/dto/Dto.h | 30 + core/src/server/web_impl/dto/IndexDto.hpp | 44 + core/src/server/web_impl/dto/PartitionDto.hpp | 48 + core/src/server/web_impl/dto/StatusDto.hpp | 40 + core/src/server/web_impl/dto/TableDto.hpp | 76 ++ core/src/server/web_impl/dto/VectorDto.hpp | 91 ++ .../web_impl/handler/WebRequestHandler.cpp | 695 ++++++++++++ .../web_impl/handler/WebRequestHandler.h | 165 +++ core/thirdparty/versions.txt | 1 + core/unittest/CMakeLists.txt | 13 + core/unittest/server/CMakeLists.txt | 6 +- core/unittest/server/test_config.cpp | 9 + core/unittest/server/test_rpc.cpp | 2 +- core/unittest/server/test_web.cpp | 991 ++++++++++++++++++ 39 files changed, 3585 insertions(+), 14 deletions(-) create mode 100644 core/src/server/web_impl/Constants.h create mode 100644 core/src/server/web_impl/Types.h create mode 100644 core/src/server/web_impl/WebServer.cpp create mode 100644 core/src/server/web_impl/WebServer.h create mode 100644 core/src/server/web_impl/component/AppComponent.hpp create mode 100644 core/src/server/web_impl/controller/WebController.hpp create mode 100644 core/src/server/web_impl/dto/CmdDto.hpp create mode 100644 core/src/server/web_impl/dto/ConfigDto.hpp create mode 100644 core/src/server/web_impl/dto/DevicesDto.hpp create mode 100644 core/src/server/web_impl/dto/Dto.h create mode 100644 core/src/server/web_impl/dto/IndexDto.hpp create mode 100644 core/src/server/web_impl/dto/PartitionDto.hpp create mode 100644 core/src/server/web_impl/dto/StatusDto.hpp create mode 100644 core/src/server/web_impl/dto/TableDto.hpp create mode 100644 core/src/server/web_impl/dto/VectorDto.hpp create mode 100644 core/src/server/web_impl/handler/WebRequestHandler.cpp create mode 100644 core/src/server/web_impl/handler/WebRequestHandler.h create mode 100644 core/unittest/server/test_web.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fb8d0c2..1b2396b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Please mark all change in change log and use the issue from GitHub - \#766 - If partition tag is similar, wrong partition is searched - \#771 - Add server build commit info interface - \#759 - Put C++ sdk out of milvus/core +- \#788 - Add web server into server module - \#813 - Add push mode for prometheus monitor - \#815 - Support MinIO storage - \#910 - Change Milvus c++ standard to c++17 diff --git a/core/build-support/lint_exclusions.txt b/core/build-support/lint_exclusions.txt index 34d469fc..36a872f9 100644 --- a/core/build-support/lint_exclusions.txt +++ b/core/build-support/lint_exclusions.txt @@ -7,4 +7,5 @@ *SqliteMetaImpl.cpp *src/grpc* *thirdparty* -*milvus/include* \ No newline at end of file +*milvus/include* +*unittest/server/test_web.cpp \ No newline at end of file diff --git a/core/cmake/DefineOptions.cmake b/core/cmake/DefineOptions.cmake index c51edcf8..f5e013b2 100644 --- a/core/cmake/DefineOptions.cmake +++ b/core/cmake/DefineOptions.cmake @@ -91,6 +91,8 @@ define_option(MILVUS_WITH_FIU "Build with fiu" OFF) define_option(MILVUS_WITH_AWS "Build with aws" ON) +define_option(MILVUS_WITH_OATPP "Build with oatpp" ON) + #---------------------------------------------------------------------- set_option_category("Test and benchmark") diff --git a/core/cmake/ThirdPartyPackages.cmake b/core/cmake/ThirdPartyPackages.cmake index 22454950..dfe98cd4 100644 --- a/core/cmake/ThirdPartyPackages.cmake +++ b/core/cmake/ThirdPartyPackages.cmake @@ -28,7 +28,8 @@ set(MILVUS_THIRDPARTY_DEPENDENCIES ZLIB Opentracing fiu - AWS) + AWS + oatpp) message(STATUS "Using ${MILVUS_DEPENDENCY_SOURCE} approach to find dependencies") @@ -64,6 +65,8 @@ macro(build_dependency DEPENDENCY_NAME) build_opentracing() elseif ("${DEPENDENCY_NAME}" STREQUAL "fiu") build_fiu() + elseif ("${DEPENDENCY_NAME}" STREQUAL "oatpp") + build_oatpp() elseif("${DEPENDENCY_NAME}" STREQUAL "AWS") build_aws() else () @@ -330,6 +333,13 @@ else () "https://gitee.com/quicksilver/libfiu/repository/archive/${FIU_VERSION}.zip") endif () +if (DEFINED ENV{MILVUS_OATPP_URL}) + set(MILVUS_OATPP_URL "$ENV{MILVUS_OATPP_URL}") +else () +# set(OATPP_SOURCE_URL "https://github.com/oatpp/oatpp/archive/${OATPP_VERSION}.tar.gz") + set(OATPP_SOURCE_URL "https://github.com/BossZou/oatpp/archive/master.zip") +endif () + if (DEFINED ENV{MILVUS_AWS_URL}) set(AWS_SOURCE_URL "$ENV{MILVUS_AWS_URL}") else () @@ -973,7 +983,6 @@ endif () # ---------------------------------------------------------------------- # fiu - macro(build_fiu) message(STATUS "Building FIU-${FIU_VERSION} from source") set(FIU_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/fiu_ep-prefix/src/fiu_ep") @@ -1013,6 +1022,53 @@ resolve_dependency(fiu) get_target_property(FIU_INCLUDE_DIR fiu INTERFACE_INCLUDE_DIRECTORIES) include_directories(SYSTEM ${FIU_INCLUDE_DIR}) +# ---------------------------------------------------------------------- +# oatpp +macro(build_oatpp) + message(STATUS "Building oatpp-${OATPP_VERSION} from source") + set(OATPP_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/oatpp_ep-prefix/src/oatpp_ep") + set(OATPP_STATIC_LIB "${OATPP_PREFIX}/lib/oatpp-${OATPP_VERSION}/${CMAKE_STATIC_LIBRARY_PREFIX}oatpp${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(OATPP_INCLUDE_DIR "${OATPP_PREFIX}/include/oatpp-${OATPP_VERSION}/oatpp") + set(OATPP_DIR_SRC "${OATPP_PREFIX}/src") + set(OATPP_DIR_LIB "${OATPP_PREFIX}/lib") + + set(OATPP_CMAKE_ARGS + "-DCMAKE_INSTALL_PREFIX=${OATPP_PREFIX}" + -DCMAKE_INSTALL_LIBDIR=lib + -DBUILD_SHARED_LIBS=OFF + -DOATPP_BUILD_TESTS=OFF + ) + + + externalproject_add(oatpp_ep + URL + ${OATPP_SOURCE_URL} + ${EP_LOG_OPTIONS} + CMAKE_ARGS + ${OATPP_CMAKE_ARGS} + BUILD_COMMAND + ${MAKE} + ${MAKE_BUILD_ARGS} + BUILD_BYPRODUCTS + ${OATPP_STATIC_LIB} + ) + + file(MAKE_DIRECTORY "${OATPP_INCLUDE_DIR}") + add_library(oatpp STATIC IMPORTED) + set_target_properties(oatpp + PROPERTIES IMPORTED_LOCATION "${OATPP_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${OATPP_INCLUDE_DIR}") + + add_dependencies(oatpp oatpp_ep) +endmacro() + +if (MILVUS_WITH_OATPP) + resolve_dependency(oatpp) + + get_target_property(OATPP_INCLUDE_DIR oatpp INTERFACE_INCLUDE_DIRECTORIES) + include_directories(SYSTEM ${OATPP_INCLUDE_DIR}) +endif () + # ---------------------------------------------------------------------- # aws macro(build_aws) diff --git a/core/conf/demo/server_config.yaml b/core/conf/demo/server_config.yaml index 40cc7855..60110444 100644 --- a/core/conf/demo/server_config.yaml +++ b/core/conf/demo/server_config.yaml @@ -29,11 +29,15 @@ version: 0.1 #----------------------+------------------------------------------------------------+------------+-----------------+ # time_zone | Use UTC-x or UTC+x to specify a time zone. | Timezone | UTC+8 | #----------------------+------------------------------------------------------------+------------+-----------------+ +# web_port | Port that Milvus web server monitors. | Integer | 19121 | +# | Port range (1024, 65535) | | | +#----------------------+------------------------------------------------------------+------------+-----------------+ server_config: address: 0.0.0.0 port: 19530 deploy_mode: single time_zone: UTC+8 + web_port: 19121 #----------------------+------------------------------------------------------------+------------+-----------------+ # DataBase Config | Description | Type | Default | diff --git a/core/conf/server_cpu_config.template b/core/conf/server_cpu_config.template index 7f375b11..b31cf5cb 100644 --- a/core/conf/server_cpu_config.template +++ b/core/conf/server_cpu_config.template @@ -29,11 +29,15 @@ version: 0.1 #----------------------+------------------------------------------------------------+------------+-----------------+ # time_zone | Use UTC-x or UTC+x to specify a time zone. | Timezone | UTC+8 | #----------------------+------------------------------------------------------------+------------+-----------------+ +# web_port | Port that Milvus web server monitors. | Integer | 19121 | +# | Port range (1024, 65535) | | | +#----------------------+------------------------------------------------------------+------------+-----------------+ server_config: address: 0.0.0.0 port: 19530 deploy_mode: single time_zone: UTC+8 + web_port: 19121 #----------------------+------------------------------------------------------------+------------+-----------------+ # DataBase Config | Description | Type | Default | diff --git a/core/conf/server_gpu_config.template b/core/conf/server_gpu_config.template index 03861be2..cbaa6be0 100644 --- a/core/conf/server_gpu_config.template +++ b/core/conf/server_gpu_config.template @@ -29,11 +29,15 @@ version: 0.1 #----------------------+------------------------------------------------------------+------------+-----------------+ # time_zone | Use UTC-x or UTC+x to specify a time zone. | Timezone | UTC+8 | #----------------------+------------------------------------------------------------+------------+-----------------+ +# web_port | Port that Milvus web server monitors. | Integer | 19121 | +# | Port range (1024, 65535) | | | +#----------------------+------------------------------------------------------------+------------+-----------------+ server_config: address: 0.0.0.0 port: 19530 deploy_mode: single time_zone: UTC+8 + web_port: 19121 #----------------------+------------------------------------------------------------+------------+-----------------+ # DataBase Config | Description | Type | Default | diff --git a/core/src/CMakeLists.txt b/core/src/CMakeLists.txt index 1199479f..9b02c55a 100644 --- a/core/src/CMakeLists.txt +++ b/core/src/CMakeLists.txt @@ -92,6 +92,19 @@ set(grpc_server_files ${grpc_interceptor_files} ) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/handler web_handler_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/component web_conponent_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/controller web_controller_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/dto web_dto_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl web_impl_files) +set(web_server_files + ${web_handler_files} + ${web_conponent_files} + ${web_controller_files} + ${web_dto_files} + ${web_impl_files} + ) + aux_source_directory(${MILVUS_ENGINE_SRC}/storage storage_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/storage/file storage_file_files) aux_source_directory(${MILVUS_ENGINE_SRC}/storage/s3 storage_s3_files) @@ -247,6 +260,7 @@ set(server_libs milvus_engine metrics tracing + oatpp ) add_executable(milvus_server @@ -256,6 +270,7 @@ add_executable(milvus_server ${server_files} ${grpc_server_files} ${grpc_service_files} + ${web_server_files} ${server_context_files} ${utils_files} ${tracing_files} diff --git a/core/src/server/Config.cpp b/core/src/server/Config.cpp index 96597ad4..ece734db 100644 --- a/core/src/server/Config.cpp +++ b/core/src/server/Config.cpp @@ -88,6 +88,9 @@ Config::ValidateConfig() { std::string server_time_zone; CONFIG_CHECK(GetServerConfigTimeZone(server_time_zone)); + std::string server_web_port; + CONFIG_CHECK(GetServerConfigWebPort(server_web_port)); + /* db config */ std::string db_backend_url; CONFIG_CHECK(GetDBConfigBackendUrl(db_backend_url)); @@ -194,6 +197,7 @@ Config::ResetDefaultConfig() { CONFIG_CHECK(SetServerConfigPort(CONFIG_SERVER_PORT_DEFAULT)); CONFIG_CHECK(SetServerConfigDeployMode(CONFIG_SERVER_DEPLOY_MODE_DEFAULT)); CONFIG_CHECK(SetServerConfigTimeZone(CONFIG_SERVER_TIME_ZONE_DEFAULT)); + CONFIG_CHECK(SetServerConfigWebPort(CONFIG_SERVER_WEB_PORT_DEFAULT)); /* db config */ CONFIG_CHECK(SetDBConfigBackendUrl(CONFIG_DB_BACKEND_URL_DEFAULT)); @@ -404,6 +408,23 @@ Config::CheckServerConfigTimeZone(const std::string& value) { return Status::OK(); } +Status +Config::CheckServerConfigWebPort(const std::string& value) { + if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { + std::string msg = + "Invalid web server port: " + value + ". Possible reason: server_config.web_port is not a number."; + return Status(SERVER_INVALID_ARGUMENT, msg); + } else { + int32_t port = std::stoi(value); + if (!(port > 1024 && port < 65535)) { + std::string msg = "Invalid web server port: " + value + + ". Possible reason: server_config.web_port is not in range [1025, 65534]."; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } + return Status::OK(); +} + /* DB config */ Status Config::CheckDBConfigBackendUrl(const std::string& value) { @@ -668,6 +689,7 @@ Config::CheckEngineConfigOmpThreadNum(const std::string& value) { } #ifdef MILVUS_GPU_VERSION + Status Config::CheckEngineConfigGpuSearchThreshold(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { @@ -789,6 +811,7 @@ Config::CheckGpuResourceConfigBuildIndexResources(const std::vector return Status::OK(); } + #endif //////////////////////////////////////////////////////////////////////////////// @@ -890,6 +913,12 @@ Config::GetServerConfigTimeZone(std::string& value) { return CheckServerConfigTimeZone(value); } +Status +Config::GetServerConfigWebPort(std::string& value) { + value = GetConfigStr(CONFIG_SERVER, CONFIG_SERVER_WEB_PORT, CONFIG_SERVER_WEB_PORT_DEFAULT); + return CheckServerConfigWebPort(value); +} + /* DB config */ Status Config::GetDBConfigBackendUrl(std::string& value) { @@ -1051,6 +1080,7 @@ Config::GetEngineConfigOmpThreadNum(int64_t& value) { } #ifdef MILVUS_GPU_VERSION + Status Config::GetEngineConfigGpuSearchThreshold(int64_t& value) { std::string str = @@ -1140,6 +1170,7 @@ Config::GetGpuResourceConfigBuildIndexResources(std::vector& value) { } return Status::OK(); } + #endif /* tracing config */ @@ -1183,6 +1214,12 @@ Config::SetServerConfigTimeZone(const std::string& value) { return SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_TIME_ZONE, value); } +Status +Config::SetServerConfigWebPort(const std::string& value) { + CONFIG_CHECK(CheckServerConfigWebPort(value)); + return SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_WEB_PORT, value); +} + /* db config */ Status Config::SetDBConfigBackendUrl(const std::string& value) { @@ -1309,6 +1346,7 @@ Config::SetEngineConfigOmpThreadNum(const std::string& value) { } #ifdef MILVUS_GPU_VERSION +/* gpu resource config */ Status Config::SetEngineConfigGpuSearchThreshold(const std::string& value) { CONFIG_CHECK(CheckEngineConfigGpuSearchThreshold(value)); diff --git a/core/src/server/Config.h b/core/src/server/Config.h index d3e542db..4401c6d8 100644 --- a/core/src/server/Config.h +++ b/core/src/server/Config.h @@ -49,6 +49,8 @@ static const char* CONFIG_SERVER_DEPLOY_MODE = "deploy_mode"; static const char* CONFIG_SERVER_DEPLOY_MODE_DEFAULT = "single"; static const char* CONFIG_SERVER_TIME_ZONE = "time_zone"; static const char* CONFIG_SERVER_TIME_ZONE_DEFAULT = "UTC+8"; +static const char* CONFIG_SERVER_WEB_PORT = "web_port"; +static const char* CONFIG_SERVER_WEB_PORT_DEFAULT = "19121"; /* db config */ static const char* CONFIG_DB = "db_config"; @@ -176,6 +178,8 @@ class Config { CheckServerConfigDeployMode(const std::string& value); Status CheckServerConfigTimeZone(const std::string& value); + Status + CheckServerConfigWebPort(const std::string& value); /* db config */ Status @@ -262,6 +266,8 @@ class Config { GetServerConfigDeployMode(std::string& value); Status GetServerConfigTimeZone(std::string& value); + Status + GetServerConfigWebPort(std::string& value); /* db config */ Status @@ -346,6 +352,8 @@ class Config { SetServerConfigDeployMode(const std::string& value); Status SetServerConfigTimeZone(const std::string& value); + Status + SetServerConfigWebPort(const std::string& value); /* db config */ Status diff --git a/core/src/server/Server.cpp b/core/src/server/Server.cpp index 85aeaccc..3b21b21c 100644 --- a/core/src/server/Server.cpp +++ b/core/src/server/Server.cpp @@ -26,6 +26,7 @@ #include "server/Config.h" #include "server/DBWrapper.h" #include "server/grpc_impl/GrpcServer.h" +#include "server/web_impl/WebServer.h" #include "src/version.h" #include "storage/s3/S3ClientWrapper.h" #include "tracing/TracerUtil.h" @@ -264,12 +265,14 @@ Server::StartService() { scheduler::StartSchedulerService(); DBWrapper::GetInstance().StartService(); grpc::GrpcServer::GetInstance().Start(); + web::WebServer::GetInstance().Start(); storage::S3ClientWrapper::GetInstance().StartService(); } void Server::StopService() { storage::S3ClientWrapper::GetInstance().StopService(); + web::WebServer::GetInstance().Stop(); grpc::GrpcServer::GetInstance().Stop(); DBWrapper::GetInstance().StopService(); scheduler::StopSchedulerService(); diff --git a/core/src/server/delivery/RequestHandler.h b/core/src/server/delivery/RequestHandler.h index c6b97184..fa701378 100644 --- a/core/src/server/delivery/RequestHandler.h +++ b/core/src/server/delivery/RequestHandler.h @@ -18,7 +18,7 @@ #pragma once #include "server/delivery/request/BaseRequest.h" -#include "src/utils/Status.h" +#include "utils/Status.h" #include #include diff --git a/core/src/server/delivery/RequestScheduler.cpp b/core/src/server/delivery/RequestScheduler.cpp index ff374c49..fed0a621 100644 --- a/core/src/server/delivery/RequestScheduler.cpp +++ b/core/src/server/delivery/RequestScheduler.cpp @@ -57,26 +57,28 @@ RequestScheduler::Start() { void RequestScheduler::Stop() { - if (stopped_) { + if (stopped_ && request_groups_.empty() && execute_threads_.empty()) { return; } SERVER_LOG_INFO << "Scheduler gonna stop..."; { std::lock_guard lock(queue_mtx_); - for (auto iter : request_groups_) { + for (auto& iter : request_groups_) { if (iter.second != nullptr) { iter.second->Put(nullptr); } } } - for (auto iter : execute_threads_) { + for (auto& iter : execute_threads_) { if (iter == nullptr) continue; iter->join(); } + request_groups_.clear(); + execute_threads_.clear(); stopped_ = true; SERVER_LOG_INFO << "Scheduler stopped"; } diff --git a/core/src/server/delivery/request/ShowPartitionsRequest.cpp b/core/src/server/delivery/request/ShowPartitionsRequest.cpp index 481ac4cc..eddbd67b 100644 --- a/core/src/server/delivery/request/ShowPartitionsRequest.cpp +++ b/core/src/server/delivery/request/ShowPartitionsRequest.cpp @@ -49,9 +49,9 @@ ShowPartitionsRequest::OnExecute() { } std::vector schema_array; - auto statuts = DBWrapper::DB()->ShowPartitions(table_name_, schema_array); - if (!statuts.ok()) { - return statuts; + status = DBWrapper::DB()->ShowPartitions(table_name_, schema_array); + if (!status.ok()) { + return status; } for (auto& schema : schema_array) { diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 24fe5882..0bb9151d 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -23,6 +23,7 @@ #include "server/grpc_impl/GrpcRequestHandler.h" #include "tracing/TextMapCarrier.h" #include "tracing/TracerUtil.h" +#include "utils/Log.h" #include "utils/TimeRecorder.h" namespace milvus { diff --git a/core/src/server/grpc_impl/GrpcServer.cpp b/core/src/server/grpc_impl/GrpcServer.cpp index 52cc48b9..769ba893 100644 --- a/core/src/server/grpc_impl/GrpcServer.cpp +++ b/core/src/server/grpc_impl/GrpcServer.cpp @@ -103,9 +103,8 @@ GrpcServer::StartService() { builder.SetDefaultCompressionAlgorithm(GRPC_COMPRESS_STREAM_GZIP); builder.SetDefaultCompressionLevel(GRPC_COMPRESS_LEVEL_NONE); - RequestHandler handler; GrpcRequestHandler service(opentracing::Tracer::Global()); - service.RegisterRequestHandler(handler); + service.RegisterRequestHandler(RequestHandler()); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(&service); diff --git a/core/src/server/web_impl/Constants.h b/core/src/server/web_impl/Constants.h new file mode 100644 index 00000000..31941c39 --- /dev/null +++ b/core/src/server/web_impl/Constants.h @@ -0,0 +1,79 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace milvus { +namespace server { +namespace web { + +//////////////////////////////////////////////////// + +static const char* CORS_KEY_METHODS = "Access-Control-Allow-Methods"; +static const char* CORS_KEY_ORIGIN = "Access-Control-Allow-Origin"; +static const char* CORS_KEY_HEADERS = "Access-Control-Allow-Headers"; +static const char* CORS_KEY_AGE = "Access-Control-Max-Age"; + +static const char* CORS_VALUE_METHODS = "GET, POST, PUT, OPTIONS, DELETE"; +static const char* CORS_VALUE_ORIGIN = "*"; +static const char* CORS_VALUE_HEADERS = + "DNT, User-Agent, X-Requested-With, If-Modified-Since, Cache-Control, Content-Type, Range, Authorization"; +static const char* CORS_VALUE_AGE = "1728000"; + +//////////////////////////////////////////////////// + +static const char* NAME_ENGINE_TYPE_FLAT = "FLAT"; +static const char* NAME_ENGINE_TYPE_IVFFLAT = "IVFFLAT"; +static const char* NAME_ENGINE_TYPE_IVFSQ8 = "IVFSQ8"; +static const char* NAME_ENGINE_TYPE_IVFSQ8H = "IVFSQ8H"; +static const char* NAME_ENGINE_TYPE_RNSG = "RNSG"; +static const char* NAME_ENGINE_TYPE_IVFPQ = "IVFPQ"; + +static const char* NAME_METRIC_TYPE_L2 = "L2"; +static const char* NAME_METRIC_TYPE_IP = "IP"; + +//////////////////////////////////////////////////// + +static const char* KEY_TABLE_TABLE_NAME = "table_name"; +static const char* KEY_TABLE_DIMENSION = "dimension"; +static const char* KEY_TABLE_INDEX_FILE_SIZE = "index_file_size"; +static const char* KEY_TABLE_INDEX_METRIC_TYPE = "metric_type"; +static const char* KEY_TABLE_COUNT = "count"; + +static const char* KEY_INDEX_INDEX_TYPE = "index_type"; +static const char* KEY_INDEX_NLIST = "nlist"; + +static const char* KEY_PARTITION_NAME = "partition_name"; +static const char* KEY_PARTITION_TAG = "partition_tag"; + +//////////////////////////////////////////////////// + +static const int64_t VALUE_TABLE_INDEX_FILE_SIZE_DEFAULT = 1024; +static const char* VALUE_TABLE_METRIC_TYPE_DEFAULT = "L2"; + +static const char* VALUE_PARTITION_TAG_DEFAULT = ""; + +static const char* VALUE_INDEX_INDEX_TYPE_DEFAULT = NAME_ENGINE_TYPE_FLAT; +static const int64_t VALUE_INDEX_NLIST_DEFAULT = 16384; + +static const int64_t VALUE_CONFIG_CPU_CACHE_CAPACITY_DEFAULT = 4; +static const bool VALUE_CONFIG_CACHE_INSERT_DATA_DEFAULT = false; + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/Types.h b/core/src/server/web_impl/Types.h new file mode 100644 index 00000000..97b09961 --- /dev/null +++ b/core/src/server/web_impl/Types.h @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include + +#include "db/engine/ExecutionEngine.h" + +#include "server/web_impl/Constants.h" + +namespace milvus { +namespace server { +namespace web { + +using OString = oatpp::data::mapping::type::String; +using OInt64 = oatpp::data::mapping::type::Int64; +using OFloat32 = oatpp::data::mapping::type::Float32; +template +using OList = oatpp::data::mapping::type::List; + +using OQueryParams = oatpp::web::protocol::http::QueryParams; + +enum StatusCode : int { + SUCCESS = 0, + UNEXPECTED_ERROR = 1, + CONNECT_FAILED = 2, // reserved. + PERMISSION_DENIED = 3, + TABLE_NOT_EXISTS = 4, // DB_NOT_FOUND || TABLE_NOT_EXISTS + ILLEGAL_ARGUMENT = 5, + ILLEGAL_RANGE = 6, + ILLEGAL_DIMENSION = 7, + ILLEGAL_INDEX_TYPE = 8, + ILLEGAL_TABLE_NAME = 9, + ILLEGAL_TOPK = 10, + ILLEGAL_ROWRECORD = 11, + ILLEGAL_VECTOR_ID = 12, + ILLEGAL_SEARCH_RESULT = 13, + FILE_NOT_FOUND = 14, + META_FAILED = 15, + CACHE_FAILED = 16, + CANNOT_CREATE_FOLDER = 17, + CANNOT_CREATE_FILE = 18, + CANNOT_DELETE_FOLDER = 19, + CANNOT_DELETE_FILE = 20, + BUILD_INDEX_ERROR = 21, + ILLEGAL_NLIST = 22, + ILLEGAL_METRIC_TYPE = 23, + OUT_OF_MEMORY = 24, + + // HTTP status code + PATH_PARAM_LOSS = 31, + QUERY_PARAM_LOSS = 32, + BODY_FIELD_LOSS = 33, +}; + +static const std::unordered_map IndexMap = { + {engine::EngineType::FAISS_IDMAP, NAME_ENGINE_TYPE_FLAT}, + {engine::EngineType::FAISS_IVFFLAT, NAME_ENGINE_TYPE_IVFFLAT}, + {engine::EngineType::FAISS_IVFSQ8, NAME_ENGINE_TYPE_IVFSQ8}, + {engine::EngineType::FAISS_IVFSQ8H, NAME_ENGINE_TYPE_IVFSQ8H}, + {engine::EngineType::NSG_MIX, NAME_ENGINE_TYPE_RNSG}, + {engine::EngineType::FAISS_PQ, NAME_ENGINE_TYPE_IVFPQ}, +}; + +static const std::unordered_map IndexNameMap = { + {NAME_ENGINE_TYPE_FLAT, engine::EngineType::FAISS_IDMAP}, + {NAME_ENGINE_TYPE_IVFFLAT, engine::EngineType::FAISS_IVFFLAT}, + {NAME_ENGINE_TYPE_IVFSQ8, engine::EngineType::FAISS_IVFSQ8}, + {NAME_ENGINE_TYPE_IVFSQ8H, engine::EngineType::FAISS_IVFSQ8H}, + {NAME_ENGINE_TYPE_RNSG, engine::EngineType::NSG_MIX}, + {NAME_ENGINE_TYPE_IVFPQ, engine::EngineType::FAISS_PQ}, +}; + +static const std::unordered_map MetricMap = { + {engine::MetricType::L2, NAME_METRIC_TYPE_L2}, + {engine::MetricType::IP, NAME_METRIC_TYPE_IP}, +}; + +static const std::unordered_map MetricNameMap = { + {NAME_METRIC_TYPE_L2, engine::MetricType::L2}, + {NAME_METRIC_TYPE_IP, engine::MetricType::IP}, +}; + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/WebServer.cpp b/core/src/server/web_impl/WebServer.cpp new file mode 100644 index 00000000..f631cd1a --- /dev/null +++ b/core/src/server/web_impl/WebServer.cpp @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "server/web_impl/WebServer.h" +#include "server/web_impl/controller/WebController.hpp" + +#include "server/Config.h" + +namespace milvus { +namespace server { +namespace web { + +void +WebServer::Start() { + if (nullptr == thread_ptr_) { + thread_ptr_ = std::make_shared(&WebServer::StartService, this); + } +} + +void +WebServer::Stop() { + StopService(); + + if (thread_ptr_ != nullptr) { + thread_ptr_->join(); + thread_ptr_ = nullptr; + } +} + +Status +WebServer::StartService() { + oatpp::base::Environment::init(); + + Config& config = Config::GetInstance(); + std::string port; + Status status; + + status = config.GetServerConfigWebPort(port); + + { + AppComponent components = AppComponent(std::stoi(port)); + + auto user_controller = WebController::createShared(); + + /* create ApiControllers and add endpoints to router */ + OATPP_COMPONENT(std::shared_ptr, router); + user_controller->addEndpointsToRouter(router); + + /* Get connection handler component */ + OATPP_COMPONENT(std::shared_ptr, connection_handler); + + /* Get connection provider component */ + OATPP_COMPONENT(std::shared_ptr, connection_provider); + + /* create server */ + auto server = oatpp::network::server::Server(connection_provider, connection_handler); + + std::thread stop_thread([&server, this] { + while (!this->try_stop_.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + + server.stop(); + OATPP_COMPONENT(std::shared_ptr, client_provider); + client_provider->getConnection(); + }); + + // start synchronously + server.run(); + + connection_handler->stop(); + + stop_thread.join(); + } + + oatpp::base::Environment::destroy(); + + return Status::OK(); +} + +Status +WebServer::StopService() { + try_stop_.store(true); + + return Status::OK(); +} + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/WebServer.h b/core/src/server/web_impl/WebServer.h new file mode 100644 index 00000000..a99f5482 --- /dev/null +++ b/core/src/server/web_impl/WebServer.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include + +#include "server/web_impl/component/AppComponent.hpp" + +#include "utils/Status.h" + +namespace milvus { +namespace server { +namespace web { + +class WebServer { + public: + static WebServer& + GetInstance() { + static WebServer web_server; + return web_server; + } + + void + Start(); + + void + Stop(); + + private: + WebServer() { + try_stop_.store(false); + } + + ~WebServer() = default; + + Status + StartService(); + Status + StopService(); + + private: + std::atomic_bool try_stop_; + + std::shared_ptr thread_ptr_; +}; + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/component/AppComponent.hpp b/core/src/server/web_impl/component/AppComponent.hpp new file mode 100644 index 00000000..8ac037fa --- /dev/null +++ b/core/src/server/web_impl/component/AppComponent.hpp @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "server/web_impl/handler/WebRequestHandler.h" + +namespace milvus { +namespace server { +namespace web { + +class AppComponent { + + public: + + explicit AppComponent(int port) : port_(port) { + } + + private: + const int port_; + + public: + OATPP_CREATE_COMPONENT(std::shared_ptr, server_connection_provider_) + ([this] { + try { + return oatpp::network::server::SimpleTCPConnectionProvider::createShared(this->port_); + } catch (std::exception& e) { + std::string error_msg = "Cannot bind http port " + std::to_string(this->port_) + + ". Check if the port is already used"; + std::cout << error_msg << std::endl; + throw std::runtime_error(error_msg); + } + }()); + + OATPP_CREATE_COMPONENT(std::shared_ptr, client_connection_provider_) + ([this] { + return oatpp::network::client::SimpleTCPConnectionProvider::createShared("localhost", this->port_); + }()); + + OATPP_CREATE_COMPONENT(std::shared_ptr, http_router_)([] { + return oatpp::web::server::HttpRouter::createShared(); + }()); + + OATPP_CREATE_COMPONENT(std::shared_ptr, server_connection_handler_)([] { + OATPP_COMPONENT(std::shared_ptr, router); // get Router component + return oatpp::web::server::HttpConnectionHandler::createShared(router); + }()); + + OATPP_CREATE_COMPONENT(std::shared_ptr, api_object_mapper_)([] { + auto serializerConfig = oatpp::parser::json::mapping::Serializer::Config::createShared(); + auto deserializerConfig = oatpp::parser::json::mapping::Deserializer::Config::createShared(); + deserializerConfig->allowUnknownFields = false; + return oatpp::parser::json::mapping::ObjectMapper::createShared(serializerConfig, + deserializerConfig); + }()); +}; + +} //namespace web +} //namespace server +} //namespace milvus diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp new file mode 100644 index 00000000..585569f7 --- /dev/null +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -0,0 +1,639 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "server/web_impl/dto/ConfigDto.hpp" +#include "server/web_impl/dto/TableDto.hpp" +#include "server/web_impl/dto/CmdDto.hpp" +#include "server/web_impl/dto/IndexDto.hpp" +#include "server/web_impl/dto/PartitionDto.hpp" +#include "server/web_impl/dto/VectorDto.hpp" +#include "server/web_impl/dto/ConfigDto.hpp" + +#include "utils/Log.h" +#include "server/delivery/RequestHandler.h" +#include "server/web_impl/handler/WebRequestHandler.h" + +namespace milvus { +namespace server { +namespace web { + +class WebController : public oatpp::web::server::api::ApiController { + public: + WebController(const std::shared_ptr& objectMapper) + : oatpp::web::server::api::ApiController(objectMapper) {} + + public: + + static std::shared_ptr createShared(OATPP_COMPONENT(std::shared_ptr, + objectMapper)) { + return std::make_shared(objectMapper); + } + + /** + * Begin ENDPOINTs generation ('ApiController' codegen) + */ +#include OATPP_CODEGEN_BEGIN(ApiController) + + ENDPOINT_INFO(root) { + info->summary = "Index.html page"; + } + + ADD_CORS(root) + + ENDPOINT("GET", "/", root) { + auto response = createResponse(Status::CODE_200, "Welcome to milvus"); + response->putHeader(Header::CONTENT_TYPE, "text/plain"); + return response; + } + + ENDPOINT_INFO(State) { + info->summary = "Server state"; + info->description = "Check web server whether is ready."; + + info->addResponse(Status::CODE_200, "application/json"); + } + + ADD_CORS(State) + + ENDPOINT("GET", "/state", State) { + return createDtoResponse(Status::CODE_200, StatusDto::createShared()); + } + + ENDPOINT_INFO(GetDevices) { + info->summary = "Obtain system devices info"; + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(GetDevices) + + ENDPOINT("GET", "/devices", GetDevices) { + auto devices_dto = DevicesDto::createShared(); + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.GetDevices(devices_dto); + std::shared_ptr response; + if (0 == status_dto->code->getValue()) { + response = createDtoResponse(Status::CODE_200, devices_dto); + } else { + response = createDtoResponse(Status::CODE_400, status_dto); + } + + return response; + } + + ADD_CORS(AdvancedConfigOptions) + + ENDPOINT("OPTIONS", "/config/advanced", AdvancedConfigOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(GetAdvancedConfig) { + info->summary = "Obtain cache config and enging config"; + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(GetAdvancedConfig) + + ENDPOINT("GET", "/config/advanced", GetAdvancedConfig) { + auto config_dto = AdvancedConfigDto::createShared(); + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.GetAdvancedConfig(config_dto); + std::shared_ptr response; + if (0 == status_dto->code->getValue()) { + response = createDtoResponse(Status::CODE_200, config_dto); + } else { + response = createDtoResponse(Status::CODE_400, status_dto); + } + + return response; + } + + ENDPOINT_INFO(SetAdvancedConfig) { + info->summary = "Modify cache config and enging config"; + + info->addConsumes("application/json"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(SetAdvancedConfig) + + ENDPOINT("PUT", "/config/advanced", SetAdvancedConfig, BODY_DTO(AdvancedConfigDto::ObjectWrapper, body)) { + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + auto status_dto = handler.SetAdvancedConfig(body); + std::shared_ptr response; + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_200, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + +#ifdef MILVUS_GPU_VERSION + + ADD_CORS(GPUConfigOptions) + + ENDPOINT("OPTIONS", "/config/gpu_resources", GPUConfigOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(GetGPUConfig) { + info->summary = "Obtain GPU resources config info"; + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(GetGPUConfig) + + ENDPOINT("GET", "/config/gpu_resources", GetGPUConfig) { + auto gpu_config_dto = GPUConfigDto::createShared(); + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.GetGpuConfig(gpu_config_dto); + + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_200, gpu_config_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(SetGPUConfig) { + info->summary = "Set GPU resources config"; + info->addConsumes("application/json"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(SetGPUConfig) + + ENDPOINT("PUT", "/config/gpu_resources", SetGPUConfig, BODY_DTO(GPUConfigDto::ObjectWrapper, body)) { + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.SetGpuConfig(body); + + std::shared_ptr response; + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_200, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + +#endif + + ADD_CORS(TablesOptions) + + ENDPOINT("OPTIONS", "/tables", TablesOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(CreateTable) { + info->summary = "Create table"; + + info->addConsumes("application/json"); + + info->addResponse(Status::CODE_201, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(CreateTable) + + ENDPOINT("POST", "/tables", CreateTable, BODY_DTO(TableRequestDto::ObjectWrapper, body)) { + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + auto status_dto = handler.CreateTable(body); + if (0 != status_dto->code) { + return createDtoResponse(Status::CODE_400, status_dto); + } else { + return createDtoResponse(Status::CODE_201, status_dto); + } + } + + ENDPOINT_INFO(ShowTables) { + info->summary = "Show whole tables"; + + info->queryParams.add("offset"); + info->queryParams.add("page_size"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(ShowTables) + + ENDPOINT("GET", "/tables", ShowTables, QUERY(Int64, offset, "offset"), QUERY(Int64, page_size, "page_size")) { + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto response_dto = TableListFieldsDto::createShared(); + auto status_dto = handler.ShowTables(offset, page_size, response_dto); + std::shared_ptr response; + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_200, response_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ADD_CORS(TableOptions) + + ENDPOINT("OPTIONS", "/tables/{table_name}", TableOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(GetTable) { + info->summary = "Get table"; + + info->pathParams.add("table_name"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(GetTable) + + ENDPOINT("GET", "/tables/{table_name}", GetTable, PATH(String, table_name), QUERIES( + const QueryParams&, query_params)) { + auto error_status_dto = StatusDto::createShared(); + + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto fields_dto = TableFieldsDto::createShared(); + auto status_dto = handler.GetTable(table_name, query_params, fields_dto); + auto code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_200, fields_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(DropTable) { + info->summary = "Drop table"; + + info->pathParams.add("table_name"); + + info->addResponse(Status::CODE_204, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(DropTable) + + ENDPOINT("DELETE", "/tables/{table_name}", DropTable, PATH(String, table_name)) { + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.DropTable(table_name); + auto code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_204, status_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ADD_CORS(IndexOptions) + + ENDPOINT("OPTIONS", "/tables/{table_name}/indexes", IndexOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(CreateIndex) { + info->summary = "Create index"; + + info->pathParams.add("table_name"); + + info->addConsumes("application/json"); + info->addResponse(Status::CODE_201, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(CreateIndex) + + ENDPOINT("POST", + "/tables/{table_name}/indexes", + CreateIndex, + PATH(String, table_name), + BODY_DTO(IndexRequestDto::ObjectWrapper, body)) { + auto handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.CreateIndex(table_name, body); + + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_201, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(GetIndex) { + info->summary = "Describe index"; + + info->pathParams.add("table_name"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(GetIndex) + + ENDPOINT("GET", "/tables/{table_name}/indexes", GetIndex, PATH(String, table_name)) { + auto index_dto = IndexDto::createShared(); + auto handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.GetIndex(table_name, index_dto); + auto code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_200, index_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(DropIndex) { + info->summary = "Drop index"; + + info->pathParams.add("table_name"); + + info->addResponse(Status::CODE_204, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(DropIndex) + + ENDPOINT("DELETE", "/tables/{table_name}/indexes", DropIndex, PATH(String, table_name)) { + auto handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.DropIndex(table_name); + auto code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_204, status_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ADD_CORS(PartitionsOptions) + + ENDPOINT("OPTIONS", "/tables/{table_name}/partitions", PartitionsOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(CreatePartition) { + info->summary = "Create partition"; + + info->pathParams.add("table_name"); + + info->addConsumes("application/json"); + + info->addResponse(Status::CODE_201, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + } + + ADD_CORS(CreatePartition) + + ENDPOINT("POST", "/tables/{table_name}/partitions", + CreatePartition, PATH(String, table_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) { + auto handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.CreatePartition(table_name, body); + + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_201, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(ShowPartitions) { + info->summary = "Show partitions"; + + info->pathParams.add("table_name"); + + info->queryParams.add("offset"); + info->queryParams.add("page_size"); + + // + info->addResponse(Status::CODE_200, "application/json"); + // Error occurred. + info->addResponse(Status::CODE_400, "application/json"); + // + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(ShowPartitions) + + ENDPOINT("GET", + "/tables/{table_name}/partitions", + ShowPartitions, + PATH(String, table_name), + QUERY(Int64, offset, "offset"), + QUERY(Int64, page_size, "page_size")) { + auto status_dto = StatusDto::createShared(); + auto partition_list_dto = PartitionListDto::createShared(); + auto handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + status_dto = handler.ShowPartitions(offset, page_size, table_name, partition_list_dto); + int64_t code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_200, partition_list_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ADD_CORS(PartitionOptions) + + ENDPOINT("OPTIONS", "/tables/{table_name}/partitions/{partition_tag}", PartitionOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ENDPOINT_INFO(DropPartition) { + info->summary = "Drop partition"; + + info->pathParams.add("table_name"); + info->pathParams.add("partition_tag"); + + info->addResponse(Status::CODE_204, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(DropPartition) + + ENDPOINT("DELETE", + "/tables/{table_name}/partitions/{partition_tag}", + DropPartition, + PATH(String, table_name), + PATH(String, partition_tag)) { + auto handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.DropPartition(table_name, partition_tag); + auto code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_204, status_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(Insert) { + info->summary = "Insert vectors"; + + info->pathParams.add("table_name"); + + info->addConsumes("application/json"); + + info->addResponse(Status::CODE_201, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(VectorsOptions) + + ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) { + return createResponse(Status::CODE_204, "No Content"); + } + + ADD_CORS(Insert) + + ENDPOINT("POST", + "/tables/{table_name}/vectors", + Insert, + PATH(String, table_name), + BODY_DTO(InsertRequestDto::ObjectWrapper, body)) { + auto ids_dto = VectorIdsDto::createShared(); + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.Insert(table_name, body, ids_dto); + + int64_t code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_201, ids_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(Search) { + info->summary = "Search"; + + info->pathParams.add("table_name"); + + info->addConsumes("application/json"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(Search) + + ENDPOINT("PUT", + "/tables/{table_name}/vectors", + Search, + PATH(String, table_name), + BODY_DTO(SearchRequestDto::ObjectWrapper, body)) { + auto results_dto = TopkResultsDto::createShared(); + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.Search(table_name, body, results_dto); + int64_t code = status_dto->code->getValue(); + if (0 == code) { + return createDtoResponse(Status::CODE_200, results_dto); + } else if (StatusCode::TABLE_NOT_EXISTS == code) { + return createDtoResponse(Status::CODE_404, status_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + + ENDPOINT_INFO(SystemMsg) { + info->summary = "Command"; + + info->pathParams.add("cmd_str"); + + info->addResponse(Status::CODE_200, "application/json"); + info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); + } + + ADD_CORS(SystemMsg) + + ENDPOINT("GET", "/system/{msg}", SystemMsg, PATH(String, msg)) { + auto cmd_dto = CommandDto::createShared(); + + WebRequestHandler handler = WebRequestHandler(); + handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + auto status_dto = handler.Cmd(msg, cmd_dto); + + if (0 == status_dto->code->getValue()) { + return createDtoResponse(Status::CODE_200, cmd_dto); + } else { + return createDtoResponse(Status::CODE_400, status_dto); + } + } + +/** + * Finish ENDPOINTs generation ('ApiController' codegen) + */ +#include OATPP_CODEGEN_END(ApiController) + +}; + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/CmdDto.hpp b/core/src/server/web_impl/dto/CmdDto.hpp new file mode 100644 index 00000000..db4bd94b --- /dev/null +++ b/core/src/server/web_impl/dto/CmdDto.hpp @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class CommandDto: public oatpp::data::mapping::type::Object { + + DTO_INIT(CommandDto, Object) + + DTO_FIELD(String, reply, "reply"); +}; + +class CmdFieldsDto : public OObject { + DTO_INIT(CmdFieldsDto, Object) + + DTO_FIELD(Fields::ObjectWrapper, reply); +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus \ No newline at end of file diff --git a/core/src/server/web_impl/dto/ConfigDto.hpp b/core/src/server/web_impl/dto/ConfigDto.hpp new file mode 100644 index 00000000..524f36a0 --- /dev/null +++ b/core/src/server/web_impl/dto/ConfigDto.hpp @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/Constants.h" +#include "server/web_impl/dto/Dto.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class AdvancedConfigDto : public OObject { + DTO_INIT(AdvancedConfigDto, Object); + + DTO_FIELD(Int64, cpu_cache_capacity) = VALUE_CONFIG_CPU_CACHE_CAPACITY_DEFAULT; + DTO_FIELD(Boolean, cache_insert_data) = VALUE_CONFIG_CACHE_INSERT_DATA_DEFAULT; + DTO_FIELD(Int64, use_blas_threshold) = 1100; + +#ifdef MILVUS_GPU_VERSION + DTO_FIELD(Int64, gpu_search_threshold) = 1000; + +#endif +}; + +class GPUConfigDto : public OObject { + DTO_INIT(GPUConfigDto, Object); + + DTO_FIELD(Boolean, enable) = true; + DTO_FIELD(Int64, cache_capacity) = 1; + DTO_FIELD(List::ObjectWrapper, search_resources); + DTO_FIELD(List::ObjectWrapper, build_index_resources); +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/DevicesDto.hpp b/core/src/server/web_impl/dto/DevicesDto.hpp new file mode 100644 index 00000000..fb67a454 --- /dev/null +++ b/core/src/server/web_impl/dto/DevicesDto.hpp @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class DeviceInfoDto : public OObject { + DTO_INIT(DeviceInfoDto, Object); + + DTO_FIELD(Int64, memory); +}; + +class DevicesDto : public OObject { + DTO_INIT(DevicesDto, Object); + + DTO_FIELD(DeviceInfoDto::ObjectWrapper, cpu); + DTO_FIELD(Fields::ObjectWrapper, gpus); +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/Dto.h b/core/src/server/web_impl/dto/Dto.h new file mode 100644 index 00000000..2e12eb5e --- /dev/null +++ b/core/src/server/web_impl/dto/Dto.h @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "oatpp/core/data/mapping/type/Object.hpp" +#include "oatpp/core/macro/codegen.hpp" + +namespace milvus { +namespace server { +namespace web { + +using OObject = oatpp::data::mapping::type::Object; +} +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/IndexDto.hpp b/core/src/server/web_impl/dto/IndexDto.hpp new file mode 100644 index 00000000..ce9ed76b --- /dev/null +++ b/core/src/server/web_impl/dto/IndexDto.hpp @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" +#include "server/web_impl/dto/StatusDto.hpp" +#include "server/web_impl/Constants.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class IndexRequestDto : public oatpp::data::mapping::type::Object { + DTO_INIT(IndexRequestDto, Object) + + DTO_FIELD(String, index_type) = VALUE_INDEX_INDEX_TYPE_DEFAULT; + + DTO_FIELD(Int64, nlist) = VALUE_INDEX_NLIST_DEFAULT; +}; + +using IndexDto = IndexRequestDto; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/PartitionDto.hpp b/core/src/server/web_impl/dto/PartitionDto.hpp new file mode 100644 index 00000000..7d05db66 --- /dev/null +++ b/core/src/server/web_impl/dto/PartitionDto.hpp @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" +#include "server/web_impl/dto/StatusDto.hpp" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class PartitionRequestDto : public oatpp::data::mapping::type::Object { + DTO_INIT(PartitionRequestDto, Object) + + DTO_FIELD(String, partition_name); + DTO_FIELD(String, partition_tag); +}; + +using PartitionFieldsDto = PartitionRequestDto; + +class PartitionListDto : public oatpp::data::mapping::type::Object { + DTO_INIT(PartitionListDto, Object) + + DTO_FIELD(List::ObjectWrapper, partitions); +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/StatusDto.hpp b/core/src/server/web_impl/dto/StatusDto.hpp new file mode 100644 index 00000000..9e03c039 --- /dev/null +++ b/core/src/server/web_impl/dto/StatusDto.hpp @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class StatusDto: public oatpp::data::mapping::type::Object { + + DTO_INIT(StatusDto, Object) + + DTO_FIELD(String, message) = "Success"; + DTO_FIELD(Int64, code) = 0; +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/dto/TableDto.hpp b/core/src/server/web_impl/dto/TableDto.hpp new file mode 100644 index 00000000..21b5e7bb --- /dev/null +++ b/core/src/server/web_impl/dto/TableDto.hpp @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" +#include "server/web_impl/dto/StatusDto.hpp" +#include "server/web_impl/Constants.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + + +class TableRequestDto : public oatpp::data::mapping::type::Object { + DTO_INIT(TableRequestDto, Object) + + DTO_FIELD(String, table_name, "table_name"); + DTO_FIELD(Int64, dimension, "dimension"); + DTO_FIELD(Int64, index_file_size, "index_file_size") = VALUE_TABLE_INDEX_FILE_SIZE_DEFAULT; + DTO_FIELD(String, metric_type, "metric_type") = VALUE_TABLE_METRIC_TYPE_DEFAULT; +}; + +class TableFieldsDto : public oatpp::data::mapping::type::Object { + DTO_INIT(TableFieldsDto, Object) + + DTO_FIELD(String, table_name); + DTO_FIELD(Int64, dimension); + DTO_FIELD(Int64, index_file_size); + DTO_FIELD(String, metric_type); + DTO_FIELD(Int64, count); + DTO_FIELD(String, index); + DTO_FIELD(Int64, nlist); +}; + +class TableListDto : public OObject { + DTO_INIT(TableListDto, Object) + + DTO_FIELD(List::ObjectWrapper, table_names); +}; + +class TableListFieldsDto : public OObject { + DTO_INIT(TableListFieldsDto, Object) + + DTO_FIELD(List::ObjectWrapper, tables); + DTO_FIELD(Int64, count); +}; + +class TablesResponseDto : public OObject { + DTO_INIT(TablesResponseDto, Object) + + DTO_FIELD(TableListFieldsDto::ObjectWrapper, tables_fields); + DTO_FIELD(Int64, page_num); +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus \ No newline at end of file diff --git a/core/src/server/web_impl/dto/VectorDto.hpp b/core/src/server/web_impl/dto/VectorDto.hpp new file mode 100644 index 00000000..8dea7c0a --- /dev/null +++ b/core/src/server/web_impl/dto/VectorDto.hpp @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "server/web_impl/dto/Dto.h" +#include "server/web_impl/Constants.h" + +namespace milvus { +namespace server { +namespace web { + +#include OATPP_CODEGEN_BEGIN(DTO) + +class RowRecordDto : public oatpp::data::mapping::type::Object { + DTO_INIT(RowRecordDto, Object) + + DTO_FIELD(List::ObjectWrapper, record); +}; + +class RecordsDto : public oatpp::data::mapping::type::Object { + DTO_INIT(RecordsDto, Object) + + DTO_FIELD(List::ObjectWrapper, records); +}; + +class SearchRequestDto : public OObject { + DTO_INIT(SearchRequestDto, Object) + + DTO_FIELD(Int64, topk); + DTO_FIELD(Int64, nprobe); + DTO_FIELD(List::ObjectWrapper, tags); + DTO_FIELD(List::ObjectWrapper, file_ids); + DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, records); +}; + + +class InsertRequestDto : public oatpp::data::mapping::type::Object { + DTO_INIT(InsertRequestDto, Object) + + DTO_FIELD(String, tag) = VALUE_PARTITION_TAG_DEFAULT; + DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, records); + DTO_FIELD(List::ObjectWrapper, ids); +}; + +class VectorIdsDto : public oatpp::data::mapping::type::Object { + DTO_INIT(VectorIdsDto, Object) + + DTO_FIELD(List::ObjectWrapper, ids); +}; + +class ResultDto : public oatpp::data::mapping::type::Object { + DTO_INIT(ResultDto, Object) + +// DTO_FIELD(Int64, num); + DTO_FIELD(String, id); + DTO_FIELD(String, dit, "distance"); +}; + +class RowResultsDto : public OObject { + DTO_INIT(RowResultsDto, Object) + +// DTO_FIELD(List::ObjectWrapper, ); +}; + +class TopkResultsDto : public OObject { + DTO_INIT(TopkResultsDto, Object); + + DTO_FIELD(Int64, num); + DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, results); +}; + +#include OATPP_CODEGEN_END(DTO) + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp new file mode 100644 index 00000000..23fb231a --- /dev/null +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -0,0 +1,695 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "server/web_impl/handler/WebRequestHandler.h" + +#include +#include +#include +#include + +#include "metrics/SystemInfo.h" +#include "utils/Log.h" + +#include "server/Config.h" +#include "server/delivery/request/BaseRequest.h" +#include "server/web_impl/Constants.h" +#include "server/web_impl/Types.h" +#include "server/web_impl/dto/PartitionDto.hpp" + +namespace milvus { +namespace server { +namespace web { + +StatusCode +WebErrorMap(ErrorCode code) { + static const std::map code_map = { + {SERVER_UNEXPECTED_ERROR, StatusCode::UNEXPECTED_ERROR}, + {SERVER_UNSUPPORTED_ERROR, StatusCode::UNEXPECTED_ERROR}, + {SERVER_NULL_POINTER, StatusCode::UNEXPECTED_ERROR}, + {SERVER_INVALID_ARGUMENT, StatusCode::ILLEGAL_ARGUMENT}, + {SERVER_FILE_NOT_FOUND, StatusCode::FILE_NOT_FOUND}, + {SERVER_NOT_IMPLEMENT, StatusCode::UNEXPECTED_ERROR}, + {SERVER_CANNOT_CREATE_FOLDER, StatusCode::CANNOT_CREATE_FOLDER}, + {SERVER_CANNOT_CREATE_FILE, StatusCode::CANNOT_CREATE_FILE}, + {SERVER_CANNOT_DELETE_FOLDER, StatusCode::CANNOT_DELETE_FOLDER}, + {SERVER_CANNOT_DELETE_FILE, StatusCode::CANNOT_DELETE_FILE}, + {SERVER_TABLE_NOT_EXIST, StatusCode::TABLE_NOT_EXISTS}, + {SERVER_INVALID_TABLE_NAME, StatusCode::ILLEGAL_TABLE_NAME}, + {SERVER_INVALID_TABLE_DIMENSION, StatusCode::ILLEGAL_DIMENSION}, + {SERVER_INVALID_TIME_RANGE, StatusCode::ILLEGAL_RANGE}, + {SERVER_INVALID_VECTOR_DIMENSION, StatusCode::ILLEGAL_DIMENSION}, + + {SERVER_INVALID_INDEX_TYPE, StatusCode::ILLEGAL_INDEX_TYPE}, + {SERVER_INVALID_ROWRECORD, StatusCode::ILLEGAL_ROWRECORD}, + {SERVER_INVALID_ROWRECORD_ARRAY, StatusCode::ILLEGAL_ROWRECORD}, + {SERVER_INVALID_TOPK, StatusCode::ILLEGAL_TOPK}, + {SERVER_INVALID_NPROBE, StatusCode::ILLEGAL_ARGUMENT}, + {SERVER_INVALID_INDEX_NLIST, StatusCode::ILLEGAL_NLIST}, + {SERVER_INVALID_INDEX_METRIC_TYPE, StatusCode::ILLEGAL_METRIC_TYPE}, + {SERVER_INVALID_INDEX_FILE_SIZE, StatusCode::ILLEGAL_ARGUMENT}, + {SERVER_ILLEGAL_VECTOR_ID, StatusCode::ILLEGAL_VECTOR_ID}, + {SERVER_ILLEGAL_SEARCH_RESULT, StatusCode::ILLEGAL_SEARCH_RESULT}, + {SERVER_CACHE_FULL, StatusCode::CACHE_FAILED}, + {SERVER_BUILD_INDEX_ERROR, StatusCode::BUILD_INDEX_ERROR}, + {SERVER_OUT_OF_MEMORY, StatusCode::OUT_OF_MEMORY}, + + {DB_NOT_FOUND, StatusCode::TABLE_NOT_EXISTS}, + {DB_META_TRANSACTION_FAILED, StatusCode::META_FAILED}, + }; + + if (code_map.find(code) != code_map.end()) { + return code_map.at(code); + } else { + return StatusCode::UNEXPECTED_ERROR; + } +} + +///////////////////////// WebRequestHandler methods /////////////////////////////////////// + +Status +WebRequestHandler::GetTaleInfo(const std::shared_ptr& context, const std::string& table_name, + std::map& table_info) { + TableSchema schema; + auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema); + if (!status.ok()) { + return status; + } + + int64_t count; + status = request_handler_.CountTable(context_ptr_, table_name, count); + if (!status.ok()) { + return status; + } + + IndexParam index_param; + status = request_handler_.DescribeIndex(context_ptr_, table_name, index_param); + if (!status.ok()) { + return status; + } + + table_info[KEY_TABLE_TABLE_NAME] = schema.table_name_; + table_info[KEY_TABLE_DIMENSION] = std::to_string(schema.dimension_); + table_info[KEY_TABLE_INDEX_METRIC_TYPE] = std::string(MetricMap.at(engine::MetricType(schema.metric_type_))); + table_info[KEY_TABLE_INDEX_FILE_SIZE] = std::to_string(schema.index_file_size_); + + table_info[KEY_INDEX_INDEX_TYPE] = std::string(IndexMap.at(engine::EngineType(index_param.index_type_))); + table_info[KEY_INDEX_NLIST] = std::to_string(index_param.nlist_); + + table_info[KEY_TABLE_COUNT] = std::to_string(count); +} + +/////////////////////////////////////////// Router methods //////////////////////////////////////////// + +StatusDto::ObjectWrapper +WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) { + auto getgb = [](uint64_t x) -> uint64_t { return x / 1024 / 1024 / 1024; }; + auto system_info = SystemInfo::GetInstance(); + + devices_dto->cpu = devices_dto->cpu->createShared(); + devices_dto->cpu->memory = getgb(system_info.GetPhysicalMemory()); + + devices_dto->gpus = devices_dto->gpus->createShared(); + +#ifdef MILVUS_GPU_VERSION + + size_t count = system_info.num_device(); + std::vector device_mems = system_info.GPUMemoryTotal(); + + if (count != device_mems.size()) { + ASSIGN_RETURN_STATUS_DTO(Status(UNEXPECTED_ERROR, "Can't obtain GPU info")); + } + + for (size_t i = 0; i < count; i++) { + auto device_dto = DeviceInfoDto::createShared(); + device_dto->memory = getgb(device_mems.at(i)); + devices_dto->gpus->put("GPU" + OString(std::to_string(i).c_str()), device_dto); + } + +#endif + + ASSIGN_RETURN_STATUS_DTO(Status::OK()); +} + +StatusDto::ObjectWrapper +WebRequestHandler::GetAdvancedConfig(AdvancedConfigDto::ObjectWrapper& advanced_config) { + Config& config = Config::GetInstance(); + + int64_t value; + auto status = config.GetCacheConfigCpuCacheCapacity(value); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + advanced_config->cpu_cache_capacity = value; + + bool ok; + status = config.GetCacheConfigCacheInsertData(ok); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + advanced_config->cache_insert_data = ok; + + status = config.GetEngineConfigUseBlasThreshold(value); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + advanced_config->use_blas_threshold = value; + +#ifdef MILVUS_GPU_VERSION + + status = config.GetEngineConfigGpuSearchThreshold(value); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + advanced_config->gpu_search_threshold = value; + +#endif + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::SetAdvancedConfig(const AdvancedConfigDto::ObjectWrapper& advanced_config) { + Config& config = Config::GetInstance(); + + if (nullptr == advanced_config->cpu_cache_capacity.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cpu_cache_capacity\' miss."); + } + auto status = + config.SetCacheConfigCpuCacheCapacity(std::to_string(advanced_config->cpu_cache_capacity->getValue())); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + if (nullptr == advanced_config->cache_insert_data.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_insert_data\' miss."); + } + status = config.SetCacheConfigCacheInsertData(std::to_string(advanced_config->cache_insert_data->getValue())); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + if (nullptr == advanced_config->use_blas_threshold.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'use_blas_threshold\' miss."); + } + status = config.SetEngineConfigUseBlasThreshold(std::to_string(advanced_config->use_blas_threshold->getValue())); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + +#ifdef MILVUS_GPU_VERSION + + if (nullptr == advanced_config->gpu_search_threshold.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'gpu_search_threshold\' miss."); + } + status = + config.SetEngineConfigGpuSearchThreshold(std::to_string(advanced_config->gpu_search_threshold->getValue())); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + +#endif + + ASSIGN_RETURN_STATUS_DTO(status) +} + +#ifdef MILVUS_GPU_VERSION + +StatusDto::ObjectWrapper +WebRequestHandler::GetGpuConfig(GPUConfigDto::ObjectWrapper& gpu_config_dto) { + Config& config = Config::GetInstance(); + + bool enable; + auto status = config.GetGpuResourceConfigEnable(enable); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + gpu_config_dto->enable = enable; + + if (!enable) { + ASSIGN_RETURN_STATUS_DTO(Status::OK()); + } + + int64_t capacity; + status = config.GetGpuResourceConfigCacheCapacity(capacity); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + gpu_config_dto->cache_capacity = capacity; + + std::vector values; + status = config.GetGpuResourceConfigSearchResources(values); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + + gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared(); + for (auto& device_id : values) { + gpu_config_dto->search_resources->pushBack("GPU" + OString(std::to_string(device_id).c_str())); + } + + values.clear(); + status = config.GetGpuResourceConfigBuildIndexResources(values); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + + gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared(); + for (auto& device_id : values) { + gpu_config_dto->build_index_resources->pushBack("GPU" + OString(std::to_string(device_id).c_str())); + } + + ASSIGN_RETURN_STATUS_DTO(Status::OK()); +} + +#endif + +#ifdef MILVUS_GPU_VERSION + +StatusDto::ObjectWrapper +WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dto) { + Config& config = Config::GetInstance(); + + if (nullptr == gpu_config_dto->enable.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'enable\' miss") + } + auto status = config.SetGpuResourceConfigEnable(std::to_string(gpu_config_dto->enable->getValue())); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + + if (!gpu_config_dto->enable->getValue()) { + RETURN_STATUS_DTO(SUCCESS, "Set Gpu resources false"); + } + + if (nullptr == gpu_config_dto->cache_capacity.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_capacity\' miss") + } + status = config.SetGpuResourceConfigCacheCapacity(std::to_string(gpu_config_dto->cache_capacity->getValue())); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + + if (nullptr == gpu_config_dto->search_resources.get()) { + gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared(); + gpu_config_dto->search_resources->pushBack("GPU0"); + } + + std::vector search_resources; + gpu_config_dto->search_resources->forEach( + [&search_resources](const OString& res) { search_resources.emplace_back(res->toLowerCase()->std_str()); }); + + std::string search_resources_value; + for (auto& res : search_resources) { + search_resources_value += res + ","; + } + auto len = search_resources_value.size(); + if (len > 0) { + search_resources_value.erase(len - 1); + } + status = config.SetGpuResourceConfigSearchResources(search_resources_value); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + + if (nullptr == gpu_config_dto->build_index_resources.get()) { + gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared(); + gpu_config_dto->build_index_resources->pushBack("GPU0"); + } + std::vector build_resources; + gpu_config_dto->build_index_resources->forEach( + [&build_resources](const OString& res) { build_resources.emplace_back(res->toLowerCase()->std_str()); }); + + std::string build_resources_value; + for (auto& res : build_resources) { + build_resources_value += res + ","; + } + len = build_resources_value.size(); + if (len > 0) { + build_resources_value.erase(len - 1); + } + + status = config.SetGpuResourceConfigBuildIndexResources(build_resources_value); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); + } + + ASSIGN_RETURN_STATUS_DTO(Status::OK()); +} + +#endif + +StatusDto::ObjectWrapper +WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& table_schema) { + if (nullptr == table_schema->table_name.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'table_name\' is missing") + } + + if (nullptr == table_schema->dimension.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'dimension\' is missing") + } + + if (nullptr == table_schema->index_file_size.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_file_size\' is missing") + } + + if (nullptr == table_schema->metric_type.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'metric_type\' is missing") + } + + if (MetricNameMap.find(table_schema->metric_type->std_str()) == MetricNameMap.end()) { + RETURN_STATUS_DTO(ILLEGAL_METRIC_TYPE, "metric_type is illegal") + } + + auto status = request_handler_.CreateTable( + context_ptr_, table_schema->table_name->std_str(), table_schema->dimension, table_schema->index_file_size, + static_cast(MetricNameMap.at(table_schema->metric_type->std_str()))); + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query_params, + TableFieldsDto::ObjectWrapper& fields_dto) { + if (nullptr == table_name.get()) { + RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'table_name\' is required!"); + } + + Status status = Status::OK(); + + // TODO: query string field `fields` npt used here + std::map table_info; + status = GetTaleInfo(context_ptr_, table_name->std_str(), table_info); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + fields_dto->table_name = table_info[KEY_TABLE_TABLE_NAME].c_str(); + fields_dto->dimension = std::stol(table_info[KEY_TABLE_DIMENSION]); + fields_dto->index = table_info[KEY_INDEX_INDEX_TYPE].c_str(); + fields_dto->nlist = std::stol(table_info[KEY_INDEX_NLIST]); + fields_dto->metric_type = table_info[KEY_TABLE_INDEX_METRIC_TYPE].c_str(); + fields_dto->index_file_size = std::stol(table_info[KEY_TABLE_INDEX_FILE_SIZE]); + fields_dto->count = std::stol(table_info[KEY_TABLE_COUNT]); + + ASSIGN_RETURN_STATUS_DTO(status); +} + +StatusDto::ObjectWrapper +WebRequestHandler::ShowTables(const OInt64& offset, const OInt64& page_size, + TableListFieldsDto::ObjectWrapper& response_dto) { + if (nullptr == offset.get()) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'offset\' is required"); + } + + if (nullptr == page_size.get()) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'page_size\' is required"); + } + std::vector tables; + Status status = Status::OK(); + + response_dto->tables = response_dto->tables->createShared(); + + if (offset < 0 || page_size < 0) { + ASSIGN_RETURN_STATUS_DTO( + Status(SERVER_UNEXPECTED_ERROR, "Query param 'offset' or 'page_size' should bigger than 0")); + } else { + status = request_handler_.ShowTables(context_ptr_, tables); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + if (offset < tables.size()) { + int64_t size = (page_size->getValue() + offset->getValue() > tables.size()) ? tables.size() - offset + : page_size->getValue(); + for (int64_t i = offset->getValue(); i < size + offset->getValue(); i++) { + std::map table_info; + + status = GetTaleInfo(context_ptr_, tables.at(i), table_info); + if (!status.ok()) { + break; + } + + auto table_fields_dto = TableFieldsDto::createShared(); + table_fields_dto->table_name = table_info[KEY_TABLE_TABLE_NAME].c_str(); + table_fields_dto->dimension = std::stol(table_info[std::string(KEY_TABLE_DIMENSION)]); + table_fields_dto->index_file_size = std::stol(table_info[std::string(KEY_TABLE_INDEX_FILE_SIZE)]); + table_fields_dto->index = table_info[KEY_INDEX_INDEX_TYPE].c_str(); + table_fields_dto->nlist = std::stol(table_info[KEY_INDEX_NLIST]); + table_fields_dto->metric_type = table_info[KEY_TABLE_INDEX_METRIC_TYPE].c_str(); + table_fields_dto->count = std::stol(table_info[KEY_TABLE_COUNT]); + + response_dto->tables->pushBack(table_fields_dto); + } + + response_dto->count = tables.size(); + } + } + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::DropTable(const OString& table_name) { + auto status = request_handler_.DropTable(context_ptr_, table_name->std_str()); + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::CreateIndex(const OString& table_name, const IndexRequestDto::ObjectWrapper& index_param) { + if (nullptr == index_param->index_type.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'index_type\' is required") + } + std::string index_type = index_param->index_type->std_str(); + if (IndexNameMap.find(index_type) == IndexNameMap.end()) { + RETURN_STATUS_DTO(ILLEGAL_INDEX_TYPE, "The index type is invalid.") + } + + if (nullptr == index_param->nlist.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'nlist\' is required") + } + + auto status = + request_handler_.CreateIndex(context_ptr_, table_name->std_str(), + static_cast(IndexNameMap.at(index_type)), index_param->nlist->getValue()); + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::GetIndex(const OString& table_name, IndexDto::ObjectWrapper& index_dto) { + IndexParam param; + auto status = request_handler_.DescribeIndex(context_ptr_, table_name->std_str(), param); + + if (status.ok()) { + index_dto->index_type = IndexMap.at(engine::EngineType(param.index_type_)).c_str(); + index_dto->nlist = param.nlist_; + } + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::DropIndex(const OString& table_name) { + auto status = request_handler_.DropIndex(context_ptr_, table_name->std_str()); + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::CreatePartition(const OString& table_name, const PartitionRequestDto::ObjectWrapper& param) { + if (nullptr == param->partition_name.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'partition_name\' is required") + } + + if (nullptr == param->partition_tag.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'partition_tag\' is required") + } + + auto status = request_handler_.CreatePartition(context_ptr_, table_name->std_str(), + param->partition_name->std_str(), param->partition_tag->std_str()); + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::ShowPartitions(const OInt64& offset, const OInt64& page_size, const OString& table_name, + PartitionListDto::ObjectWrapper& partition_list_dto) { + if (nullptr == offset.get()) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'offset\' is required!"); + } + + if (nullptr == page_size.get()) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'page_size\' is required!"); + } + + std::vector partitions; + auto status = request_handler_.ShowPartitions(context_ptr_, table_name->std_str(), partitions); + + if (status.ok()) { + partition_list_dto->partitions = partition_list_dto->partitions->createShared(); + + if (offset->getValue() < partitions.size()) { + int64_t size = (offset->getValue() + page_size->getValue() > partitions.size()) ? partitions.size() - offset + : page_size->getValue(); + for (int64_t i = offset->getValue(); i < size + offset->getValue(); i++) { + auto partition_dto = PartitionFieldsDto::createShared(); + partition_dto->partition_name = partitions.at(i).partition_name_.c_str(); + partition_dto->partition_tag = partitions.at(i).tag_.c_str(); + partition_list_dto->partitions->pushBack(partition_dto); + } + } + } + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::DropPartition(const OString& table_name, const OString& tag) { + auto status = request_handler_.DropPartition(context_ptr_, table_name->std_str(), "", tag->std_str()); + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& param, + VectorIdsDto::ObjectWrapper& ids_dto) { + std::vector ids; + if (nullptr != param->ids.get() && param->ids->count() > 0) { + for (int64_t i = 0; i < param->ids->count(); i++) { + ids.emplace_back(param->ids->get(i)->getValue()); + } + } + + if (nullptr == param->records.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors") + } + + size_t tal_size = 0; + for (int64_t i = 0; i < param->records->count(); i++) { + tal_size += param->records->get(i)->count(); + } + + std::vector datas(tal_size); + size_t index_offset = 0; + param->records->forEach([&datas, &index_offset](const OList::ObjectWrapper& row_item) { + row_item->forEach([&datas, &index_offset](const OFloat32& item) { + datas[index_offset] = item->getValue(); + index_offset++; + }); + }); + + auto status = request_handler_.Insert(context_ptr_, table_name->std_str(), param->records->count(), datas, + param->tag->std_str(), ids); + + if (status.ok()) { + ids_dto->ids = ids_dto->ids->createShared(); + for (auto& id : ids) { + ids_dto->ids->pushBack(std::to_string(id).c_str()); + } + } + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::ObjectWrapper& search_request, + TopkResultsDto::ObjectWrapper& results_dto) { + if (nullptr == search_request->topk.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'topk\' is required in request body") + } + int64_t topk_t = search_request->topk->getValue(); + + if (nullptr == search_request->nprobe.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'nprobe\' is required in request body") + } + int64_t nprobe_t = search_request->nprobe->getValue(); + + std::vector tag_list; + std::vector file_id_list; + + if (nullptr != search_request->tags.get()) { + search_request->tags->forEach([&tag_list](const OString& tag) { tag_list.emplace_back(tag->std_str()); }); + } + + if (nullptr != search_request->file_ids.get()) { + search_request->file_ids->forEach( + [&file_id_list](const OString& id) { file_id_list.emplace_back(id->std_str()); }); + } + + if (nullptr == search_request->records.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill query vectors") + } + + size_t tal_size = 0; + search_request->records->forEach( + [&tal_size](const OList::ObjectWrapper& item) { tal_size += item->count(); }); + + std::vector datas(tal_size); + size_t index_offset = 0; + search_request->records->forEach([&datas, &index_offset](const OList::ObjectWrapper& elem) { + elem->forEach([&datas, &index_offset](const OFloat32& item) { + datas[index_offset] = item->getValue(); + index_offset++; + }); + }); + + std::vector range_list; + + TopKQueryResult result; + auto context_ptr = GenContextPtr("Web Handler"); + auto status = request_handler_.Search(context_ptr, table_name->std_str(), search_request->records->count(), datas, + range_list, topk_t, nprobe_t, tag_list, file_id_list, result); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + results_dto->num = result.row_num_; + results_dto->results = results_dto->results->createShared(); + if (0 == result.row_num_) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + auto step = result.id_list_.size() / result.row_num_; + for (size_t i = 0; i < result.row_num_; i++) { + auto row_result_dto = OList::createShared(); + for (size_t j = 0; j < step; j++) { + auto result_dto = ResultDto::createShared(); + result_dto->id = std::to_string(result.id_list_.at(i * step + j)).c_str(); + result_dto->dit = std::to_string(result.distance_list_.at(i * step + j)).c_str(); + row_result_dto->pushBack(result_dto); + } + results_dto->results->pushBack(row_result_dto); + } + + ASSIGN_RETURN_STATUS_DTO(status) +} + +StatusDto::ObjectWrapper +WebRequestHandler::Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto) { + std::string reply_str; + auto status = request_handler_.Cmd(context_ptr_, cmd->std_str(), reply_str); + + if (status.ok()) { + cmd_dto->reply = reply_str.c_str(); + } + + ASSIGN_RETURN_STATUS_DTO(status); +} + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h new file mode 100644 index 00000000..af645531 --- /dev/null +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "server/web_impl/Types.h" +#include "server/web_impl/dto/CmdDto.hpp" +#include "server/web_impl/dto/ConfigDto.hpp" +#include "server/web_impl/dto/DevicesDto.hpp" +#include "server/web_impl/dto/IndexDto.hpp" +#include "server/web_impl/dto/PartitionDto.hpp" +#include "server/web_impl/dto/TableDto.hpp" +#include "server/web_impl/dto/VectorDto.hpp" + +#include "server/context/Context.h" +#include "server/delivery/RequestHandler.h" +#include "utils/Status.h" + +namespace milvus { +namespace server { +namespace web { + +#define RETURN_STATUS_DTO(STATUS_CODE, MESSAGE) \ + do { \ + auto status_dto = StatusDto::createShared(); \ + status_dto->code = (STATUS_CODE); \ + status_dto->message = (MESSAGE); \ + return status_dto; \ + } while (false); + +#define ASSIGN_RETURN_STATUS_DTO(STATUS) \ + do { \ + int code; \ + if (0 != (STATUS).code()) { \ + code = WebErrorMap((STATUS).code()); \ + } else { \ + code = 0; \ + } \ + RETURN_STATUS_DTO(code, (STATUS).message().c_str()) \ + } while (false); + +StatusCode +WebErrorMap(ErrorCode code); + +class WebRequestHandler { + private: + std::shared_ptr + GenContextPtr(const std::string& context_str) { + auto context_ptr = std::make_shared("dummy_request_id"); + opentracing::mocktracer::MockTracerOptions tracer_options; + auto mock_tracer = + std::shared_ptr{new opentracing::mocktracer::MockTracer{std::move(tracer_options)}}; + auto mock_span = mock_tracer->StartSpan("mock_span"); + auto trace_context = std::make_shared(mock_span); + context_ptr->SetTraceContext(trace_context); + + return context_ptr; + } + + public: + WebRequestHandler() { + context_ptr_ = GenContextPtr("Web Handler"); + } + + Status + GetTaleInfo(const std::shared_ptr& context, const std::string& table_name, + std::map& table_info); + + StatusDto::ObjectWrapper + GetDevices(DevicesDto::ObjectWrapper& devices); + + StatusDto::ObjectWrapper + GetAdvancedConfig(AdvancedConfigDto::ObjectWrapper& config); + + StatusDto::ObjectWrapper + SetAdvancedConfig(const AdvancedConfigDto::ObjectWrapper& config); + +#ifdef MILVUS_GPU_VERSION + StatusDto::ObjectWrapper + GetGpuConfig(GPUConfigDto::ObjectWrapper& gpu_config_dto); + + StatusDto::ObjectWrapper + SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dto); +#endif + + StatusDto::ObjectWrapper + CreateTable(const TableRequestDto::ObjectWrapper& table_schema); + + StatusDto::ObjectWrapper + GetTable(const OString& table_name, const OQueryParams& query_params, TableFieldsDto::ObjectWrapper& schema_dto); + + StatusDto::ObjectWrapper + ShowTables(const OInt64& offset, const OInt64& page_size, TableListFieldsDto::ObjectWrapper& table_list_dto); + + StatusDto::ObjectWrapper + DropTable(const OString& table_name); + + StatusDto::ObjectWrapper + CreateIndex(const OString& table_name, const IndexRequestDto::ObjectWrapper& index_param); + + StatusDto::ObjectWrapper + GetIndex(const OString& table_name, IndexDto::ObjectWrapper& index_dto); + + StatusDto::ObjectWrapper + DropIndex(const OString& table_name); + + StatusDto::ObjectWrapper + CreatePartition(const OString& table_name, const PartitionRequestDto::ObjectWrapper& param); + + StatusDto::ObjectWrapper + ShowPartitions(const OInt64& offset, const OInt64& page_size, const OString& table_name, + PartitionListDto::ObjectWrapper& partition_list_dto); + + StatusDto::ObjectWrapper + DropPartition(const OString& table_name, const OString& tag); + + StatusDto::ObjectWrapper + Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& param, + VectorIdsDto::ObjectWrapper& ids_dto); + + StatusDto::ObjectWrapper + Search(const OString& table_name, const SearchRequestDto::ObjectWrapper& search_request, + TopkResultsDto::ObjectWrapper& results_dto); + + StatusDto::ObjectWrapper + Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto); + + WebRequestHandler& + RegisterRequestHandler(const RequestHandler& handler) { + request_handler_ = handler; + } + + private: + std::shared_ptr context_ptr_; + RequestHandler request_handler_; +}; + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/thirdparty/versions.txt b/core/thirdparty/versions.txt index 551a68bb..a3790205 100644 --- a/core/thirdparty/versions.txt +++ b/core/thirdparty/versions.txt @@ -11,6 +11,7 @@ GRPC_VERSION=master ZLIB_VERSION=v1.2.11 OPENTRACING_VERSION=v1.5.1 FIU_VERSION=1.00 +OATPP_VERSION=0.19.11 AWS_VERSION=1.7.250 # vim: set filetype=sh: diff --git a/core/unittest/CMakeLists.txt b/core/unittest/CMakeLists.txt index ac5f5121..9103c927 100644 --- a/core/unittest/CMakeLists.txt +++ b/core/unittest/CMakeLists.txt @@ -74,6 +74,19 @@ set(grpc_server_files ${grpc_interceptor_files} ) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/handler web_handler_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/component web_conponent_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/controller web_controller_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/dto web_dto_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl web_impl_files) +set(web_server_files + ${web_handler_files} + ${web_conponent_files} + ${web_controller_files} + ${web_dto_files} + ${web_impl_files} + ) + aux_source_directory(${MILVUS_ENGINE_SRC}/server/delivery server_delivery_impl_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/delivery/request server_delivery_request_files) set(server_delivery_files diff --git a/core/unittest/server/CMakeLists.txt b/core/unittest/server/CMakeLists.txt index 8c01ab5a..f8cf03ea 100644 --- a/core/unittest/server/CMakeLists.txt +++ b/core/unittest/server/CMakeLists.txt @@ -21,6 +21,7 @@ set(test_files ${CMAKE_CURRENT_SOURCE_DIR}/test_cache.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_config.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_rpc.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_web.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_util.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) @@ -48,6 +49,7 @@ set(server_test_files ${grpc_server_files} ${grpc_service_files} ${server_delivery_files} + ${web_server_files} ${util_files} ${entry_file} ${test_files} @@ -60,13 +62,15 @@ set(grpc_lib grpc++ grpc grpc_protobuf - grpc_protoc) + grpc_protoc + ) target_link_libraries(test_server knowhere stdc++ ${grpc_lib} ${unittest_libs} + oatpp ) install(TARGETS test_server DESTINATION unittest) diff --git a/core/unittest/server/test_config.cpp b/core/unittest/server/test_config.cpp index 6ba839e0..7b3aba30 100644 --- a/core/unittest/server/test_config.cpp +++ b/core/unittest/server/test_config.cpp @@ -135,6 +135,11 @@ TEST_F(ConfigTest, SERVER_CONFIG_VALID_TEST) { ASSERT_TRUE(config.GetServerConfigPort(str_val).ok()); ASSERT_TRUE(str_val == server_port); + std::string web_port = "19999"; + ASSERT_TRUE(config.SetServerConfigWebPort(web_port).ok()); + ASSERT_TRUE(config.GetServerConfigWebPort(str_val).ok()); + ASSERT_TRUE(str_val == web_port); + std::string server_mode = "cluster_readonly"; ASSERT_TRUE(config.SetServerConfigDeployMode(server_mode).ok()); ASSERT_TRUE(config.GetServerConfigDeployMode(str_val).ok()); @@ -479,6 +484,10 @@ TEST_F(ConfigTest, SERVER_CONFIG_INVALID_TEST) { ASSERT_FALSE(config.SetServerConfigPort("a").ok()); ASSERT_FALSE(config.SetServerConfigPort("99999").ok()); + ASSERT_FALSE(config.SetServerConfigWebPort("a").ok()); + ASSERT_FALSE(config.SetServerConfigWebPort("99999").ok()); + ASSERT_FALSE(config.SetServerConfigWebPort("-1").ok()); + ASSERT_FALSE(config.SetServerConfigDeployMode("cluster").ok()); ASSERT_FALSE(config.SetServerConfigTimeZone("GM").ok()); diff --git a/core/unittest/server/test_rpc.cpp b/core/unittest/server/test_rpc.cpp index 531956e0..365fbf10 100644 --- a/core/unittest/server/test_rpc.cpp +++ b/core/unittest/server/test_rpc.cpp @@ -231,7 +231,7 @@ TEST_F(RpcHandlerTest, INSERT_TEST) { ASSERT_EQ(vector_ids.vector_id_array_size(), VECTOR_COUNT); // insert vectors with wrong dim - std::vector record_wrong_dim(TABLE_DIM - 1, 0.5f); + std::vector record_wrong_dim(TABLE_DIM - 1, 0.5f); ::milvus::grpc::RowRecord* grpc_record = request.add_row_record_array(); CopyRowRecord(grpc_record, record_wrong_dim); handler->Insert(&context, &request, &vector_ids); diff --git a/core/unittest/server/test_web.cpp b/core/unittest/server/test_web.cpp new file mode 100644 index 00000000..1624c7e1 --- /dev/null +++ b/core/unittest/server/test_web.cpp @@ -0,0 +1,991 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "wrapper/VecIndex.h" + +#include "server/Server.h" +#include "server/delivery/RequestScheduler.h" +#include "server/delivery/request/BaseRequest.h" +#include "server/delivery/RequestHandler.h" +#include "src/version.h" + +#include "server/web_impl/handler/WebRequestHandler.h" +#include "server/web_impl/dto/TableDto.hpp" +#include "server/web_impl/dto/StatusDto.hpp" +#include "server/web_impl/dto/VectorDto.hpp" +#include "server/web_impl/dto/IndexDto.hpp" +#include "server/web_impl/component/AppComponent.hpp" +#include "server/web_impl/controller/WebController.hpp" +#include "server/web_impl/Types.h" +#include "server/web_impl/WebServer.h" + +#include "scheduler/ResourceFactory.h" +#include "scheduler/SchedInst.h" +#include "server/Config.h" +#include "server/DBWrapper.h" +#include "utils/CommonUtil.h" + +static const char* TABLE_NAME = "test_web"; +static constexpr int64_t TABLE_DIM = 256; +static constexpr int64_t INDEX_FILE_SIZE = 1024; +static constexpr int64_t VECTOR_COUNT = 1000; +static constexpr int64_t INSERT_LOOP = 10; +constexpr int64_t SECONDS_EACH_HOUR = 3600; + +using OStatus = oatpp::web::protocol::http::Status; +using OString = milvus::server::web::OString; +using OQueryParams = milvus::server::web::OQueryParams; +using OChunkedBuffer = oatpp::data::stream::ChunkedBuffer; +using OOutputStream = oatpp::data::stream::BufferOutputStream; +using OFloat32 = milvus::server::web::OFloat32; +template +using OList = milvus::server::web::OList; + +using StatusCode = milvus::server::web::StatusCode; + +namespace { + +OList::ObjectWrapper +RandomRowRecordDto(int64_t dim) { + auto row_record_dto = OList::createShared(); + + std::default_random_engine e; + std::uniform_real_distribution u(0, 1); + for (size_t i = 0; i < dim; i++) { + row_record_dto->pushBack(u(e)); + } + + return row_record_dto; +} + +OList::ObjectWrapper>::ObjectWrapper +RandomRecordsDto(int64_t dim, int64_t num) { + auto records_dto = OList::ObjectWrapper>::createShared(); + for (size_t i = 0; i < num; i++) { + records_dto->pushBack(RandomRowRecordDto(dim)); + } + + return records_dto; +} + +std::string +RandomName() { + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::default_random_engine e(seed); + std::uniform_int_distribution u(0, 1000000); + + size_t name_len = u(e) % 16 + 3; + + char* name = new char[name_len + 1]; + name[name_len] = '\0'; + + for (size_t i = 0; i < name_len; i++) { + unsigned random_i = u(e); + char remainder = static_cast(random_i % 26); + name[i] = (random_i % 2 == 0) ? 'A' + remainder : 'a' + remainder; + } + + std::string random_name(name); + + delete[] name; + + return random_name; +} + +} // namespace + +namespace { + +class WebHandlerTest : public testing::Test { + protected: + static void + SetUpTestCase() { + auto res_mgr = milvus::scheduler::ResMgrInst::GetInstance(); + res_mgr->Clear(); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("disk", "DISK", 0, false)); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("cpu", "CPU", 0)); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("gtx1660", "GPU", 0)); + + auto default_conn = milvus::scheduler::Connection("IO", 500.0); + auto PCIE = milvus::scheduler::Connection("IO", 11000.0); + res_mgr->Connect("disk", "cpu", default_conn); + res_mgr->Connect("cpu", "gtx1660", PCIE); + res_mgr->Start(); + milvus::scheduler::SchedInst::GetInstance()->Start(); + milvus::scheduler::JobMgrInst::GetInstance()->Start(); + + milvus::engine::DBOptions opt; + + milvus::server::Config::GetInstance().SetDBConfigBackendUrl("sqlite://:@:/"); + boost::filesystem::remove_all("/tmp/milvus_web_handler_test"); + milvus::server::Config::GetInstance().SetStorageConfigPrimaryPath("/tmp/milvus_web_handler_test"); + milvus::server::Config::GetInstance().SetStorageConfigSecondaryPath(""); + milvus::server::Config::GetInstance().SetDBConfigArchiveDiskThreshold(""); + milvus::server::Config::GetInstance().SetDBConfigArchiveDaysThreshold(""); + milvus::server::Config::GetInstance().SetCacheConfigCacheInsertData(""); + milvus::server::Config::GetInstance().SetEngineConfigOmpThreadNum(""); + + milvus::server::DBWrapper::GetInstance().StartService(); + } + + void + SetUp() override { + handler = std::make_shared(); + } + + void + TearDown() override { + } + + static void + TearDownTestCase() { + milvus::server::DBWrapper::GetInstance().StopService(); + milvus::scheduler::JobMgrInst::GetInstance()->Stop(); + milvus::scheduler::ResMgrInst::GetInstance()->Stop(); + milvus::scheduler::SchedInst::GetInstance()->Stop(); + boost::filesystem::remove_all("/tmp/milvus_web_handler_test"); + } + + protected: + void + GenTable(const std::string& table_name, int64_t dim, int64_t index_size, const std::string& metric) { + auto table_dto = milvus::server::web::TableRequestDto::createShared(); + table_dto->table_name = table_name.c_str(); + table_dto->dimension = dim; + table_dto->index_file_size = index_size; + table_dto->metric_type = metric.c_str(); + + auto status_dto = handler->CreateTable(table_dto); + } + + protected: + std::shared_ptr handler; + std::shared_ptr dummy_context; +}; + +} // namespace + +TEST_F(WebHandlerTest, TABLE) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + + auto table_dto = milvus::server::web::TableRequestDto::createShared(); + table_dto->table_name = table_name; + table_dto->dimension = TABLE_DIM + 100000; + table_dto->index_file_size = INDEX_FILE_SIZE; + table_dto->metric_type = "L2"; + + // invalid dimension + auto status_dto = handler->CreateTable(table_dto); + ASSERT_EQ(StatusCode::ILLEGAL_DIMENSION, status_dto->code->getValue()); + + // invalid index file size + table_dto->dimension = TABLE_DIM; + table_dto->index_file_size = -1; + status_dto = handler->CreateTable(table_dto); + ASSERT_EQ(StatusCode::ILLEGAL_ARGUMENT, status_dto->code->getValue()); + + // invalid metric type + table_dto->index_file_size = INDEX_FILE_SIZE; + table_dto->metric_type = "L1"; + status_dto = handler->CreateTable(table_dto); + ASSERT_EQ(StatusCode::ILLEGAL_METRIC_TYPE, status_dto->code->getValue()); + + // create table successfully + table_dto->metric_type = "L2"; + status_dto = handler->CreateTable(table_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + + sleep(3); + + status_dto = handler->DropTable(table_name); + ASSERT_EQ(0, status_dto->code->getValue()); + + // drop table which not exists. + status_dto = handler->DropTable(table_name + "57575yfhfdhfhdh436gdsgpppdgsgv3233"); + ASSERT_EQ(StatusCode::TABLE_NOT_EXISTS, status_dto->code->getValue()); +} + +TEST_F(WebHandlerTest, HAS_TABLE_TEST) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + + GenTable(table_name->std_str(), 10, 10, "L2"); + + milvus::server::web::OQueryParams query_params; + auto tables_dto = milvus::server::web::TableFieldsDto::createShared(); + auto status_dto = handler->GetTable(table_name, query_params, tables_dto); + ASSERT_EQ(0, status_dto->code->getValue()); +} + +TEST_F(WebHandlerTest, GET_TABLE) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + GenTable(table_name->std_str(), 10, 10, "L2"); + + milvus::server::web::OQueryParams query_params; + auto table_dto = milvus::server::web::TableFieldsDto::createShared(); + auto status_dto = handler->GetTable(table_name, query_params, table_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + ASSERT_EQ(10, table_dto->dimension->getValue()); + ASSERT_EQ(10, table_dto->index_file_size->getValue()); + ASSERT_EQ("L2", table_dto->metric_type->std_str()); +} + +TEST_F(WebHandlerTest, INSERT_COUNT) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + GenTable(table_name->std_str(), 16, 10, "L2"); + + auto insert_request_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_request_dto->records = insert_request_dto->records->createShared(); + for (size_t i = 0; i < 1000; i++) { + insert_request_dto->records->pushBack(RandomRowRecordDto(16)); + } + insert_request_dto->ids = insert_request_dto->ids->createShared(); + + auto ids_dto = milvus::server::web::VectorIdsDto::createShared(); + + auto status_dto = handler->Insert(table_name, insert_request_dto, ids_dto); + + ASSERT_EQ(0, status_dto->code->getValue()); + ASSERT_EQ(1000, ids_dto->ids->count()); + + sleep(8); + + milvus::server::web::OQueryParams query_params; + query_params.put("fields", "num"); + auto tables_dto = milvus::server::web::TableFieldsDto::createShared(); + status_dto = handler->GetTable(table_name, query_params, tables_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + ASSERT_EQ(1000, tables_dto->count->getValue()); +} + +TEST_F(WebHandlerTest, INDEX) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + GenTable(table_name->std_str(), 16, 10, "L2"); + + auto index_request_dto = milvus::server::web::IndexRequestDto::createShared(); + index_request_dto->index_type = "FLAT"; + index_request_dto->nlist = 10; + + milvus::server::web::StatusDto::createShared(); + + auto status_dto = handler->CreateIndex(table_name, index_request_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + + status_dto = handler->DropIndex(table_name); + ASSERT_EQ(0, status_dto->code->getValue()); + + // invalid index_type + index_request_dto->index_type = "AAA"; + status_dto = handler->CreateIndex(table_name, index_request_dto); + ASSERT_NE(0, status_dto->code->getValue()); + ASSERT_EQ(StatusCode::ILLEGAL_INDEX_TYPE, status_dto->code->getValue()); + + // invalid nlist + index_request_dto->index_type = "FLAT"; + index_request_dto->nlist = -1; + status_dto = handler->CreateIndex(table_name, index_request_dto); + ASSERT_NE(0, status_dto->code->getValue()); + ASSERT_EQ(StatusCode::ILLEGAL_NLIST, status_dto->code->getValue()); +} + +TEST_F(WebHandlerTest, PARTITION) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + GenTable(table_name->std_str(), 16, 10, "L2"); + + auto partition_dto = milvus::server::web::PartitionRequestDto::createShared(); + partition_dto->partition_name = "partition_test"; + partition_dto->partition_tag = "test"; + + auto status_dto = handler->CreatePartition(table_name, partition_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + + // test partition name equal to table name + partition_dto->partition_name = table_name; + partition_dto->partition_tag = "test02"; + status_dto = handler->CreatePartition(table_name, partition_dto); + ASSERT_NE(0, status_dto->code->getValue()); + ASSERT_EQ(StatusCode::ILLEGAL_TABLE_NAME, status_dto->code->getValue()); + + auto partitions_dto = milvus::server::web::PartitionListDto::createShared(); + status_dto = handler->ShowPartitions(0, 10, table_name, partitions_dto); + ASSERT_EQ(1, partitions_dto->partitions->count()); + + status_dto = handler->DropPartition(table_name, "test"); + ASSERT_EQ(0, status_dto->code->getValue()); + + // Show all partitions + partitions_dto = milvus::server::web::PartitionListDto::createShared(); + status_dto = handler->ShowPartitions(0, 10, table_name, partitions_dto); +} + +TEST_F(WebHandlerTest, SEARCH) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + + auto table_name = milvus::server::web::OString(TABLE_NAME) + RandomName().c_str(); + GenTable(table_name->std_str(), TABLE_DIM, 10, "L2"); + + auto insert_request_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_request_dto->records = insert_request_dto->records->createShared(); + for (size_t i = 0; i < 1000; i++) { + insert_request_dto->records->pushBack(RandomRowRecordDto(TABLE_DIM)); + } + insert_request_dto->ids = insert_request_dto->ids->createShared(); + auto ids_dto = milvus::server::web::VectorIdsDto::createShared(); + auto status_dto = handler->Insert(table_name, insert_request_dto, ids_dto); + + auto search_request_dto = milvus::server::web::SearchRequestDto::createShared(); + search_request_dto->records = RandomRecordsDto(TABLE_DIM, 10); + search_request_dto->topk = 1; + search_request_dto->nprobe = 1; + + auto results_dto = milvus::server::web::TopkResultsDto::createShared(); + + status_dto = handler->Search(table_name, search_request_dto, results_dto); + ASSERT_EQ(0, status_dto->code->getValue()) << status_dto->message->std_str(); +} + +TEST_F(WebHandlerTest, CMD) { + handler->RegisterRequestHandler(milvus::server::RequestHandler()); + milvus::server::web::OString cmd; + auto cmd_dto = milvus::server::web::CommandDto::createShared(); + + cmd = "status"; + auto status_dto = handler->Cmd(cmd, cmd_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + ASSERT_EQ("OK", cmd_dto->reply->std_str()); + + cmd = "version"; + status_dto = handler->Cmd(cmd, cmd_dto); + ASSERT_EQ(0, status_dto->code->getValue()); + ASSERT_EQ("0.6.0", cmd_dto->reply->std_str()); +} + +/////////////////////////////////////////////////////////////////////////////////////// + +namespace { +static const char* CONTROLLER_TEST_TABLE_NAME = "controller_unit_test"; + +class TestClient : public oatpp::web::client::ApiClient { + public: +#include OATPP_CODEGEN_BEGIN(ApiClient) + API_CLIENT_INIT(TestClient) + + API_CALL("GET", "/", root) + + API_CALL("GET", "/state", getState) + + API_CALL("GET", "/devices", getDevices) + + API_CALL("GET", "/config/advanced", getAdvanced) + + API_CALL("OPTIONS", "/config/advanced", optionsAdvanced) + + API_CALL("PUT", "/config/advanced", setAdvanced, + BODY_DTO(milvus::server::web::AdvancedConfigDto::ObjectWrapper, body)) + +#ifdef MILVUS_GPU_VERSION + + API_CALL("OPTIONS", "config/gpu_resources", optionsGpuConfig) + + API_CALL("GET", "/config/gpu_resources", getGPUConfig) + + API_CALL("PUT", "/config/gpu_resources", setGPUConfig, + BODY_DTO(milvus::server::web::GPUConfigDto::ObjectWrapper, body)) + +#endif + + API_CALL("OPTIONS", "/tables", optionsTables) + + API_CALL("POST", "/tables", createTable, BODY_DTO(milvus::server::web::TableRequestDto::ObjectWrapper, body)) + + API_CALL("GET", "/tables", showTables, QUERY(Int64, offset), QUERY(Int64, page_size)) + + API_CALL("OPTIONS", "/tables/{table_name}", optionsTable, PATH(String, table_name, "table_name")) + + API_CALL("GET", "/tables/{table_name}", getTable, PATH(String, table_name, "table_name")) + + API_CALL("DELETE", "/tables/{table_name}", dropTable, PATH(String, table_name, "table_name")) + + API_CALL("OPTIONS", "/tables/{table_name}/indexes", optionsIndexes, PATH(String, table_name, "table_name")) + + API_CALL("POST", + "/tables/{table_name}/indexes", + createIndex, + PATH(String, table_name, "table_name"), + BODY_DTO(milvus::server::web::IndexRequestDto::ObjectWrapper, body)) + + API_CALL("GET", "/tables/{table_name}/indexes", getIndex, PATH(String, table_name, "table_name")) + + API_CALL("DELETE", "/tables/{table_name}/indexes", dropIndex, PATH(String, table_name, "table_name")) + + API_CALL("OPTIONS", "/tables/{table_name}/partitions", optionsPartitions, PATH(String, table_name, "table_name")) + + API_CALL("POST", + "/tables/{table_name}/partitions", + createPartition, + PATH(String, table_name, "table_name"), + BODY_DTO(milvus::server::web::PartitionRequestDto::ObjectWrapper, body)) + + API_CALL("GET", + "/tables/{table_name}/partitions", + showPartitions, + PATH(String, table_name, "table_name"), + QUERY(Int64, offset), + QUERY(Int64, page_size)) + + API_CALL("OPTIONS", + "/tables/{table_name}/partitions/{partition_tag}", + optionsParTag, + PATH(String, table_name, "table_name"), + PATH(String, partition_tag, "partition_tag")) + + API_CALL("DELETE", + "/tables/{table_name}/partitions/{partition_tag}", + dropPartition, + PATH(String, table_name, "table_name"), + PATH(String, partition_tag)) + + API_CALL("OPTIONS", "/tables/{table_name}/vectors", optionsVectors, PATH(String, table_name, "table_name")) + + API_CALL("POST", + "/tables/{table_name}/vectors", + insert, + PATH(String, table_name, "table_name"), + BODY_DTO(milvus::server::web::InsertRequestDto::ObjectWrapper, body)) + + API_CALL("PUT", + "/tables/{table_name}/vectors", + search, + PATH(String, table_name, "table_name"), + BODY_DTO(milvus::server::web::SearchRequestDto::ObjectWrapper, body)) + + API_CALL("GET", "/system/{msg}", cmd, PATH(String, cmd_str, "msg")) + +#include OATPP_CODEGEN_END(ApiClient) +}; + +class WebControllerTest : public testing::Test { + protected: + static void + SetUpTestCase() { + auto res_mgr = milvus::scheduler::ResMgrInst::GetInstance(); + res_mgr->Clear(); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("disk", "DISK", 0, false)); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("cpu", "CPU", 0)); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("gtx1660", "GPU", 0)); + + auto default_conn = milvus::scheduler::Connection("IO", 500.0); + auto PCIE = milvus::scheduler::Connection("IO", 11000.0); + res_mgr->Connect("disk", "cpu", default_conn); + res_mgr->Connect("cpu", "gtx1660", PCIE); + res_mgr->Start(); + milvus::scheduler::SchedInst::GetInstance()->Start(); + milvus::scheduler::JobMgrInst::GetInstance()->Start(); + + milvus::engine::DBOptions opt; + + milvus::server::Config::GetInstance().SetDBConfigBackendUrl("sqlite://:@:/"); + boost::filesystem::remove_all("/tmp/milvus_web_controller_test"); + milvus::server::Config::GetInstance().SetStorageConfigPrimaryPath("/tmp/milvus_web_controller_test"); + milvus::server::Config::GetInstance().SetStorageConfigSecondaryPath(""); + milvus::server::Config::GetInstance().SetDBConfigArchiveDiskThreshold(""); + milvus::server::Config::GetInstance().SetDBConfigArchiveDaysThreshold(""); + milvus::server::Config::GetInstance().SetCacheConfigCacheInsertData(""); + milvus::server::Config::GetInstance().SetEngineConfigOmpThreadNum(""); + + milvus::server::DBWrapper::GetInstance().StartService(); + + milvus::server::Config::GetInstance().SetServerConfigWebPort("29999"); + + milvus::server::web::WebServer::GetInstance().Start(); + + sleep(5); + } + + static void + TearDownTestCase() { + milvus::server::web::WebServer::GetInstance().Stop(); + + milvus::server::DBWrapper::GetInstance().StopService(); + milvus::scheduler::JobMgrInst::GetInstance()->Stop(); + milvus::scheduler::ResMgrInst::GetInstance()->Stop(); + milvus::scheduler::SchedInst::GetInstance()->Stop(); + boost::filesystem::remove_all("/tmp/milvus_web_controller_test"); + } + + void + GenTable(const OString& table_name, int64_t dim, int64_t index_size, const OString& metric) { + auto response = client_ptr->getTable(table_name, conncetion_ptr); + if (OStatus::CODE_200.code == response->getStatusCode()) { + return; + } + auto table_dto = milvus::server::web::TableRequestDto::createShared(); + table_dto->table_name = table_name; + table_dto->dimension = dim; + table_dto->index_file_size = index_size; + table_dto->metric_type = metric; + client_ptr->createTable(table_dto, conncetion_ptr); + } + + void + SetUp() override { + OATPP_COMPONENT(std::shared_ptr, clientConnectionProvider); + OATPP_COMPONENT(std::shared_ptr, objectMapper); + object_mapper = objectMapper; + + auto requestExecutor = oatpp::web::client::HttpRequestExecutor::createShared(clientConnectionProvider); + client_ptr = TestClient::createShared(requestExecutor, objectMapper); + + conncetion_ptr = client_ptr->getConnection(); + } + + void + TearDown() override { + }; + + protected: + std::shared_ptr object_mapper; + std::shared_ptr conncetion_ptr; + std::shared_ptr client_ptr; + + protected: + void GenTable(const std::string& table_name, int64_t dim, int64_t index_file_size, int64_t metric_type) { + auto table_dto = milvus::server::web::TableRequestDto::createShared(); + table_dto->table_name = OString(table_name.c_str()); + table_dto->dimension = dim; + table_dto->index_file_size = index_file_size; + table_dto->metric_type = metric_type; + + client_ptr->createTable(table_dto, conncetion_ptr); + } +}; + +} // namespace +TEST_F(WebControllerTest, OPTIONS) { + auto response = client_ptr->root(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->getState(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->optionsAdvanced(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + +#ifdef MILVUS_GPU_VERSION + + response = client_ptr->optionsGpuConfig(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + +#endif + + response = client_ptr->optionsIndexes("test", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + response = client_ptr->optionsParTag("test", "tag", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + response = client_ptr->optionsPartitions("table_name", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + response = client_ptr->optionsTable("table", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + response = client_ptr->optionsTables(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + response = client_ptr->optionsVectors("table", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, CREATE_TABLE) { + auto table_dto = milvus::server::web::TableRequestDto::createShared(); + auto response = client_ptr->createTable(table_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code) << result_dto->message->std_str(); + + OString table_name = "web_test_create_table" + OString(RandomName().c_str()); + + table_dto->table_name = table_name; + response = client_ptr->createTable(table_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code) << result_dto->message->std_str(); + + table_dto->dimension = 128; + table_dto->index_file_size = 10; + table_dto->metric_type = "L2"; + + response = client_ptr->createTable(table_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + // invalid table name + table_dto->table_name = "9090&*&()"; + response = client_ptr->createTable(table_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, GET_TABLE) { + OString table_name = "web_test_create_table" + OString(RandomName().c_str()); + GenTable(table_name, 10, 10, "L2"); + + OQueryParams params; + + // fields value is 'num', test count table + params.put("fields", "num"); + auto response = client_ptr->getTable(table_name, conncetion_ptr); + + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + + response = client_ptr->getTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + // invalid table name + table_name = "57474dgdfhdfhdh dgd"; + response = client_ptr->getTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + auto status_sto = response->readBodyToDto(object_mapper.get()); + + table_name = "test_table_not_found_0000000001110101010020202030203030435"; + response = client_ptr->getTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_404.code, response->getStatusCode()); + status_sto = response->readBodyToDto(object_mapper.get()); +} + +TEST_F(WebControllerTest, SHOW_TABLES) { + // test query table limit 1 + auto response = client_ptr->showTables(1, 1, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + // test query table empty + response = client_ptr->showTables(0, 0, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->showTables(-1, 0, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + + response = client_ptr->showTables(0, -10, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, DROP_TABLE) { + auto table_name = "table_drop_test" + OString(RandomName().c_str()); + GenTable(table_name, 128, 100, "L2"); + + sleep(1); + + auto response = client_ptr->dropTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, INSERT) { + auto table_name = "test_insert_table_test" + OString(RandomName().c_str()); + const int64_t dim = 64; + GenTable(table_name, dim, 100, "L2"); + + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_dto->ids = insert_dto->ids->createShared(); + insert_dto->records = RandomRecordsDto(dim, 20); + + auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(20, result_dto->ids->count()); + + response = client_ptr->dropTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, INSERT_IDS) { + auto table_name = "test_insert_table_test" + OString(RandomName().c_str()); + const int64_t dim = 64; + GenTable(table_name, dim, 100, "L2"); + + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_dto->ids = insert_dto->ids->createShared(); + for (size_t i = 0; i < 20; i++) { + insert_dto->ids->pushBack(i); + } + + insert_dto->records = RandomRecordsDto(dim, 20); + + auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(20, result_dto->ids->count()); + + response = client_ptr->dropTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, INDEX) { + auto table_name = "test_insert_table_test" + OString(RandomName().c_str()); + GenTable(table_name, 64, 100, "L2"); + + // test index with imcomplete param + auto index_dto = milvus::server::web::IndexRequestDto::createShared(); + auto response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + // drop index + response = client_ptr->dropIndex(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + index_dto->index_type = milvus::server::web::IndexMap.at(milvus::engine::EngineType::FAISS_IDMAP).c_str(); + + response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + // drop index + response = client_ptr->dropIndex(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + index_dto->index_type = "J46"; + response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::ILLEGAL_INDEX_TYPE, result_dto->code); + + index_dto->index_type = milvus::server::web::IndexMap.at(milvus::engine::EngineType::FAISS_IDMAP).c_str(); + index_dto->nlist = 10; + + response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + // drop index + response = client_ptr->dropIndex(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + // invalid index type + index_dto->index_type = 100; + response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); + ASSERT_NE(OStatus::CODE_201.code, response->getStatusCode()); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + + // insert data and create index + response = client_ptr->dropIndex(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); + + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_dto->ids = insert_dto->ids->createShared(); + insert_dto->records = RandomRecordsDto(64, 200); + + response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + index_dto->index_type = milvus::server::web::IndexMap.at(milvus::engine::EngineType::FAISS_IDMAP).c_str(); + response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + // get index + response = client_ptr->getIndex(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, PARTITION) { + const OString table_name = "test_controller_partition_" + OString(RandomName().c_str()); + GenTable(table_name, 64, 100, "L2"); + + auto par_param = milvus::server::web::PartitionRequestDto::createShared(); + auto response = client_ptr->createPartition(table_name, par_param); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + + par_param->partition_name = "partition01" + OString(RandomName().c_str()); + response = client_ptr->createPartition(table_name, par_param); + result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + + par_param->partition_tag = "tag01"; + response = client_ptr->createPartition(table_name, par_param); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + // insert 200 vectors into table with tag = 'tag01' + OQueryParams query_params; + // add partition tag + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + // add partition tag + insert_dto->tag = OString("tag01"); + insert_dto->ids = insert_dto->ids->createShared(); + insert_dto->records = insert_dto->records->createShared(); + for (size_t i = 0; i < 200; i++) { + insert_dto->records->pushBack(RandomRowRecordDto(64)); + } + response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + // Show all partitins + response = client_ptr->showPartitions(table_name, 0, 10, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->dropPartition(table_name, "tag01", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, SEARCH) { + const OString table_name = "test_partition_table_test" + OString(RandomName().c_str()); + GenTable(table_name, 64, 100, "L2"); + + // Insert 200 vectors into table + OQueryParams query_params; + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_dto->ids = insert_dto->ids->createShared(); + insert_dto->records = RandomRecordsDto(64, 200);// insert_dto->records->createShared(); + + auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + sleep(4); + + //Create partition and insert 200 vectors into it + auto par_param = milvus::server::web::PartitionRequestDto::createShared(); + par_param->partition_name = "partition" + OString(RandomName().c_str()); + par_param->partition_tag = "tag" + OString(RandomName().c_str()); + response = client_ptr->createPartition(table_name, par_param); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()) + << "Error: " << response->getStatusDescription()->std_str(); + + insert_dto->tag = par_param->partition_tag; + response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + sleep(5); + + // Test search + auto search_request_dto = milvus::server::web::SearchRequestDto::createShared(); + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + + search_request_dto->nprobe = 1; + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + + search_request_dto->topk = 1; + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + + search_request_dto->records = RandomRecordsDto(64, 10); + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + // Test search with tags + search_request_dto->tags = search_request_dto->tags->createShared(); + search_request_dto->tags->pushBack(par_param->partition_tag); + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, CMD) { + auto response = client_ptr->cmd("status", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->cmd("version", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, ADVANCEDCONFIG) { + auto response = client_ptr->getAdvanced(conncetion_ptr); + + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + auto config_dto = milvus::server::web::AdvancedConfigDto::createShared(); + response = client_ptr->setAdvanced(config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + config_dto->cpu_cache_capacity = 3; + response = client_ptr->setAdvanced(config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + config_dto->cache_insert_data = true; + response = client_ptr->setAdvanced(config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + +#ifdef MILVUS_GPU_VERSION + + config_dto->gpu_search_threshold = 1000; + response = client_ptr->setAdvanced(config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + +#endif + + config_dto->use_blas_threshold = 1000; + response = client_ptr->setAdvanced(config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + +#ifdef MILVUS_GPU_VERSION +TEST_F(WebControllerTest, GPUCONFIG) { + auto response = client_ptr->getGPUConfig(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + auto gpu_config_dto = milvus::server::web::GPUConfigDto::createShared(); + + response = client_ptr->setGPUConfig(gpu_config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + gpu_config_dto->enable = true; + response = client_ptr->setGPUConfig(gpu_config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + gpu_config_dto->cache_capacity = 2; + response = client_ptr->setGPUConfig(gpu_config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared(); + gpu_config_dto->build_index_resources->pushBack("GPU0"); + response = client_ptr->setGPUConfig(gpu_config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared(); + gpu_config_dto->search_resources->pushBack("GPU0"); + + response = client_ptr->setGPUConfig(gpu_config_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + +#endif + +TEST_F(WebControllerTest, DEVICESCONFIG) { + auto response = WebControllerTest::client_ptr->getDevices(conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + -- GitLab