From 015f0352a62fe11d6a589703bce39af8f2d78f6a Mon Sep 17 00:00:00 2001 From: BossZou <40255591+BossZou@users.noreply.github.com> Date: Sat, 18 Jan 2020 10:05:49 +0800 Subject: [PATCH] Fix http bug & add binary vectors support (#1073) * refactoring(create_table done) * refactoring * refactor server delivery (insert done) * refactoring server module (count_table done) * server refactor done * cmake pass * refactor server module done. * set grpc response status correctly * format done. * fix redefine ErrorMap() * optimize insert reducing ids data copy * optimize grpc request with reducing data copy * clang format * [skip ci] Refactor server module done. update changlog. prepare for PR * remove explicit and change int32_t to int64_t * add web server * [skip ci] add license in web module * modify header include & comment oatpp environment config * add port configure & create table in handler * modify web url * simple url complation done & add swagger * make sure web url * web functionality done. debuging * add web unittest * web test pass * add web server port * add web server port in template * update unittest cmake file * change web server default port to 19121 * rename method in web module & unittest pass * add search case in unittest for web module * rename some variables * fix bug * unittest pass * web prepare * fix cmd bug(check server status) * update changlog * add web port validate & default set * clang-format pass * add web port test in unittest * add CORS & redirect root to swagger ui * add web status * web table method func cascade test pass * add config url in web module * modify thirdparty cmake to avoid building oatpp test * clang format * update changlog * add constants in web module * reserve Config.cpp * fix constants reference bug * replace web server with async module * modify component to support async * format * developing controller & add test clent into unittest * add web port into demo/server_config * modify thirdparty cmake to allow build test * remove unnecessary comment * add endpoint info in controller * finish web test(bug here) * clang format * add web test cpp to lint exclusions * check null field in GetConfig * add macro RETURN STATUS DTo * fix cmake conflict * fix crash when exit server * remove surplus comments & add http param check * add uri /docs to direct swagger * format * change cmd to system * add default value & unittest in web module * add macros to judge if GPU supported * add macros in unit & add default in index dto & print error message when bind http port fail * format (fix #788) * fix cors bug (not completed) * comment cors * change web framework to simple api * comments optimize * change to simple API * remove comments in controller.hpp * remove EP_COMMON_CMAKE_ARGS in oatpp and oatpp-swagger * add ep cmake args to sqlite * clang-format * change a format * test pass * change name to * fix compiler issue(oatpp-swagger depend on oatpp) * add & in start_server.h * specify lib location with oatpp and oatpp-swagger * add comments * add swagger definition * [skip ci] change http method options status code * remove oatpp swagger(fix #970) * remove comments * check Start web behavior * add default to cpu_cache_capacity * remove swagger component.hpp & /docs url * remove /docs info * remove /docs in unittest * remove space in test rpc * remove repeate info in CHANGLOG * change cache_insert_data default value as a constant * [skip ci] Fix some broken links (#960) * [skip ci] Fix broken link * [skip ci] Fix broken link * [skip ci] Fix broken link * [skip ci] Fix broken links * fix issue 373 (#964) * fix issue 373 * Adjustment format * Adjustment format * Adjustment format * change readme * #966 update NOTICE.md (#967) * remove comments * check Start web behavior * add default to cpu_cache_capacity * remove swagger component.hpp & /docs url * remove /docs info * remove /docs in unittest * remove space in test rpc * remove repeate info in CHANGLOG * change cache_insert_data default value as a constant * adjust web port cofig place * rename web_port variable * change gpu resources invoke way to cmd() * set advanced config name add DEFAULT * change config setting to cmd * modify .. * optimize code * assign TableDto' count default value 0 (fix #995) * check if table exists when show partitions (fix #1028) * check table exists when drop partition (fix #1029) * check if partition name is legal (fix #1022) * modify status code when partition tag is illegal * update changlog * add info to /system url * add binary index and add bin uri & handler method(not completed) * optimize http insert and search time(fix #1066) | add binary vectors support(fix #1067) * fix test partition bug * fix test bug when check insert records * add binary vectors test * add default for offset and page_size * fix uinttest bug * [skip ci] remove comments * optimize web code for PR comments * add new folder named utils Co-authored-by: jielinxu <52057195+jielinxu@users.noreply.github.com> Co-authored-by: JackLCL <53512883+JackLCL@users.noreply.github.com> Co-authored-by: Cai Yudong --- CHANGELOG.md | 7 + core/src/CMakeLists.txt | 2 + .../delivery/request/DropPartitionRequest.cpp | 15 +- .../request/ShowPartitionsRequest.cpp | 10 + core/src/server/web_impl/Constants.h | 4 +- core/src/server/web_impl/Types.h | 13 +- core/src/server/web_impl/WebServer.h | 2 - .../web_impl/controller/WebController.hpp | 370 ++++++++----- core/src/server/web_impl/dto/TableDto.hpp | 9 +- core/src/server/web_impl/dto/VectorDto.hpp | 10 +- .../web_impl/handler/WebRequestHandler.cpp | 524 +++++++++--------- .../web_impl/handler/WebRequestHandler.h | 20 +- core/src/server/web_impl/utils/Util.cpp | 65 +++ core/src/server/web_impl/utils/Util.h | 38 ++ core/src/utils/ValidationUtil.cpp | 30 +- core/unittest/CMakeLists.txt | 2 + core/unittest/server/test_web.cpp | 237 ++++++-- tests/milvus_python_test/test_partition.py | 6 +- 18 files changed, 901 insertions(+), 463 deletions(-) create mode 100644 core/src/server/web_impl/utils/Util.cpp create mode 100644 core/src/server/web_impl/utils/Util.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 1609e48f..518dbd3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,13 @@ Please mark all change in change log and use the issue from GitHub - \#805 - IVFTest.gpu_seal_test unittest failed - \#831 - Judge branch error in CommonUtil.cpp - \#977 - Server crash when create tables concurrently +- \#995 - table count set to 0 if no tables found +- \#1010 - improve error message when offset or page_size is equal 0 +- \#1022 - check if partition name is legal +- \#1028 - check if table exists when show partitions +- \#1029 - check if table exists when try to delete partition +- \#1066 - optimize http insert and search speed +- \#1067 - Add binary vectors support in http server ## Feature - \#216 - Add CLI to get server info diff --git a/core/src/CMakeLists.txt b/core/src/CMakeLists.txt index 9b02c55a..e34f8ad0 100644 --- a/core/src/CMakeLists.txt +++ b/core/src/CMakeLists.txt @@ -96,12 +96,14 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/handler web_handler_fi aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/component web_conponent_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/controller web_controller_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/dto web_dto_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/utils web_utils_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl web_impl_files) set(web_server_files ${web_handler_files} ${web_conponent_files} ${web_controller_files} ${web_dto_files} + ${web_utils_files} ${web_impl_files} ) diff --git a/core/src/server/delivery/request/DropPartitionRequest.cpp b/core/src/server/delivery/request/DropPartitionRequest.cpp index 461ca8b1..e0d2bcf9 100644 --- a/core/src/server/delivery/request/DropPartitionRequest.cpp +++ b/core/src/server/delivery/request/DropPartitionRequest.cpp @@ -47,8 +47,19 @@ DropPartitionRequest::OnExecute() { std::string table_name = table_name_; std::string partition_name = partition_name_; std::string partition_tag = tag_; + + bool exists; + auto status = DBWrapper::DB()->HasTable(table_name, exists); + if (!status.ok()) { + return status; + } + + if (!exists) { + return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + } + if (!partition_name.empty()) { - auto status = ValidationUtil::ValidateTableName(partition_name); + status = ValidationUtil::ValidateTableName(partition_name); if (!status.ok()) { return status; } @@ -68,7 +79,7 @@ DropPartitionRequest::OnExecute() { return DBWrapper::DB()->DropPartition(partition_name); } else { - auto status = ValidationUtil::ValidateTableName(table_name); + status = ValidationUtil::ValidateTableName(table_name); if (!status.ok()) { return status; } diff --git a/core/src/server/delivery/request/ShowPartitionsRequest.cpp b/core/src/server/delivery/request/ShowPartitionsRequest.cpp index eddbd67b..2d1e4ce3 100644 --- a/core/src/server/delivery/request/ShowPartitionsRequest.cpp +++ b/core/src/server/delivery/request/ShowPartitionsRequest.cpp @@ -48,6 +48,16 @@ ShowPartitionsRequest::OnExecute() { return status; } + bool exists = false; + status = DBWrapper::DB()->HasTable(table_name_, exists); + if (!status.ok()) { + return status; + } + + if (!exists) { + return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + } + std::vector schema_array; status = DBWrapper::DB()->ShowPartitions(table_name_, schema_array); if (!status.ok()) { diff --git a/core/src/server/web_impl/Constants.h b/core/src/server/web_impl/Constants.h index 31941c39..8845c913 100644 --- a/core/src/server/web_impl/Constants.h +++ b/core/src/server/web_impl/Constants.h @@ -1,4 +1,3 @@ - // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -46,6 +45,9 @@ static const char* NAME_ENGINE_TYPE_IVFPQ = "IVFPQ"; static const char* NAME_METRIC_TYPE_L2 = "L2"; static const char* NAME_METRIC_TYPE_IP = "IP"; +static const char* NAME_METRIC_TYPE_HAMMING = "HAMMING"; +static const char* NAME_METRIC_TYPE_JACCARD = "JACCARD"; +static const char* NAME_METRIC_TYPE_TANIMOTO = "TANIMOTO"; //////////////////////////////////////////////////// diff --git a/core/src/server/web_impl/Types.h b/core/src/server/web_impl/Types.h index 97b09961..a54dca8a 100644 --- a/core/src/server/web_impl/Types.h +++ b/core/src/server/web_impl/Types.h @@ -21,9 +21,9 @@ #include #include +#include #include "db/engine/ExecutionEngine.h" - #include "server/web_impl/Constants.h" namespace milvus { @@ -31,6 +31,8 @@ namespace server { namespace web { using OString = oatpp::data::mapping::type::String; +using OInt8 = oatpp::data::mapping::type::Int8; +using OInt16 = oatpp::data::mapping::type::Int16; using OInt64 = oatpp::data::mapping::type::Int64; using OFloat32 = oatpp::data::mapping::type::Float32; template @@ -65,10 +67,11 @@ enum StatusCode : int { ILLEGAL_METRIC_TYPE = 23, OUT_OF_MEMORY = 24, - // HTTP status code + // HTTP error code PATH_PARAM_LOSS = 31, QUERY_PARAM_LOSS = 32, BODY_FIELD_LOSS = 33, + ILLEGAL_QUERY_PARAM = 36, }; static const std::unordered_map IndexMap = { @@ -92,11 +95,17 @@ static const std::unordered_map IndexNameMap = static const std::unordered_map MetricMap = { {engine::MetricType::L2, NAME_METRIC_TYPE_L2}, {engine::MetricType::IP, NAME_METRIC_TYPE_IP}, + {engine::MetricType::HAMMING, NAME_METRIC_TYPE_HAMMING}, + {engine::MetricType::JACCARD, NAME_METRIC_TYPE_JACCARD}, + {engine::MetricType::TANIMOTO, NAME_METRIC_TYPE_TANIMOTO}, }; static const std::unordered_map MetricNameMap = { {NAME_METRIC_TYPE_L2, engine::MetricType::L2}, {NAME_METRIC_TYPE_IP, engine::MetricType::IP}, + {NAME_METRIC_TYPE_HAMMING, engine::MetricType::HAMMING}, + {NAME_METRIC_TYPE_JACCARD, engine::MetricType::JACCARD}, + {NAME_METRIC_TYPE_TANIMOTO, engine::MetricType::TANIMOTO}, }; } // namespace web diff --git a/core/src/server/web_impl/WebServer.h b/core/src/server/web_impl/WebServer.h index a99f5482..865fd25c 100644 --- a/core/src/server/web_impl/WebServer.h +++ b/core/src/server/web_impl/WebServer.h @@ -22,8 +22,6 @@ #include #include -#include - #include "server/web_impl/component/AppComponent.hpp" #include "utils/Status.h" diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp index 585569f7..0656fa61 100644 --- a/core/src/server/web_impl/controller/WebController.hpp +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -98,10 +98,12 @@ class WebController : public oatpp::web::server::api::ApiController { handler.RegisterRequestHandler(::milvus::server::RequestHandler()); auto status_dto = handler.GetDevices(devices_dto); std::shared_ptr response; - if (0 == status_dto->code->getValue()) { - response = createDtoResponse(Status::CODE_200, devices_dto); - } else { - response = createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, devices_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } return response; @@ -127,11 +129,14 @@ class WebController : public oatpp::web::server::api::ApiController { WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); auto status_dto = handler.GetAdvancedConfig(config_dto); + std::shared_ptr response; - if (0 == status_dto->code->getValue()) { - response = createDtoResponse(Status::CODE_200, config_dto); - } else { - response = createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, config_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } return response; @@ -152,13 +157,17 @@ class WebController : public oatpp::web::server::api::ApiController { WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - auto status_dto = handler.SetAdvancedConfig(body); std::shared_ptr response; - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_200, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + auto status_dto = handler.SetAdvancedConfig(body); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } #ifdef MILVUS_GPU_VERSION @@ -182,13 +191,18 @@ class WebController : public oatpp::web::server::api::ApiController { auto gpu_config_dto = GPUConfigDto::createShared(); WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - auto status_dto = handler.GetGpuConfig(gpu_config_dto); - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_200, gpu_config_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + std::shared_ptr response; + auto status_dto = handler.GetGpuConfig(gpu_config_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, gpu_config_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(SetGPUConfig) { @@ -207,11 +221,15 @@ class WebController : public oatpp::web::server::api::ApiController { auto status_dto = handler.SetGpuConfig(body); std::shared_ptr response; - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_200, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } #endif @@ -227,7 +245,7 @@ class WebController : public oatpp::web::server::api::ApiController { info->addConsumes("application/json"); - info->addResponse(Status::CODE_201, "application/json"); + info->addResponse(Status::CODE_201, "application/json"); info->addResponse(Status::CODE_400, "application/json"); } @@ -237,12 +255,17 @@ class WebController : public oatpp::web::server::api::ApiController { WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + std::shared_ptr response; auto status_dto = handler.CreateTable(body); - if (0 != status_dto->code) { - return createDtoResponse(Status::CODE_400, status_dto); - } else { - return createDtoResponse(Status::CODE_201, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(ShowTables) { @@ -257,17 +280,25 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(ShowTables) - ENDPOINT("GET", "/tables", ShowTables, QUERY(Int64, offset, "offset"), QUERY(Int64, page_size, "page_size")) { + ENDPOINT("GET", "/tables", ShowTables, REQUEST( + const std::shared_ptr&, request)) { WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); auto response_dto = TableListFieldsDto::createShared(); - auto status_dto = handler.ShowTables(offset, page_size, response_dto); + auto offset = request->getQueryParameter("offset", "0"); + auto page_size = request->getQueryParameter("page_size", "10"); + std::shared_ptr response; - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_200, response_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + auto status_dto = handler.ShowTables(offset, page_size, response_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, response_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ADD_CORS(TableOptions) @@ -296,21 +327,25 @@ class WebController : public oatpp::web::server::api::ApiController { handler.RegisterRequestHandler(::milvus::server::RequestHandler()); auto fields_dto = TableFieldsDto::createShared(); auto status_dto = handler.GetTable(table_name, query_params, fields_dto); - auto code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_200, fields_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + + std::shared_ptr response; + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, fields_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(DropTable) { info->summary = "Drop table"; - info->pathParams.add("table_name"); - info->addResponse(Status::CODE_204, "application/json"); info->addResponse(Status::CODE_400, "application/json"); info->addResponse(Status::CODE_404, "application/json"); @@ -321,15 +356,21 @@ class WebController : public oatpp::web::server::api::ApiController { ENDPOINT("DELETE", "/tables/{table_name}", DropTable, PATH(String, table_name)) { WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + std::shared_ptr response; auto status_dto = handler.DropTable(table_name); - auto code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_204, status_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_204, status_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ADD_CORS(IndexOptions) @@ -341,29 +382,34 @@ class WebController : public oatpp::web::server::api::ApiController { ENDPOINT_INFO(CreateIndex) { info->summary = "Create index"; - info->pathParams.add("table_name"); - info->addConsumes("application/json"); + info->addResponse(Status::CODE_201, "application/json"); info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); } ADD_CORS(CreateIndex) - ENDPOINT("POST", - "/tables/{table_name}/indexes", - CreateIndex, - PATH(String, table_name), - BODY_DTO(IndexRequestDto::ObjectWrapper, body)) { + ENDPOINT("POST", "/tables/{table_name}/indexes", CreateIndex, + PATH(String, table_name), BODY_DTO(IndexRequestDto::ObjectWrapper, body)) { auto handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - auto status_dto = handler.CreateIndex(table_name, body); - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_201, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + std::shared_ptr response; + auto status_dto = handler.CreateIndex(table_name, body); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, status_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(GetIndex) { @@ -382,15 +428,21 @@ class WebController : public oatpp::web::server::api::ApiController { auto index_dto = IndexDto::createShared(); auto handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + std::shared_ptr response; auto status_dto = handler.GetIndex(table_name, index_dto); - auto code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_200, index_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, index_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(DropIndex) { @@ -408,15 +460,21 @@ class WebController : public oatpp::web::server::api::ApiController { ENDPOINT("DELETE", "/tables/{table_name}/indexes", DropIndex, PATH(String, table_name)) { auto handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + std::shared_ptr response; auto status_dto = handler.DropIndex(table_name); - auto code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_204, status_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_204, status_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ADD_CORS(PartitionsOptions) @@ -434,6 +492,7 @@ class WebController : public oatpp::web::server::api::ApiController { info->addResponse(Status::CODE_201, "application/json"); info->addResponse(Status::CODE_400, "application/json"); + info->addResponse(Status::CODE_404, "application/json"); } ADD_CORS(CreatePartition) @@ -442,13 +501,21 @@ class WebController : public oatpp::web::server::api::ApiController { CreatePartition, PATH(String, table_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) { auto handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - auto status_dto = handler.CreatePartition(table_name, body); - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_201, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + std::shared_ptr response; + auto status_dto = handler.CreatePartition(table_name, body); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, status_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(ShowPartitions) { @@ -469,26 +536,29 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(ShowPartitions) - ENDPOINT("GET", - "/tables/{table_name}/partitions", - ShowPartitions, - PATH(String, table_name), - QUERY(Int64, offset, "offset"), - QUERY(Int64, page_size, "page_size")) { - auto status_dto = StatusDto::createShared(); + ENDPOINT("GET", "/tables/{table_name}/partitions", ShowPartitions, + PATH(String, table_name), REQUEST( + const std::shared_ptr&, request)) { + auto offset = request->getQueryParameter("offset", "0"); + auto page_size = request->getQueryParameter("page_size", "10"); + auto partition_list_dto = PartitionListDto::createShared(); auto handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - status_dto = handler.ShowPartitions(offset, page_size, table_name, partition_list_dto); - int64_t code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_200, partition_list_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + std::shared_ptr response; + auto status_dto = handler.ShowPartitions(offset, page_size, table_name, partition_list_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, partition_list_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default:response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ADD_CORS(PartitionOptions) @@ -510,22 +580,31 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(DropPartition) - ENDPOINT("DELETE", - "/tables/{table_name}/partitions/{partition_tag}", - DropPartition, - PATH(String, table_name), - PATH(String, partition_tag)) { + ENDPOINT("DELETE", "/tables/{table_name}/partitions/{partition_tag}", DropPartition, + PATH(String, table_name), PATH(String, partition_tag)) { auto handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + std::shared_ptr response; auto status_dto = handler.DropPartition(table_name, partition_tag); - auto code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_204, status_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_204, status_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; + } + + ADD_CORS(VectorsOptions) + + ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) { + return createResponse(Status::CODE_204, "No Content"); } ENDPOINT_INFO(Insert) { @@ -540,32 +619,28 @@ class WebController : public oatpp::web::server::api::ApiController { info->addResponse(Status::CODE_404, "application/json"); } - ADD_CORS(VectorsOptions) - - ENDPOINT("OPTIONS", "/tables/{table_name}/vectors", VectorsOptions) { - return createResponse(Status::CODE_204, "No Content"); - } - ADD_CORS(Insert) - ENDPOINT("POST", - "/tables/{table_name}/vectors", - Insert, - PATH(String, table_name), - BODY_DTO(InsertRequestDto::ObjectWrapper, body)) { + ENDPOINT("POST", "/tables/{table_name}/vectors", Insert, + PATH(String, table_name), BODY_DTO(InsertRequestDto::ObjectWrapper, body)) { auto ids_dto = VectorIdsDto::createShared(); WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - auto status_dto = handler.Insert(table_name, body, ids_dto); - int64_t code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_201, ids_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + std::shared_ptr response; + auto status_dto = handler.Insert(table_name, body, ids_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, ids_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(Search) { @@ -582,23 +657,26 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(Search) - ENDPOINT("PUT", - "/tables/{table_name}/vectors", - Search, - PATH(String, table_name), - BODY_DTO(SearchRequestDto::ObjectWrapper, body)) { + ENDPOINT("PUT", "/tables/{table_name}/vectors", Search, + PATH(String, table_name), BODY_DTO(SearchRequestDto::ObjectWrapper, body)) { auto results_dto = TopkResultsDto::createShared(); WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); + + std::shared_ptr response; auto status_dto = handler.Search(table_name, body, results_dto); - int64_t code = status_dto->code->getValue(); - if (0 == code) { - return createDtoResponse(Status::CODE_200, results_dto); - } else if (StatusCode::TABLE_NOT_EXISTS == code) { - return createDtoResponse(Status::CODE_404, status_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, results_dto); + break; + case StatusCode::TABLE_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); } + + return response; } ENDPOINT_INFO(SystemMsg) { @@ -608,7 +686,6 @@ class WebController : public oatpp::web::server::api::ApiController { info->addResponse(Status::CODE_200, "application/json"); info->addResponse(Status::CODE_400, "application/json"); - info->addResponse(Status::CODE_404, "application/json"); } ADD_CORS(SystemMsg) @@ -618,13 +695,18 @@ class WebController : public oatpp::web::server::api::ApiController { WebRequestHandler handler = WebRequestHandler(); handler.RegisterRequestHandler(::milvus::server::RequestHandler()); - auto status_dto = handler.Cmd(msg, cmd_dto); - if (0 == status_dto->code->getValue()) { - return createDtoResponse(Status::CODE_200, cmd_dto); - } else { - return createDtoResponse(Status::CODE_400, status_dto); + std::shared_ptr response; + auto status_dto = handler.Cmd(msg, cmd_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_200, cmd_dto); + break; + default: + return createDtoResponse(Status::CODE_400, status_dto); } + + return response; } /** diff --git a/core/src/server/web_impl/dto/TableDto.hpp b/core/src/server/web_impl/dto/TableDto.hpp index 21b5e7bb..1573041a 100644 --- a/core/src/server/web_impl/dto/TableDto.hpp +++ b/core/src/server/web_impl/dto/TableDto.hpp @@ -59,14 +59,7 @@ class TableListFieldsDto : public OObject { DTO_INIT(TableListFieldsDto, Object) DTO_FIELD(List::ObjectWrapper, tables); - DTO_FIELD(Int64, count); -}; - -class TablesResponseDto : public OObject { - DTO_INIT(TablesResponseDto, Object) - - DTO_FIELD(TableListFieldsDto::ObjectWrapper, tables_fields); - DTO_FIELD(Int64, page_num); + DTO_FIELD(Int64, count) = 0; }; #include OATPP_CODEGEN_END(DTO) diff --git a/core/src/server/web_impl/dto/VectorDto.hpp b/core/src/server/web_impl/dto/VectorDto.hpp index 8dea7c0a..96fb469c 100644 --- a/core/src/server/web_impl/dto/VectorDto.hpp +++ b/core/src/server/web_impl/dto/VectorDto.hpp @@ -46,14 +46,15 @@ class SearchRequestDto : public OObject { DTO_FIELD(List::ObjectWrapper, tags); DTO_FIELD(List::ObjectWrapper, file_ids); DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, records); + DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, records_bin); }; - class InsertRequestDto : public oatpp::data::mapping::type::Object { DTO_INIT(InsertRequestDto, Object) DTO_FIELD(String, tag) = VALUE_PARTITION_TAG_DEFAULT; DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, records); + DTO_FIELD(List::ObjectWrapper>::ObjectWrapper, records_bin); DTO_FIELD(List::ObjectWrapper, ids); }; @@ -66,17 +67,10 @@ class VectorIdsDto : public oatpp::data::mapping::type::Object { class ResultDto : public oatpp::data::mapping::type::Object { DTO_INIT(ResultDto, Object) -// DTO_FIELD(Int64, num); DTO_FIELD(String, id); DTO_FIELD(String, dit, "distance"); }; -class RowResultsDto : public OObject { - DTO_INIT(RowResultsDto, Object) - -// DTO_FIELD(List::ObjectWrapper, ); -}; - class TopkResultsDto : public OObject { DTO_INIT(TopkResultsDto, Object); diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp index daef63cc..634e3ef1 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.cpp +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -17,19 +17,20 @@ #include "server/web_impl/handler/WebRequestHandler.h" -#include #include +#include #include #include #include "metrics/SystemInfo.h" -#include "utils/Log.h" - #include "server/Config.h" #include "server/delivery/request/BaseRequest.h" #include "server/web_impl/Constants.h" #include "server/web_impl/Types.h" #include "server/web_impl/dto/PartitionDto.hpp" +#include "server/web_impl/utils/Util.h" +#include "utils/StringHelpFunctions.h" +#include "utils/TimeRecorder.h" namespace milvus { namespace server { @@ -79,83 +80,9 @@ WebErrorMap(ErrorCode code) { } } -namespace { -Status -CopyRowRecords(const InsertRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) { - vectors.float_data_.clear(); - vectors.binary_data_.clear(); - vectors.id_array_.clear(); - vectors.vector_count_ = param->records->count(); - - // step 1: copy vector data - if (nullptr == param->records.get()) { - return Status(SERVER_INVALID_ROWRECORD_ARRAY, ""); - } - - size_t tal_size = 0; - for (int64_t i = 0; i < param->records->count(); i++) { - tal_size += param->records->get(i)->count(); - } - - std::vector& datas = vectors.float_data_; - datas.resize(tal_size); - size_t index_offset = 0; - param->records->forEach([&datas, &index_offset](const OList::ObjectWrapper& row_item) { - row_item->forEach([&datas, &index_offset](const OFloat32& item) { - datas[index_offset] = item->getValue(); - index_offset++; - }); - }); - - // step 2: copy id array - if (nullptr == param->ids.get()) { - return Status(SERVER_ILLEGAL_VECTOR_ID, ""); - } - - for (int64_t i = 0; i < param->ids->count(); i++) { - vectors.id_array_.emplace_back(param->ids->get(i)->getValue()); - } - - return Status::OK(); -} - -Status -CopyRowRecords(const SearchRequestDto::ObjectWrapper& param, engine::VectorsData& vectors) { - vectors.float_data_.clear(); - vectors.binary_data_.clear(); - vectors.id_array_.clear(); - vectors.vector_count_ = param->records->count(); - - // step 1: copy vector data - if (nullptr == param->records.get()) { - return Status(SERVER_INVALID_ROWRECORD_ARRAY, ""); - } - - size_t tal_size = 0; - for (int64_t i = 0; i < param->records->count(); i++) { - tal_size += param->records->get(i)->count(); - } - - std::vector& datas = vectors.float_data_; - datas.resize(tal_size); - size_t index_offset = 0; - param->records->forEach([&datas, &index_offset](const OList::ObjectWrapper& row_item) { - row_item->forEach([&datas, &index_offset](const OFloat32& item) { - datas[index_offset] = item->getValue(); - index_offset++; - }); - }); - - return Status::OK(); -} - -} // namespace - ///////////////////////// WebRequestHandler methods /////////////////////////////////////// - Status -WebRequestHandler::GetTaleInfo(const std::shared_ptr& context, const std::string& table_name, - std::map& table_info) { +WebRequestHandler::GetTableInfo(const std::string& table_name, TableFieldsDto::ObjectWrapper& table_fields) { TableSchema schema; auto status = request_handler_.DescribeTable(context_ptr_, table_name, schema); if (!status.ok()) { @@ -174,26 +101,28 @@ WebRequestHandler::GetTaleInfo(const std::shared_ptr& context, const st return status; } - table_info[KEY_TABLE_TABLE_NAME] = schema.table_name_; - table_info[KEY_TABLE_DIMENSION] = std::to_string(schema.dimension_); - table_info[KEY_TABLE_INDEX_METRIC_TYPE] = std::string(MetricMap.at(engine::MetricType(schema.metric_type_))); - table_info[KEY_TABLE_INDEX_FILE_SIZE] = std::to_string(schema.index_file_size_); - - table_info[KEY_INDEX_INDEX_TYPE] = std::string(IndexMap.at(engine::EngineType(index_param.index_type_))); - table_info[KEY_INDEX_NLIST] = std::to_string(index_param.nlist_); + table_fields->table_name = schema.table_name_.c_str(); + table_fields->dimension = schema.dimension_; + table_fields->index_file_size = schema.index_file_size_; + table_fields->index = IndexMap.at(engine::EngineType(index_param.index_type_)).c_str(); + table_fields->nlist = index_param.nlist_; + table_fields->metric_type = MetricMap.at(engine::MetricType(schema.metric_type_)).c_str(); + table_fields->count = count; +} - table_info[KEY_TABLE_COUNT] = std::to_string(count); +Status +WebRequestHandler::CommandLine(const std::string& cmd, std::string& reply) { + return request_handler_.Cmd(context_ptr_, cmd, reply); } /////////////////////////////////////////// Router methods //////////////////////////////////////////// StatusDto::ObjectWrapper WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) { - auto getgb = [](uint64_t x) -> uint64_t { return x / 1024 / 1024 / 1024; }; auto system_info = SystemInfo::GetInstance(); devices_dto->cpu = devices_dto->cpu->createShared(); - devices_dto->cpu->memory = getgb(system_info.GetPhysicalMemory()); + devices_dto->cpu->memory = system_info.GetPhysicalMemory() >> 30; devices_dto->gpus = devices_dto->gpus->createShared(); @@ -203,12 +132,12 @@ WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) { std::vector device_mems = system_info.GPUMemoryTotal(); if (count != device_mems.size()) { - ASSIGN_RETURN_STATUS_DTO(Status(UNEXPECTED_ERROR, "Can't obtain GPU info")); + RETURN_STATUS_DTO(UNEXPECTED_ERROR, "Can't obtain GPU info"); } for (size_t i = 0; i < count; i++) { auto device_dto = DeviceInfoDto::createShared(); - device_dto->memory = getgb(device_mems.at(i)); + device_dto->memory = device_mems.at(i) >> 30; devices_dto->gpus->put("GPU" + OString(std::to_string(i).c_str()), device_dto); } @@ -220,35 +149,39 @@ WebRequestHandler::GetDevices(DevicesDto::ObjectWrapper& devices_dto) { StatusDto::ObjectWrapper WebRequestHandler::GetAdvancedConfig(AdvancedConfigDto::ObjectWrapper& advanced_config) { Config& config = Config::GetInstance(); + std::string reply; + std::string cache_cmd_prefix = "get_config " + std::string(CONFIG_CACHE) + "."; - int64_t value; - auto status = config.GetCacheConfigCpuCacheCapacity(value); + std::string cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CPU_CACHE_CAPACITY); + auto status = CommandLine(cache_cmd_string, reply); if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status); + ASSIGN_RETURN_STATUS_DTO(status) } - advanced_config->cpu_cache_capacity = value; + advanced_config->cpu_cache_capacity = std::stol(reply); - bool ok; - status = config.GetCacheConfigCacheInsertData(ok); + cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CACHE_INSERT_DATA); + CommandLine(cache_cmd_string, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status) } - advanced_config->cache_insert_data = ok; + advanced_config->cache_insert_data = ("1" == reply || "true" == reply); - status = config.GetEngineConfigUseBlasThreshold(value); + auto engine_cmd_prefix = "get_config " + std::string(CONFIG_ENGINE) + "."; + + auto engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_USE_BLAS_THRESHOLD); + CommandLine(engine_cmd_string, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status) } - advanced_config->use_blas_threshold = value; + advanced_config->use_blas_threshold = std::stol(reply); #ifdef MILVUS_GPU_VERSION - - status = config.GetEngineConfigGpuSearchThreshold(value); + engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_GPU_SEARCH_THRESHOLD); + CommandLine(engine_cmd_string, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status) } - advanced_config->gpu_search_threshold = value; - + advanced_config->gpu_search_threshold = std::stol(reply); #endif ASSIGN_RETURN_STATUS_DTO(status) @@ -256,44 +189,57 @@ WebRequestHandler::GetAdvancedConfig(AdvancedConfigDto::ObjectWrapper& advanced_ StatusDto::ObjectWrapper WebRequestHandler::SetAdvancedConfig(const AdvancedConfigDto::ObjectWrapper& advanced_config) { - Config& config = Config::GetInstance(); - if (nullptr == advanced_config->cpu_cache_capacity.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cpu_cache_capacity\' miss."); } - auto status = - config.SetCacheConfigCpuCacheCapacity(std::to_string(advanced_config->cpu_cache_capacity->getValue())); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } if (nullptr == advanced_config->cache_insert_data.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_insert_data\' miss."); } - status = config.SetCacheConfigCacheInsertData(std::to_string(advanced_config->cache_insert_data->getValue())); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } if (nullptr == advanced_config->use_blas_threshold.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'use_blas_threshold\' miss."); } - status = config.SetEngineConfigUseBlasThreshold(std::to_string(advanced_config->use_blas_threshold->getValue())); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } #ifdef MILVUS_GPU_VERSION - if (nullptr == advanced_config->gpu_search_threshold.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'gpu_search_threshold\' miss."); } - status = - config.SetEngineConfigGpuSearchThreshold(std::to_string(advanced_config->gpu_search_threshold->getValue())); +#endif + + std::string reply; + std::string cache_cmd_prefix = "set_config " + std::string(CONFIG_CACHE) + "."; + + std::string cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CPU_CACHE_CAPACITY) + " " + + std::to_string(advanced_config->cpu_cache_capacity->getValue()); + auto status = CommandLine(cache_cmd_string, reply); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + cache_cmd_string = cache_cmd_prefix + std::string(CONFIG_CACHE_CACHE_INSERT_DATA) + " " + + std::to_string(advanced_config->cache_insert_data->getValue()); + status = CommandLine(cache_cmd_string, reply); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + auto engine_cmd_prefix = "set_config " + std::string(CONFIG_ENGINE) + "."; + + auto engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_USE_BLAS_THRESHOLD) + " " + + std::to_string(advanced_config->use_blas_threshold->getValue()); + status = CommandLine(engine_cmd_string, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status) } +#ifdef MILVUS_GPU_VERSION + engine_cmd_string = engine_cmd_prefix + std::string(CONFIG_ENGINE_GPU_SEARCH_THRESHOLD) + " " + + std::to_string(advanced_config->gpu_search_threshold->getValue()); + CommandLine(engine_cmd_string, reply); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } #endif ASSIGN_RETURN_STATUS_DTO(status) @@ -303,46 +249,52 @@ WebRequestHandler::SetAdvancedConfig(const AdvancedConfigDto::ObjectWrapper& adv StatusDto::ObjectWrapper WebRequestHandler::GetGpuConfig(GPUConfigDto::ObjectWrapper& gpu_config_dto) { - Config& config = Config::GetInstance(); + std::string reply; + std::string gpu_cmd_prefix = "get_config " + std::string(CONFIG_GPU_RESOURCE) + "."; - bool enable; - auto status = config.GetGpuResourceConfigEnable(enable); + std::string gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_ENABLE); + auto status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } - gpu_config_dto->enable = enable; + gpu_config_dto->enable = reply == "1" || reply == "true"; - if (!enable) { + if (!gpu_config_dto->enable->getValue()) { ASSIGN_RETURN_STATUS_DTO(Status::OK()); } - int64_t capacity; - status = config.GetGpuResourceConfigCacheCapacity(capacity); + gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_CACHE_CAPACITY); + status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } - gpu_config_dto->cache_capacity = capacity; + gpu_config_dto->cache_capacity = std::stol(reply); - std::vector values; - status = config.GetGpuResourceConfigSearchResources(values); + gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_SEARCH_RESOURCES); + status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } + std::vector gpu_entry; + StringHelpFunctions::SplitStringByDelimeter(reply, ",", gpu_entry); + gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared(); - for (auto& device_id : values) { - gpu_config_dto->search_resources->pushBack("GPU" + OString(std::to_string(device_id).c_str())); + for (auto& device_id : gpu_entry) { + gpu_config_dto->search_resources->pushBack(OString(device_id.c_str())->toUpperCase()); } + gpu_entry.clear(); - values.clear(); - status = config.GetGpuResourceConfigBuildIndexResources(values); + gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES); + status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } + StringHelpFunctions::SplitStringByDelimeter(reply, ",", gpu_entry); gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared(); - for (auto& device_id : values) { - gpu_config_dto->build_index_resources->pushBack("GPU" + OString(std::to_string(device_id).c_str())); + for (auto& device_id : gpu_entry) { + gpu_config_dto->build_index_resources->pushBack(OString(device_id.c_str())->toUpperCase()); } ASSIGN_RETURN_STATUS_DTO(Status::OK()); @@ -354,31 +306,44 @@ WebRequestHandler::GetGpuConfig(GPUConfigDto::ObjectWrapper& gpu_config_dto) { StatusDto::ObjectWrapper WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dto) { - Config& config = Config::GetInstance(); - + // Step 1: Check config param if (nullptr == gpu_config_dto->enable.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'enable\' miss") } - auto status = config.SetGpuResourceConfigEnable(std::to_string(gpu_config_dto->enable->getValue())); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status); + + if (nullptr == gpu_config_dto->cache_capacity.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_capacity\' miss") } - if (!gpu_config_dto->enable->getValue()) { - RETURN_STATUS_DTO(SUCCESS, "Set Gpu resources false"); + if (nullptr == gpu_config_dto->search_resources.get()) { + gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared(); + gpu_config_dto->search_resources->pushBack("GPU0"); } - if (nullptr == gpu_config_dto->cache_capacity.get()) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'cache_capacity\' miss") + if (nullptr == gpu_config_dto->build_index_resources.get()) { + gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared(); + gpu_config_dto->build_index_resources->pushBack("GPU0"); } - status = config.SetGpuResourceConfigCacheCapacity(std::to_string(gpu_config_dto->cache_capacity->getValue())); + + // Step 2: Set config + std::string reply; + std::string gpu_cmd_prefix = "set_config " + std::string(CONFIG_GPU_RESOURCE) + "."; + std::string gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_ENABLE) + " " + + std::to_string(gpu_config_dto->enable->getValue()); + auto status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } - if (nullptr == gpu_config_dto->search_resources.get()) { - gpu_config_dto->search_resources = gpu_config_dto->search_resources->createShared(); - gpu_config_dto->search_resources->pushBack("GPU0"); + if (!gpu_config_dto->enable->getValue()) { + RETURN_STATUS_DTO(SUCCESS, "Set Gpu resources to false"); + } + + gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_CACHE_CAPACITY) + " " + + std::to_string(gpu_config_dto->cache_capacity->getValue()); + status = CommandLine(gpu_cmd_request, reply); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status); } std::vector search_resources; @@ -393,15 +358,13 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt if (len > 0) { search_resources_value.erase(len - 1); } - status = config.SetGpuResourceConfigSearchResources(search_resources_value); + + gpu_cmd_request = gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_SEARCH_RESOURCES) + " " + search_resources_value; + status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } - if (nullptr == gpu_config_dto->build_index_resources.get()) { - gpu_config_dto->build_index_resources = gpu_config_dto->build_index_resources->createShared(); - gpu_config_dto->build_index_resources->pushBack("GPU0"); - } std::vector build_resources; gpu_config_dto->build_index_resources->forEach( [&build_resources](const OString& res) { build_resources.emplace_back(res->toLowerCase()->std_str()); }); @@ -415,7 +378,9 @@ WebRequestHandler::SetGpuConfig(const GPUConfigDto::ObjectWrapper& gpu_config_dt build_resources_value.erase(len - 1); } - status = config.SetGpuResourceConfigBuildIndexResources(build_resources_value); + gpu_cmd_request = + gpu_cmd_prefix + std::string(CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES) + " " + build_resources_value; + status = CommandLine(gpu_cmd_request, reply); if (!status.ok()) { ASSIGN_RETURN_STATUS_DTO(status); } @@ -461,74 +426,62 @@ WebRequestHandler::GetTable(const OString& table_name, const OQueryParams& query RETURN_STATUS_DTO(PATH_PARAM_LOSS, "Path param \'table_name\' is required!"); } - Status status = Status::OK(); - // TODO: query string field `fields` npt used here - std::map table_info; - status = GetTaleInfo(context_ptr_, table_name->std_str(), table_info); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } - - fields_dto->table_name = table_info[KEY_TABLE_TABLE_NAME].c_str(); - fields_dto->dimension = std::stol(table_info[KEY_TABLE_DIMENSION]); - fields_dto->index = table_info[KEY_INDEX_INDEX_TYPE].c_str(); - fields_dto->nlist = std::stol(table_info[KEY_INDEX_NLIST]); - fields_dto->metric_type = table_info[KEY_TABLE_INDEX_METRIC_TYPE].c_str(); - fields_dto->index_file_size = std::stol(table_info[KEY_TABLE_INDEX_FILE_SIZE]); - fields_dto->count = std::stol(table_info[KEY_TABLE_COUNT]); + auto status = GetTableInfo(table_name->std_str(), fields_dto); ASSIGN_RETURN_STATUS_DTO(status); } StatusDto::ObjectWrapper -WebRequestHandler::ShowTables(const OInt64& offset, const OInt64& page_size, +WebRequestHandler::ShowTables(const OString& offset, const OString& page_size, TableListFieldsDto::ObjectWrapper& response_dto) { - if (nullptr == offset.get()) { - RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'offset\' is required"); + int64_t offset_value = 0; + int64_t page_size_value = 10; + + if (nullptr != offset.get()) { + try { + offset_value = std::stol(offset->std_str()); + } catch (const std::exception& e) { + RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, "Query param \'offset\' is illegal, only type of \'int\' allowed"); + } } - if (nullptr == page_size.get()) { - RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'page_size\' is required"); + if (nullptr != page_size.get()) { + try { + page_size_value = std::stol(page_size->std_str()); + } catch (const std::exception& e) { + RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, + "Query param \'page_size\' is illegal, only type of \'int\' allowed"); + } + } + + if (offset_value < 0 || page_size_value < 0) { + RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, "Query param 'offset' or 'page_size' should equal or bigger than 0"); } + std::vector tables; - Status status = Status::OK(); + auto status = request_handler_.ShowTables(context_ptr_, tables); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } response_dto->tables = response_dto->tables->createShared(); - if (offset < 0 || page_size < 0) { - ASSIGN_RETURN_STATUS_DTO( - Status(SERVER_UNEXPECTED_ERROR, "Query param 'offset' or 'page_size' should bigger than 0")); - } else { - status = request_handler_.ShowTables(context_ptr_, tables); + if (offset_value >= tables.size()) { + ASSIGN_RETURN_STATUS_DTO(Status::OK()); + } + + response_dto->count = tables.size(); + + int64_t size = page_size_value + offset_value > tables.size() ? tables.size() - offset_value : page_size_value; + for (int64_t i = offset_value; i < size + offset_value; i++) { + auto table_fields_dto = TableFieldsDto::createShared(); + status = GetTableInfo(tables.at(i), table_fields_dto); if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } - if (offset < tables.size()) { - int64_t size = (page_size->getValue() + offset->getValue() > tables.size()) ? tables.size() - offset - : page_size->getValue(); - for (int64_t i = offset->getValue(); i < size + offset->getValue(); i++) { - std::map table_info; - - status = GetTaleInfo(context_ptr_, tables.at(i), table_info); - if (!status.ok()) { - break; - } - - auto table_fields_dto = TableFieldsDto::createShared(); - table_fields_dto->table_name = table_info[KEY_TABLE_TABLE_NAME].c_str(); - table_fields_dto->dimension = std::stol(table_info[std::string(KEY_TABLE_DIMENSION)]); - table_fields_dto->index_file_size = std::stol(table_info[std::string(KEY_TABLE_INDEX_FILE_SIZE)]); - table_fields_dto->index = table_info[KEY_INDEX_INDEX_TYPE].c_str(); - table_fields_dto->nlist = std::stol(table_info[KEY_INDEX_NLIST]); - table_fields_dto->metric_type = table_info[KEY_TABLE_INDEX_METRIC_TYPE].c_str(); - table_fields_dto->count = std::stol(table_info[KEY_TABLE_COUNT]); - - response_dto->tables->pushBack(table_fields_dto); - } - - response_dto->count = tables.size(); + break; } + + response_dto->tables->pushBack(table_fields_dto); } ASSIGN_RETURN_STATUS_DTO(status) @@ -598,31 +551,50 @@ WebRequestHandler::CreatePartition(const OString& table_name, const PartitionReq } StatusDto::ObjectWrapper -WebRequestHandler::ShowPartitions(const OInt64& offset, const OInt64& page_size, const OString& table_name, +WebRequestHandler::ShowPartitions(const OString& offset, const OString& page_size, const OString& table_name, PartitionListDto::ObjectWrapper& partition_list_dto) { - if (nullptr == offset.get()) { - RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'offset\' is required!"); + int64_t offset_value = 0; + int64_t page_size_value = 10; + + if (nullptr != offset.get()) { + try { + offset_value = std::stol(offset->std_str()); + } catch (const std::exception& e) { + std::string msg = "Query param \'offset\' is illegal. Reason: " + std::string(e.what()); + RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, msg.c_str()); + } } - if (nullptr == page_size.get()) { - RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'page_size\' is required!"); + if (nullptr != page_size.get()) { + try { + page_size_value = std::stol(page_size->std_str()); + } catch (const std::exception& e) { + std::string msg = "Query param \'page_size\' is illegal. Reason: " + std::string(e.what()); + RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, msg.c_str()); + } + } + + if (offset_value < 0 || page_size_value < 0) { + ASSIGN_RETURN_STATUS_DTO( + Status(SERVER_UNEXPECTED_ERROR, "Query param 'offset' or 'page_size' should equal or bigger than 0")); } std::vector partitions; auto status = request_handler_.ShowPartitions(context_ptr_, table_name->std_str(), partitions); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } - if (status.ok()) { - partition_list_dto->partitions = partition_list_dto->partitions->createShared(); - - if (offset->getValue() < partitions.size()) { - int64_t size = (offset->getValue() + page_size->getValue() > partitions.size()) ? partitions.size() - offset - : page_size->getValue(); - for (int64_t i = offset->getValue(); i < size + offset->getValue(); i++) { - auto partition_dto = PartitionFieldsDto::createShared(); - partition_dto->partition_name = partitions.at(i).partition_name_.c_str(); - partition_dto->partition_tag = partitions.at(i).tag_.c_str(); - partition_list_dto->partitions->pushBack(partition_dto); - } + partition_list_dto->partitions = partition_list_dto->partitions->createShared(); + + if (offset_value < partitions.size()) { + int64_t size = + offset_value + page_size_value > partitions.size() ? partitions.size() - offset_value : page_size_value; + for (int64_t i = offset_value; i < size + offset_value; i++) { + auto partition_dto = PartitionFieldsDto::createShared(); + partition_dto->partition_name = partitions.at(i).partition_name_.c_str(); + partition_dto->partition_tag = partitions.at(i).tag_.c_str(); + partition_list_dto->partitions->pushBack(partition_dto); } } @@ -637,15 +609,47 @@ WebRequestHandler::DropPartition(const OString& table_name, const OString& tag) } StatusDto::ObjectWrapper -WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& param, +WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::ObjectWrapper& request, VectorIdsDto::ObjectWrapper& ids_dto) { + TableSchema schema; + auto status = request_handler_.DescribeTable(context_ptr_, table_name->std_str(), schema); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + auto metric = engine::MetricType(schema.metric_type_); engine::VectorsData vectors; - auto status = CopyRowRecords(param, vectors); - if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors") + bool bin_flag = engine::MetricType::HAMMING == metric || engine::MetricType::JACCARD == metric || + engine::MetricType::TANIMOTO == metric; + + if (!bin_flag) { + if (nullptr == request->records.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors"); + } + vectors.vector_count_ = request->records->count(); + status = CopyRowRecords(request->records, vectors.float_data_); + } else { + if (nullptr == request->records_bin.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records_bin\' is required to fill vectors"); + } + vectors.vector_count_ = request->records_bin->count(); + status = CopyBinRowRecords(request->records_bin, vectors.binary_data_); + } + + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + // step 2: copy id array + if (nullptr != request->ids.get()) { + auto& id_array = vectors.id_array_; + id_array.resize(request->ids->count()); + + size_t i = 0; + request->ids->forEach([&id_array, &i](const OInt64& item) { id_array[i++] = item->getValue(); }); } - status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, param->tag->std_str()); + status = request_handler_.Insert(context_ptr_, table_name->std_str(), vectors, request->tag->std_str()); if (status.ok()) { ids_dto->ids = ids_dto->ids->createShared(); @@ -658,42 +662,58 @@ WebRequestHandler::Insert(const OString& table_name, const InsertRequestDto::Obj } StatusDto::ObjectWrapper -WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::ObjectWrapper& search_request, +WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::ObjectWrapper& request, TopkResultsDto::ObjectWrapper& results_dto) { - if (nullptr == search_request->topk.get()) { + if (nullptr == request->topk.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'topk\' is required in request body") } - int64_t topk_t = search_request->topk->getValue(); + int64_t topk_t = request->topk->getValue(); - if (nullptr == search_request->nprobe.get()) { + if (nullptr == request->nprobe.get()) { RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'nprobe\' is required in request body") } - int64_t nprobe_t = search_request->nprobe->getValue(); + int64_t nprobe_t = request->nprobe->getValue(); std::vector tag_list; - std::vector file_id_list; - - if (nullptr != search_request->tags.get()) { - search_request->tags->forEach([&tag_list](const OString& tag) { tag_list.emplace_back(tag->std_str()); }); + if (nullptr != request->tags.get()) { + request->tags->forEach([&tag_list](const OString& tag) { tag_list.emplace_back(tag->std_str()); }); } - if (nullptr != search_request->file_ids.get()) { - search_request->file_ids->forEach( - [&file_id_list](const OString& id) { file_id_list.emplace_back(id->std_str()); }); + std::vector file_id_list; + if (nullptr != request->file_ids.get()) { + request->file_ids->forEach([&file_id_list](const OString& id) { file_id_list.emplace_back(id->std_str()); }); } - if (nullptr == search_request->records.get()) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill query vectors") + TableSchema schema; + auto status = request_handler_.DescribeTable(context_ptr_, table_name->std_str(), schema); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) } + auto metric = engine::MetricType(schema.metric_type_); + bool bin_flag = engine::MetricType::HAMMING == metric || engine::MetricType::JACCARD == metric || + engine::MetricType::TANIMOTO == metric; engine::VectorsData vectors; - auto status = CopyRowRecords(search_request, vectors); - if (status.code() == SERVER_INVALID_ROWRECORD_ARRAY) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors") + + if (!bin_flag) { + if (nullptr == request->records.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records\' is required to fill vectors"); + } + vectors.vector_count_ = request->records->count(); + status = CopyRowRecords(request->records, vectors.float_data_); + } else { + if (nullptr == request->records_bin.get()) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'records_bin\' is required to fill vectors"); + } + vectors.vector_count_ = request->records_bin->count(); + status = CopyBinRowRecords(request->records_bin, vectors.binary_data_); } - std::vector range_list; + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + std::vector range_list; TopKQueryResult result; auto context_ptr = GenContextPtr("Web Handler"); status = request_handler_.Search(context_ptr, table_name->std_str(), vectors, range_list, topk_t, nprobe_t, @@ -725,8 +745,14 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj StatusDto::ObjectWrapper WebRequestHandler::Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto) { + std::string info = cmd->std_str(); + + if ("info" == info) { + info = "get_system_info"; + } + std::string reply_str; - auto status = request_handler_.Cmd(context_ptr_, cmd->std_str(), reply_str); + auto status = CommandLine(info, reply_str); if (status.ok()) { cmd_dto->reply = reply_str.c_str(); diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h index af645531..a28d4d25 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.h +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -23,10 +23,9 @@ #include #include -#include - #include #include +#include #include "server/web_impl/Types.h" #include "server/web_impl/dto/CmdDto.hpp" @@ -37,6 +36,7 @@ #include "server/web_impl/dto/TableDto.hpp" #include "server/web_impl/dto/VectorDto.hpp" +#include "db/Types.h" #include "server/context/Context.h" #include "server/delivery/RequestHandler.h" #include "utils/Status.h" @@ -82,15 +82,18 @@ class WebRequestHandler { return context_ptr; } + protected: + Status + GetTableInfo(const std::string& table_name, TableFieldsDto::ObjectWrapper& table_fields); + + Status + CommandLine(const std::string& cmd, std::string& reply); + public: WebRequestHandler() { context_ptr_ = GenContextPtr("Web Handler"); } - Status - GetTaleInfo(const std::shared_ptr& context, const std::string& table_name, - std::map& table_info); - StatusDto::ObjectWrapper GetDevices(DevicesDto::ObjectWrapper& devices); @@ -115,7 +118,7 @@ class WebRequestHandler { GetTable(const OString& table_name, const OQueryParams& query_params, TableFieldsDto::ObjectWrapper& schema_dto); StatusDto::ObjectWrapper - ShowTables(const OInt64& offset, const OInt64& page_size, TableListFieldsDto::ObjectWrapper& table_list_dto); + ShowTables(const OString& offset, const OString& page_size, TableListFieldsDto::ObjectWrapper& table_list_dto); StatusDto::ObjectWrapper DropTable(const OString& table_name); @@ -133,7 +136,7 @@ class WebRequestHandler { CreatePartition(const OString& table_name, const PartitionRequestDto::ObjectWrapper& param); StatusDto::ObjectWrapper - ShowPartitions(const OInt64& offset, const OInt64& page_size, const OString& table_name, + ShowPartitions(const OString& offset, const OString& page_size, const OString& table_name, PartitionListDto::ObjectWrapper& partition_list_dto); StatusDto::ObjectWrapper @@ -150,6 +153,7 @@ class WebRequestHandler { StatusDto::ObjectWrapper Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto); + public: WebRequestHandler& RegisterRequestHandler(const RequestHandler& handler) { request_handler_ = handler; diff --git a/core/src/server/web_impl/utils/Util.cpp b/core/src/server/web_impl/utils/Util.cpp new file mode 100644 index 00000000..5f988214 --- /dev/null +++ b/core/src/server/web_impl/utils/Util.cpp @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "server/web_impl/utils/Util.h" + +namespace milvus { +namespace server { +namespace web { + +Status +CopyRowRecords(const OList::ObjectWrapper>::ObjectWrapper& records, std::vector& vectors) { + size_t tal_size = 0; + records->forEach([&tal_size](const OList::ObjectWrapper& row_item) { tal_size += row_item->count(); }); + + vectors.resize(tal_size); + size_t index_offset = 0; + records->forEach([&vectors, &index_offset](const OList::ObjectWrapper& row_item) { + row_item->forEach( + [&vectors, &index_offset](const OFloat32& item) { vectors[index_offset++] = item->getValue(); }); + }); + + return Status::OK(); +} + +Status +CopyBinRowRecords(const OList::ObjectWrapper>::ObjectWrapper& records, std::vector& vectors) { + size_t tal_size = 0; + records->forEach([&tal_size](const OList::ObjectWrapper& item) { tal_size += item->count(); }); + + vectors.resize(tal_size); + size_t index_offset = 0; + bool oor = false; + records->forEach([&vectors, &index_offset, &oor](const OList::ObjectWrapper& row_item) { + row_item->forEach([&vectors, &index_offset, &oor](const OInt64& item) { + if (!oor) { + int64_t value = item->getValue(); + if (0 > value || value > 255) { + oor = true; + } else { + vectors[index_offset++] = static_cast(value); + } + } + }); + }); + + return Status::OK(); +} + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/utils/Util.h b/core/src/server/web_impl/utils/Util.h new file mode 100644 index 00000000..ff875494 --- /dev/null +++ b/core/src/server/web_impl/utils/Util.h @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "db/Types.h" +#include "server/web_impl/Types.h" +#include "utils/Status.h" + +namespace milvus { +namespace server { +namespace web { + +Status +CopyRowRecords(const OList::ObjectWrapper>::ObjectWrapper& records, std::vector& vectors); + +Status +CopyBinRowRecords(const OList::ObjectWrapper>::ObjectWrapper& records, std::vector& vectors); + +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/utils/ValidationUtil.cpp b/core/src/utils/ValidationUtil.cpp index 3da1f640..02ba9c70 100644 --- a/core/src/utils/ValidationUtil.cpp +++ b/core/src/utils/ValidationUtil.cpp @@ -194,7 +194,33 @@ ValidationUtil::ValidatePartitionName(const std::string& partition_name) { return Status(SERVER_INVALID_TABLE_NAME, msg); } - return ValidateTableName(partition_name); + std::string invalid_msg = "Invalid partition name: " + partition_name + ". "; + // Table name size shouldn't exceed 16384. + if (partition_name.size() > TABLE_NAME_SIZE_LIMIT) { + std::string msg = invalid_msg + "The length of a partition name must be less than 255 characters."; + SERVER_LOG_ERROR << msg; + return Status(SERVER_INVALID_TABLE_NAME, msg); + } + + // Table name first character should be underscore or character. + char first_char = partition_name[0]; + if (first_char != '_' && std::isalpha(first_char) == 0) { + std::string msg = invalid_msg + "The first character of a partition name must be an underscore or letter."; + SERVER_LOG_ERROR << msg; + return Status(SERVER_INVALID_TABLE_NAME, msg); + } + + int64_t table_name_size = partition_name.size(); + for (int64_t i = 1; i < table_name_size; ++i) { + char name_char = partition_name[i]; + if (name_char != '_' && std::isalnum(name_char) == 0) { + std::string msg = invalid_msg + "Partition name can only contain numbers, letters, and underscores."; + SERVER_LOG_ERROR << msg; + return Status(SERVER_INVALID_TABLE_NAME, msg); + } + } + + return Status::OK(); } Status @@ -207,7 +233,7 @@ ValidationUtil::ValidatePartitionTags(const std::vector& partition_ if (valid_tag.empty()) { std::string msg = "Invalid partition tag: " + valid_tag + ". " + "Partition tag should not be empty."; SERVER_LOG_ERROR << msg; - return Status(SERVER_INVALID_NPROBE, msg); + return Status(SERVER_INVALID_TABLE_NAME, msg); } } diff --git a/core/unittest/CMakeLists.txt b/core/unittest/CMakeLists.txt index 9103c927..fae84319 100644 --- a/core/unittest/CMakeLists.txt +++ b/core/unittest/CMakeLists.txt @@ -78,12 +78,14 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/handler web_handler_fi aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/component web_conponent_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/controller web_controller_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/dto web_dto_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl/utils web_utils_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/web_impl web_impl_files) set(web_server_files ${web_handler_files} ${web_conponent_files} ${web_controller_files} ${web_dto_files} + ${web_utils_files} ${web_impl_files} ) diff --git a/core/unittest/server/test_web.cpp b/core/unittest/server/test_web.cpp index 1624c7e1..76cde3ab 100644 --- a/core/unittest/server/test_web.cpp +++ b/core/unittest/server/test_web.cpp @@ -53,6 +53,8 @@ #include "server/DBWrapper.h" #include "utils/CommonUtil.h" +#include "unittest/server/utils.h" + static const char* TABLE_NAME = "test_web"; static constexpr int64_t TABLE_DIM = 256; static constexpr int64_t INDEX_FILE_SIZE = 1024; @@ -66,6 +68,7 @@ using OQueryParams = milvus::server::web::OQueryParams; using OChunkedBuffer = oatpp::data::stream::ChunkedBuffer; using OOutputStream = oatpp::data::stream::BufferOutputStream; using OFloat32 = milvus::server::web::OFloat32; +using OInt64 = milvus::server::web::OInt64; template using OList = milvus::server::web::OList; @@ -86,6 +89,19 @@ RandomRowRecordDto(int64_t dim) { return row_record_dto; } +OList::ObjectWrapper +RandomBinRowRecordDto(int64_t dim) { + auto row_record_dto = OList::createShared(); + + std::default_random_engine e; + std::uniform_real_distribution u(0, 255); + for (size_t i = 0; i < dim / 8; i++) { + row_record_dto->pushBack(static_cast(u(e))); + } + + return row_record_dto; +} + OList::ObjectWrapper>::ObjectWrapper RandomRecordsDto(int64_t dim, int64_t num) { auto records_dto = OList::ObjectWrapper>::createShared(); @@ -96,6 +112,16 @@ RandomRecordsDto(int64_t dim, int64_t num) { return records_dto; } +OList::ObjectWrapper>::ObjectWrapper +RandomBinRecordsDto(int64_t dim, int64_t num) { + auto records_dto = OList::ObjectWrapper>::createShared(); + for (size_t i = 0; i < num; i++) { + records_dto->pushBack(RandomBinRowRecordDto(dim)); + } + + return records_dto; +} + std::string RandomName() { unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); @@ -281,7 +307,7 @@ TEST_F(WebHandlerTest, INSERT_COUNT) { ASSERT_EQ(0, status_dto->code->getValue()); ASSERT_EQ(1000, ids_dto->ids->count()); - sleep(8); + sleep(2); milvus::server::web::OQueryParams query_params; query_params.put("fields", "num"); @@ -344,7 +370,7 @@ TEST_F(WebHandlerTest, PARTITION) { ASSERT_EQ(StatusCode::ILLEGAL_TABLE_NAME, status_dto->code->getValue()); auto partitions_dto = milvus::server::web::PartitionListDto::createShared(); - status_dto = handler->ShowPartitions(0, 10, table_name, partitions_dto); + status_dto = handler->ShowPartitions("0", "10", table_name, partitions_dto); ASSERT_EQ(1, partitions_dto->partitions->count()); status_dto = handler->DropPartition(table_name, "test"); @@ -352,7 +378,7 @@ TEST_F(WebHandlerTest, PARTITION) { // Show all partitions partitions_dto = milvus::server::web::PartitionListDto::createShared(); - status_dto = handler->ShowPartitions(0, 10, table_name, partitions_dto); + status_dto = handler->ShowPartitions("0", "10", table_name, partitions_dto); } TEST_F(WebHandlerTest, SEARCH) { @@ -400,7 +426,54 @@ TEST_F(WebHandlerTest, CMD) { /////////////////////////////////////////////////////////////////////////////////////// namespace { +static const char* CONTROLLER_TEST_VALID_CONFIG_STR = + "# Default values are used when you make no changes to the following parameters.\n" + "\n" + "version: 0.1" + "\n" + "server_config:\n" + " address: 0.0.0.0 # milvus server ip address (IPv4)\n" + " port: 19530 # port range: 1025 ~ 65534\n" + " deploy_mode: single \n" + " time_zone: UTC+8\n" + "\n" + "db_config:\n" + " backend_url: sqlite://:@:/ \n" + "\n" + " insert_buffer_size: 4 # GB, maximum insert buffer size allowed\n" + " preload_table: \n" + "\n" + "storage_config:\n" + " primary_path: /tmp/milvus_web_controller_test # path used to store data and meta\n" + " secondary_path: # path used to store data only, split by semicolon\n" + "\n" + "metric_config:\n" + " enable_monitor: false # enable monitoring or not\n" + " address: 127.0.0.1\n" + " port: 8080 # port prometheus uses to fetch metrics\n" + "\n" + "cache_config:\n" + " cpu_cache_capacity: 4 # GB, CPU memory used for cache\n" + " cpu_cache_threshold: 0.85 \n" + " cache_insert_data: false # whether to load inserted data into cache\n" + "\n" + "engine_config:\n" + " use_blas_threshold: 20 \n" + "\n" + #ifdef MILVUS_GPU_VERSION + "gpu_resource_config:\n" + " enable: true # whether to enable GPU resources\n" + " cache_capacity: 4 # GB, size of GPU memory per card used for cache, must be a positive integer\n" + " search_resources: # define the GPU devices used for search computation, must be in format gpux\n" + " - gpu0\n" + " build_index_resources: # define the GPU devices used for index building, must be in format gpux\n" + " - gpu0\n" + #endif + "\n"; + static const char* CONTROLLER_TEST_TABLE_NAME = "controller_unit_test"; +static const char* CONTROLLER_TEST_CONFIG_DIR = "/tmp/milvus_web_controller_test/"; +static const char* CONTROLLER_TEST_CONFIG_FILE = "config.yaml"; class TestClient : public oatpp::web::client::ApiClient { public: @@ -445,11 +518,8 @@ class TestClient : public oatpp::web::client::ApiClient { API_CALL("OPTIONS", "/tables/{table_name}/indexes", optionsIndexes, PATH(String, table_name, "table_name")) - API_CALL("POST", - "/tables/{table_name}/indexes", - createIndex, - PATH(String, table_name, "table_name"), - BODY_DTO(milvus::server::web::IndexRequestDto::ObjectWrapper, body)) + API_CALL("POST", "/tables/{table_name}/indexes",createIndex, + PATH(String, table_name, "table_name"), BODY_DTO(milvus::server::web::IndexRequestDto::ObjectWrapper, body)) API_CALL("GET", "/tables/{table_name}/indexes", getIndex, PATH(String, table_name, "table_name")) @@ -505,6 +575,15 @@ class WebControllerTest : public testing::Test { protected: static void SetUpTestCase() { + // Basic config + std::string config_path = std::string(CONTROLLER_TEST_CONFIG_DIR).append(CONTROLLER_TEST_CONFIG_FILE); + std::fstream fs(config_path.c_str(), std::ios_base::out); + fs << CONTROLLER_TEST_VALID_CONFIG_STR; + fs.close(); + + milvus::server::Config& config = milvus::server::Config::GetInstance(); + config.LoadConfigFile(config_path); + auto res_mgr = milvus::scheduler::ResMgrInst::GetInstance(); res_mgr->Clear(); res_mgr->Add(milvus::scheduler::ResourceFactory::Create("disk", "DISK", 0, false)); @@ -522,13 +601,8 @@ class WebControllerTest : public testing::Test { milvus::engine::DBOptions opt; milvus::server::Config::GetInstance().SetDBConfigBackendUrl("sqlite://:@:/"); - boost::filesystem::remove_all("/tmp/milvus_web_controller_test"); - milvus::server::Config::GetInstance().SetStorageConfigPrimaryPath("/tmp/milvus_web_controller_test"); - milvus::server::Config::GetInstance().SetStorageConfigSecondaryPath(""); - milvus::server::Config::GetInstance().SetDBConfigArchiveDiskThreshold(""); - milvus::server::Config::GetInstance().SetDBConfigArchiveDaysThreshold(""); - milvus::server::Config::GetInstance().SetCacheConfigCacheInsertData(""); - milvus::server::Config::GetInstance().SetEngineConfigOmpThreadNum(""); + boost::filesystem::remove_all(CONTROLLER_TEST_CONFIG_DIR); + milvus::server::Config::GetInstance().SetStorageConfigPrimaryPath(CONTROLLER_TEST_CONFIG_DIR); milvus::server::DBWrapper::GetInstance().StartService(); @@ -547,7 +621,7 @@ class WebControllerTest : public testing::Test { milvus::scheduler::JobMgrInst::GetInstance()->Stop(); milvus::scheduler::ResMgrInst::GetInstance()->Stop(); milvus::scheduler::SchedInst::GetInstance()->Stop(); - boost::filesystem::remove_all("/tmp/milvus_web_controller_test"); + boost::filesystem::remove_all(CONTROLLER_TEST_CONFIG_DIR); } void @@ -638,16 +712,16 @@ TEST_F(WebControllerTest, CREATE_TABLE) { auto table_dto = milvus::server::web::TableRequestDto::createShared(); auto response = client_ptr->createTable(table_dto, conncetion_ptr); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); - auto result_dto = response->readBodyToDto(object_mapper.get()); - ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code) << result_dto->message->std_str(); + auto error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code) << error_dto->message->std_str(); OString table_name = "web_test_create_table" + OString(RandomName().c_str()); table_dto->table_name = table_name; response = client_ptr->createTable(table_dto, conncetion_ptr); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); - result_dto = response->readBodyToDto(object_mapper.get()); - ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code) << result_dto->message->std_str(); + error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code) << error_dto->message->std_str(); table_dto->dimension = 128; table_dto->index_file_size = 10; @@ -655,6 +729,8 @@ TEST_F(WebControllerTest, CREATE_TABLE) { response = client_ptr->createTable(table_dto, conncetion_ptr); ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, result_dto->code->getValue()) << result_dto->message->std_str(); // invalid table name table_dto->table_name = "9090&*&()"; @@ -671,29 +747,32 @@ TEST_F(WebControllerTest, GET_TABLE) { // fields value is 'num', test count table params.put("fields", "num"); auto response = client_ptr->getTable(table_name, conncetion_ptr); - ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); auto result_dto = response->readBodyToDto(object_mapper.get()); - - response = client_ptr->getTable(table_name, conncetion_ptr); - ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + ASSERT_EQ(table_name->std_str(), result_dto->table_name->std_str()); + ASSERT_EQ(10, result_dto->dimension); + ASSERT_EQ("L2", result_dto->metric_type->std_str()); + ASSERT_EQ(10, result_dto->index_file_size->getValue()); + ASSERT_EQ("FLAT", result_dto->index->std_str()); // invalid table name table_name = "57474dgdfhdfhdh dgd"; response = client_ptr->getTable(table_name, conncetion_ptr); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); auto status_sto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::ILLEGAL_TABLE_NAME, status_sto->code->getValue()); - table_name = "test_table_not_found_0000000001110101010020202030203030435"; + table_name = "test_table_not_found_000000000111010101002020203020aaaaa3030435"; response = client_ptr->getTable(table_name, conncetion_ptr); ASSERT_EQ(OStatus::CODE_404.code, response->getStatusCode()); - status_sto = response->readBodyToDto(object_mapper.get()); } TEST_F(WebControllerTest, SHOW_TABLES) { // test query table limit 1 auto response = client_ptr->showTables(1, 1, conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_TRUE(result_dto->count->getValue() > 0); // test query table empty response = client_ptr->showTables(0, 0, conncetion_ptr); @@ -734,6 +813,24 @@ TEST_F(WebControllerTest, INSERT) { ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); } +TEST_F(WebControllerTest, INSERT_BIN) { + auto table_name = "test_insert_bin_table_test" + OString(RandomName().c_str()); + const int64_t dim = 64; + GenTable(table_name, dim, 100, "HAMMING"); + + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_dto->ids = insert_dto->ids->createShared(); + insert_dto->records_bin = RandomBinRecordsDto(dim, 20); + + auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(20, result_dto->ids->count()); + + response = client_ptr->dropTable(table_name, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); +} + TEST_F(WebControllerTest, INSERT_IDS) { auto table_name = "test_insert_table_test" + OString(RandomName().c_str()); const int64_t dim = 64; @@ -764,6 +861,9 @@ TEST_F(WebControllerTest, INDEX) { auto index_dto = milvus::server::web::IndexRequestDto::createShared(); auto response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto create_index_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, create_index_dto->code); + // drop index response = client_ptr->dropIndex(table_name, conncetion_ptr); ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); @@ -795,7 +895,6 @@ TEST_F(WebControllerTest, INDEX) { // invalid index type index_dto->index_type = 100; response = client_ptr->createIndex(table_name, index_dto, conncetion_ptr); - ASSERT_NE(OStatus::CODE_201.code, response->getStatusCode()); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); // insert data and create index @@ -816,6 +915,9 @@ TEST_F(WebControllerTest, INDEX) { // get index response = client_ptr->getIndex(table_name, conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + auto result_index_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ("FLAT", result_index_dto->index_type->std_str()); + ASSERT_EQ(10, result_index_dto->nlist->getValue()); } TEST_F(WebControllerTest, PARTITION) { @@ -825,23 +927,25 @@ TEST_F(WebControllerTest, PARTITION) { auto par_param = milvus::server::web::PartitionRequestDto::createShared(); auto response = client_ptr->createPartition(table_name, par_param); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); - auto result_dto = response->readBodyToDto(object_mapper.get()); - ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + auto error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code); par_param->partition_name = "partition01" + OString(RandomName().c_str()); response = client_ptr->createPartition(table_name, par_param); - result_dto = response->readBodyToDto(object_mapper.get()); - ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code); par_param->partition_tag = "tag01"; response = client_ptr->createPartition(table_name, par_param); ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto create_result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, create_result_dto->code); // insert 200 vectors into table with tag = 'tag01' OQueryParams query_params; // add partition tag auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); - // add partition tag insert_dto->tag = OString("tag01"); insert_dto->ids = insert_dto->ids->createShared(); insert_dto->records = insert_dto->records->createShared(); @@ -854,13 +958,17 @@ TEST_F(WebControllerTest, PARTITION) { // Show all partitins response = client_ptr->showPartitions(table_name, 0, 10, conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(1, result_dto->partitions->count()); + ASSERT_EQ("tag01", result_dto->partitions->get(0)->partition_tag->std_str()); + ASSERT_EQ(par_param->partition_name->std_str(), result_dto->partitions->get(0)->partition_name->std_str()); response = client_ptr->dropPartition(table_name, "tag01", conncetion_ptr); ASSERT_EQ(OStatus::CODE_204.code, response->getStatusCode()); } TEST_F(WebControllerTest, SEARCH) { - const OString table_name = "test_partition_table_test" + OString(RandomName().c_str()); + const OString table_name = "test_search_table_test" + OString(RandomName().c_str()); GenTable(table_name, 64, 100, "L2"); // Insert 200 vectors into table @@ -869,6 +977,67 @@ TEST_F(WebControllerTest, SEARCH) { insert_dto->ids = insert_dto->ids->createShared(); insert_dto->records = RandomRecordsDto(64, 200);// insert_dto->records->createShared(); + auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto insert_result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(200, insert_result_dto->ids->count()); + + sleep(4); + + //Create partition and insert 200 vectors into it + auto par_param = milvus::server::web::PartitionRequestDto::createShared(); + par_param->partition_name = "partition" + OString(RandomName().c_str()); + par_param->partition_tag = "tag" + OString(RandomName().c_str()); + response = client_ptr->createPartition(table_name, par_param); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()) + << "Error: " << response->getStatusDescription()->std_str(); + + insert_dto->tag = par_param->partition_tag; + response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + sleep(2); + + // Test search + auto search_request_dto = milvus::server::web::SearchRequestDto::createShared(); + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + auto error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code); + + search_request_dto->nprobe = 1; + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code); + + search_request_dto->topk = 1; + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + error_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, error_dto->code); + + search_request_dto->records = RandomRecordsDto(64, 10); + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(10, result_dto->num); + ASSERT_EQ(10, result_dto->results->count()); + ASSERT_EQ(1, result_dto->results->get(0)->count()); + + // Test search with tags + search_request_dto->tags = search_request_dto->tags->createShared(); + search_request_dto->tags->pushBack(par_param->partition_tag); + response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); +} + +TEST_F(WebControllerTest, SEARCH_BIN) { + const OString table_name = "test_search_bin_table_test" + OString(RandomName().c_str()); + GenTable(table_name, 64, 100, "HAMMING"); + + // Insert 200 vectors into table + OQueryParams query_params; + auto insert_dto = milvus::server::web::InsertRequestDto::createShared(); + insert_dto->ids = insert_dto->ids->createShared(); + insert_dto->records_bin = RandomBinRecordsDto(64, 200); + auto response = client_ptr->insert(table_name, insert_dto, conncetion_ptr); ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); @@ -903,7 +1072,7 @@ TEST_F(WebControllerTest, SEARCH) { result_dto = response->readBodyToDto(object_mapper.get()); ASSERT_EQ(milvus::server::web::StatusCode::BODY_FIELD_LOSS, result_dto->code); - search_request_dto->records = RandomRecordsDto(64, 10); + search_request_dto->records_bin = RandomBinRecordsDto(64, 10); response = client_ptr->search(table_name, search_request_dto, conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); diff --git a/tests/milvus_python_test/test_partition.py b/tests/milvus_python_test/test_partition.py index cdf492e4..acd8307b 100644 --- a/tests/milvus_python_test/test_partition.py +++ b/tests/milvus_python_test/test_partition.py @@ -258,8 +258,8 @@ class TestShowBase: ''' partition_name = gen_unique_str() status, res = connect.show_partitions(partition_name) - assert status.OK() - assert len(res) == 0 + assert not status.OK() + # assert len(res) == 0 def test_show_multi_partitions(self, connect, table): ''' @@ -428,4 +428,4 @@ class TestNameInvalid(object): partition_name = gen_unique_str() status = connect.create_partition(table, partition_name, tag) status, res = connect.show_partitions(table_name) - assert not status.OK() \ No newline at end of file + assert not status.OK() -- GitLab