diff --git a/cpp/src/server/RequestTask.cpp b/cpp/src/server/RequestTask.cpp index 07a8305d1f59f7afd446dc716a2af57e3c3ea98e..76da0f728dc157554653b10f69bcc617aee8fc66 100644 --- a/cpp/src/server/RequestTask.cpp +++ b/cpp/src/server/RequestTask.cpp @@ -8,6 +8,7 @@ #include "utils/CommonUtil.h" #include "utils/Log.h" #include "utils/TimeRecorder.h" +#include "utils/ValidationUtil.h" #include "DBWrapper.h" #include "version.h" @@ -133,19 +134,23 @@ BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) { ServerError CreateTableTask::OnExecute() { TimeRecorder rc("CreateTableTask"); - + try { //step 1: check arguments - if(schema_.table_name.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(schema_.table_name); + if(res != SERVER_SUCCESS) { + return res; } - if(schema_.dimension <= 0) { - return SetError(SERVER_INVALID_TABLE_DIMENSION, "Invalid table dimension: " + std::to_string(schema_.dimension)); + + res = ValidateTableDimension(schema_.dimension); + if(res != SERVER_SUCCESS) { + return res; } - engine::EngineType engine_type = EngineType(schema_.index_type); - if(engine_type == engine::EngineType::INVALID) { - return SetError(SERVER_INVALID_INDEX_TYPE, "Invalid index type: " + std::to_string(schema_.index_type)); + res = ValidateTableIndexType(schema_.index_type); + if(res != SERVER_SUCCESS) { + return res; } //step 2: construct table schema @@ -187,8 +192,10 @@ ServerError DescribeTableTask::OnExecute() { try { //step 1: check arguments - if(table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return res; } //step 2: get table info @@ -230,10 +237,11 @@ ServerError HasTableTask::OnExecute() { TimeRecorder rc("HasTableTask"); //step 1: check arguments - if(table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return res; } - //step 2: check table existence engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table_); if(!stat.ok()) { @@ -264,8 +272,10 @@ ServerError DeleteTableTask::OnExecute() { TimeRecorder rc("DeleteTableTask"); //step 1: check arguments - if (table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return res; } //step 2: check table existence @@ -346,8 +356,10 @@ ServerError AddVectorTask::OnExecute() { TimeRecorder rc("AddVectorTask"); //step 1: check arguments - if (table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return res; } if(record_array_.empty()) { @@ -435,8 +447,10 @@ ServerError SearchVectorTask::OnExecute() { TimeRecorder rc("SearchVectorTask"); //step 1: check arguments - if (table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return res; } if(top_k_ <= 0) { @@ -548,8 +562,10 @@ ServerError GetTableRowCountTask::OnExecute() { TimeRecorder rc("GetTableRowCountTask"); //step 1: check arguments - if (table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return res; } //step 2: get row count diff --git a/cpp/src/utils/ValidationUtil.cpp b/cpp/src/utils/ValidationUtil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4bbd3346a3b383e89e9d828301968ee03277410 --- /dev/null +++ b/cpp/src/utils/ValidationUtil.cpp @@ -0,0 +1,74 @@ +#include +#include "ValidationUtil.h" +#include "Log.h" + + +namespace zilliz { +namespace milvus { +namespace server { + +constexpr size_t table_name_size_limit = 16384; +constexpr int64_t table_dimension_limit = 16384; + +ServerError +ValidateTableName(const std::string &table_name) { + + // Table name shouldn't be empty. + if (table_name.empty()) { + SERVER_LOG_ERROR << "Empty table name"; + return SERVER_INVALID_TABLE_NAME; + } + + // Table name size shouldn't exceed 16384. + if (table_name.size() > table_name_size_limit) { + SERVER_LOG_ERROR << "Table name size exceed the limitation"; + return SERVER_INVALID_TABLE_NAME; + } + + // Table name first character should be underscore or character. + char first_char = table_name[0]; + if (first_char != '_' && std::isalpha(first_char) == 0) { + SERVER_LOG_ERROR << "Table name first character isn't underscore or character: " << first_char; + return SERVER_INVALID_TABLE_NAME; + } + + int64_t table_name_size = table_name.size(); + for (int64_t i = 1; i < table_name_size; ++i) { + char name_char = table_name[i]; + if (name_char != '_' && std::isalnum(name_char) == 0) { + SERVER_LOG_ERROR << "Table name character isn't underscore or alphanumber: " << name_char; + return SERVER_INVALID_TABLE_NAME; + } + } + + return SERVER_SUCCESS; +} + +ServerError +ValidateTableDimension(int64_t dimension) { + if (dimension <= 0 || dimension > table_dimension_limit) { + SERVER_LOG_ERROR << "Table dimension excceed the limitation: " << table_dimension_limit; + return SERVER_INVALID_VECTOR_DIMENSION; + } else { + return SERVER_SUCCESS; + } +} + +ServerError +ValidateTableIndexType(int32_t index_type) { + auto engine_type = engine::EngineType(index_type); + switch (engine_type) { + case engine::EngineType::FAISS_IDMAP: + case engine::EngineType::FAISS_IVFFLAT: { + SERVER_LOG_DEBUG << "Index type: " << index_type; + return SERVER_SUCCESS; + } + default: { + return SERVER_INVALID_INDEX_TYPE; + } + } +} + +} +} +} \ No newline at end of file diff --git a/cpp/src/utils/ValidationUtil.h b/cpp/src/utils/ValidationUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..608ac22682c52a29cf97bebc56ac414d95069c94 --- /dev/null +++ b/cpp/src/utils/ValidationUtil.h @@ -0,0 +1,20 @@ +#pragma once + +#include "Error.h" + +namespace zilliz { +namespace milvus { +namespace server { + +ServerError +ValidateTableName(const std::string& table_name); + +ServerError +ValidateTableDimension(int64_t dimension); + +ServerError +ValidateTableIndexType(int32_t index_type); + +} +} +} \ No newline at end of file diff --git a/cpp/unittest/CMakeLists.txt b/cpp/unittest/CMakeLists.txt index 043716b58bf77e2877d2786d8f81809de8a7b0c6..8675bf8735eca1468fe760aef609e1542b3e7e4d 100644 --- a/cpp/unittest/CMakeLists.txt +++ b/cpp/unittest/CMakeLists.txt @@ -12,7 +12,6 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/config config_files) set(unittest_srcs ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp) - #${EASYLOGGINGPP_INCLUDE_DIR}/easylogging++.cc) set(require_files ${MILVUS_ENGINE_SRC}/server/ServerConfig.cpp @@ -44,4 +43,5 @@ add_subdirectory(index_wrapper) #add_subdirectory(faiss_wrapper) #add_subdirectory(license) add_subdirectory(metrics) -add_subdirectory(storage) \ No newline at end of file +add_subdirectory(storage) +add_subdirectory(utils) \ No newline at end of file diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index bd17081af8270d08c1b289bb223b86efc419f27b..625211cae79a0c6885b3643ebc5be8b0315effb0 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -3,17 +3,20 @@ // Unauthorized copying of this file, via any medium is strictly prohibited. // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// -#include -#include -#include -#include - #include "utils.h" #include "db/DB.h" #include "db/DBImpl.h" #include "db/MetaConsts.h" #include "db/Factories.h" +#include +#include + +#include + +#include +#include + using namespace zilliz::milvus; namespace { diff --git a/cpp/unittest/db/mysql_db_test.cpp b/cpp/unittest/db/mysql_db_test.cpp index 7fdb30a2042441488308cb4abd89ec4265d319d3..0e24cacdfd1e1fca09195266caa918a0d95cb7e2 100644 --- a/cpp/unittest/db/mysql_db_test.cpp +++ b/cpp/unittest/db/mysql_db_test.cpp @@ -3,17 +3,19 @@ // Unauthorized copying of this file, via any medium is strictly prohibited. // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// -#include -#include -#include -#include - #include "utils.h" #include "db/DB.h" #include "db/DBImpl.h" #include "db/MetaConsts.h" #include "db/Factories.h" +#include +#include +#include + +#include +#include + using namespace zilliz::milvus; namespace { diff --git a/cpp/unittest/db/search_test.cpp b/cpp/unittest/db/search_test.cpp index db10bcbadf6172d0a76eac127b83d693378872d3..ce99ea78f77c58ebb91d4b791ad1b10717f333a4 100644 --- a/cpp/unittest/db/search_test.cpp +++ b/cpp/unittest/db/search_test.cpp @@ -3,10 +3,11 @@ // Unauthorized copying of this file, via any medium is strictly prohibited. // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// -#include - #include "db/scheduler/task/SearchTask.h" +#include + +#include #include using namespace zilliz::milvus; diff --git a/cpp/unittest/faiss_wrapper/wrapper_test.cpp b/cpp/unittest/faiss_wrapper/wrapper_test.cpp index 67a6c3cde828a06ad11fe64d1922a76065fc1d8a..6f4a651a554131ce63df68b042751670156e1682 100644 --- a/cpp/unittest/faiss_wrapper/wrapper_test.cpp +++ b/cpp/unittest/faiss_wrapper/wrapper_test.cpp @@ -4,12 +4,15 @@ // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// -#include + #include "wrapper/Operand.h" #include "wrapper/Index.h" #include "wrapper/IndexBuilder.h" +#include +#include + using namespace zilliz::milvus::engine; diff --git a/cpp/unittest/utils/CMakeLists.txt b/cpp/unittest/utils/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a46a3b05e1bcec1b0eae4b4932320be0c7c0ebdd --- /dev/null +++ b/cpp/unittest/utils/CMakeLists.txt @@ -0,0 +1,30 @@ +#------------------------------------------------------------------------------- +# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +# Unauthorized copying of this file, via any medium is strictly prohibited. +# Proprietary and confidential. +#------------------------------------------------------------------------------- + +# Make sure that your call to link_directories takes place before your call to the relevant add_executable. +include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") +link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") + +set(validation_util_src + ${MILVUS_ENGINE_SRC}/utils/ValidationUtil.cpp + ${MILVUS_ENGINE_SRC}/utils/ValidationUtil.h) + +set(validation_util_test_src + ${unittest_srcs} + ${validation_util_src} + ${require_files} + ValidationUtilTest.cpp + ) + +add_executable(valication_util_test + ${validation_util_test_src} + ${config_files}) + +target_link_libraries(valication_util_test + ${unittest_libs} + boost_filesystem) + +install(TARGETS valication_util_test DESTINATION bin) \ No newline at end of file diff --git a/cpp/unittest/utils/ValidationUtilTest.cpp b/cpp/unittest/utils/ValidationUtilTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..095614e3257935e85f2e72de15d6b0e761855d5d --- /dev/null +++ b/cpp/unittest/utils/ValidationUtilTest.cpp @@ -0,0 +1,61 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// +#include + +#include "utils/ValidationUtil.h" +#include "utils/Error.h" + +#include + +using namespace zilliz::milvus::server; + +TEST(ValidationUtilTest, TableNameTest) { + std::string table_name = "Normal123_"; + ServerError res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_SUCCESS); + + table_name = "12sds"; + res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_INVALID_TABLE_NAME); + + table_name = ""; + res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_INVALID_TABLE_NAME); + + table_name = "_asdasd"; + res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_SUCCESS); + + table_name = "!@#!@"; + res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_INVALID_TABLE_NAME); + + table_name = "中文"; + res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_INVALID_TABLE_NAME); + + + table_name = std::string('a', 32768); + res = ValidateTableName(table_name); + ASSERT_EQ(res, SERVER_INVALID_TABLE_NAME); +} + + +TEST(ValidationUtilTest, TableDimensionTest) { + ASSERT_EQ(ValidateTableDimension(-1), SERVER_INVALID_VECTOR_DIMENSION); + ASSERT_EQ(ValidateTableDimension(0), SERVER_INVALID_VECTOR_DIMENSION); + ASSERT_EQ(ValidateTableDimension(16385), SERVER_INVALID_VECTOR_DIMENSION); + ASSERT_EQ(ValidateTableDimension(16384), SERVER_SUCCESS); + ASSERT_EQ(ValidateTableDimension(1), SERVER_SUCCESS); +} + +TEST(ValidationUtilTest, TableIndexTypeTest) { + ASSERT_EQ(ValidateTableIndexType(0), SERVER_INVALID_INDEX_TYPE); + ASSERT_EQ(ValidateTableIndexType(1), SERVER_SUCCESS); + ASSERT_EQ(ValidateTableIndexType(2), SERVER_SUCCESS); + ASSERT_EQ(ValidateTableIndexType(3), SERVER_INVALID_INDEX_TYPE); + ASSERT_EQ(ValidateTableIndexType(4), SERVER_INVALID_INDEX_TYPE); +}