diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h index 4503e5565066b489fb8786d6559ecbc6a3ea8012..cd9e071b9c7412570d782974e5f2df9107d14f3d 100644 --- a/cpp/src/db/ExecutionEngine.h +++ b/cpp/src/db/ExecutionEngine.h @@ -19,6 +19,7 @@ enum class EngineType { FAISS_IDMAP = 1, FAISS_IVFFLAT, FAISS_IVFSQ8, + MAX_VALUE = FAISS_IVFSQ8, }; class ExecutionEngine { diff --git a/cpp/src/server/RequestTask.cpp b/cpp/src/server/RequestTask.cpp index d4051eba667931c25c48a2470c0cdba19a6d5071..af6264da908059e46d3a5d944d2731bc338a659b 100644 --- a/cpp/src/server/RequestTask.cpp +++ b/cpp/src/server/RequestTask.cpp @@ -148,17 +148,17 @@ ServerError CreateTableTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(schema_.table_name); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + schema_.table_name); } res = ValidateTableDimension(schema_.dimension); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table dimension: " + std::to_string(schema_.dimension)); } res = ValidateTableIndexType(schema_.index_type); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid index type: " + std::to_string(schema_.index_type)); } //step 2: construct table schema @@ -203,7 +203,7 @@ ServerError DescribeTableTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(table_name_); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + table_name_); } //step 2: get table info @@ -243,12 +243,20 @@ ServerError BuildIndexTask::OnExecute() { TimeRecorder rc("BuildIndexTask"); //step 1: check arguments - if(table_name_.empty()) { - return SetError(SERVER_INVALID_TABLE_NAME, "Empty table name"); + ServerError res = SERVER_SUCCESS; + res = ValidateTableName(table_name_); + if(res != SERVER_SUCCESS) { + return SetError(res, "Invalid table name: " + table_name_); + } + + bool has_table = false; + engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table); + if(!has_table) { + return SetError(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); } //step 2: check table existence - engine::Status stat = DBWrapper::DB()->BuildIndex(table_name_); + stat = DBWrapper::DB()->BuildIndex(table_name_); if(!stat.ok()) { return SetError(SERVER_BUILD_INDEX_ERROR, "Engine failed: " + stat.ToString()); } @@ -281,8 +289,9 @@ ServerError HasTableTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(table_name_); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + table_name_); } + //step 2: check table existence engine::Status stat = DBWrapper::DB()->HasTable(table_name_, has_table_); if(!stat.ok()) { @@ -316,7 +325,7 @@ ServerError DeleteTableTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(table_name_); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + table_name_); } //step 2: check table existence @@ -400,7 +409,7 @@ ServerError AddVectorTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(table_name_); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + table_name_); } if(record_array_.empty()) { @@ -491,7 +500,7 @@ ServerError SearchVectorTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(table_name_); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + table_name_); } if(top_k_ <= 0) { @@ -606,7 +615,7 @@ ServerError GetTableRowCountTask::OnExecute() { ServerError res = SERVER_SUCCESS; res = ValidateTableName(table_name_); if(res != SERVER_SUCCESS) { - return res; + return SetError(res, "Invalid table name: " + table_name_); } //step 2: get row count diff --git a/cpp/src/utils/ValidationUtil.cpp b/cpp/src/utils/ValidationUtil.cpp index a1e3f0dffcce8711faff32b530a89f76828be7c6..53f00a4fc7bedb2a553b70b26a82827762b83853 100644 --- a/cpp/src/utils/ValidationUtil.cpp +++ b/cpp/src/utils/ValidationUtil.cpp @@ -56,18 +56,13 @@ ValidateTableDimension(int64_t dimension) { ServerError ValidateTableIndexType(int32_t index_type) { - auto engine_type = engine::EngineType(index_type); - switch (engine_type) { - case engine::EngineType::FAISS_IDMAP: - case engine::EngineType::FAISS_IVFFLAT: - case engine::EngineType::FAISS_IVFSQ8:{ - SERVER_LOG_DEBUG << "Index type: " << index_type; - return SERVER_SUCCESS; - } - default: { - return SERVER_INVALID_INDEX_TYPE; - } + int engine_type = (int)engine::EngineType(index_type); + if(engine_type <= 0 || engine_type > (int)engine::EngineType::MAX_VALUE) { + return SERVER_INVALID_INDEX_TYPE; } + + SERVER_LOG_DEBUG << "Index type: " << index_type; + return SERVER_SUCCESS; } } diff --git a/cpp/unittest/utils/ValidationUtilTest.cpp b/cpp/unittest/utils/ValidationUtilTest.cpp index 095614e3257935e85f2e72de15d6b0e761855d5d..38fc63a10d54934224533600d0d76b054eea187c 100644 --- a/cpp/unittest/utils/ValidationUtilTest.cpp +++ b/cpp/unittest/utils/ValidationUtilTest.cpp @@ -7,9 +7,11 @@ #include "utils/ValidationUtil.h" #include "utils/Error.h" +#include "db/ExecutionEngine.h" #include +using namespace zilliz::milvus; using namespace zilliz::milvus::server; TEST(ValidationUtilTest, TableNameTest) { @@ -53,9 +55,9 @@ TEST(ValidationUtilTest, TableDimensionTest) { } TEST(ValidationUtilTest, TableIndexTypeTest) { - ASSERT_EQ(ValidateTableIndexType(0), SERVER_INVALID_INDEX_TYPE); - ASSERT_EQ(ValidateTableIndexType(1), SERVER_SUCCESS); - ASSERT_EQ(ValidateTableIndexType(2), SERVER_SUCCESS); - ASSERT_EQ(ValidateTableIndexType(3), SERVER_INVALID_INDEX_TYPE); - ASSERT_EQ(ValidateTableIndexType(4), SERVER_INVALID_INDEX_TYPE); + ASSERT_EQ(ValidateTableIndexType((int)engine::EngineType::INVALID), SERVER_INVALID_INDEX_TYPE); + for(int i = 1; i <= (int)engine::EngineType::MAX_VALUE; i++) { + ASSERT_EQ(ValidateTableIndexType(i), SERVER_SUCCESS); + } + ASSERT_EQ(ValidateTableIndexType((int)engine::EngineType::MAX_VALUE + 1), SERVER_INVALID_INDEX_TYPE); }