From 68e1a8da2bfcd97d1dd56bba65119c215aafd80d Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Thu, 30 Jul 2020 18:17:10 +0800 Subject: [PATCH] add field name and segment size validation check when create collection (#3065) * add field name and segment size validation check when create collection Signed-off-by: yudong.cai * fix clang-format Signed-off-by: yudong.cai * add collection name check Signed-off-by: yudong.cai * clean header files Signed-off-by: yudong.cai * move constants to Types.h Signed-off-by: yudong.cai --- core/src/db/DBImpl.cpp | 2 +- core/src/db/Types.h | 5 +++ core/src/db/snapshot/ReferenceProxy.h | 2 +- core/src/server/ValidationUtil.cpp | 38 +++++++------------ core/src/server/ValidationUtil.h | 6 +++ .../delivery/request/CreateCollectionReq.cpp | 18 ++++----- .../delivery/request/CreateCollectionReq.h | 2 - .../delivery/request/DropCollectionReq.cpp | 7 +--- .../delivery/request/GetCollectionInfoReq.cpp | 8 ++-- .../delivery/request/GetCollectionInfoReq.h | 3 -- .../request/GetCollectionStatsReq.cpp | 28 +++++++------- .../delivery/request/GetCollectionStatsReq.h | 1 - .../delivery/request/HasCollectionReq.cpp | 2 +- .../delivery/request/ListCollectionsReq.cpp | 2 - 14 files changed, 56 insertions(+), 68 deletions(-) diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index 2afe537c..b7696467 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -264,7 +264,7 @@ DBImpl::HasCollection(const std::string& collection_name, bool& has_or_not) { auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); has_or_not = status.ok(); - return status; + return Status::OK(); } Status diff --git a/core/src/db/Types.h b/core/src/db/Types.h index 51cc7c97..d7d3f22f 100644 --- a/core/src/db/Types.h +++ b/core/src/db/Types.h @@ -90,7 +90,12 @@ extern const char* PARAM_INDEX_METRIC_TYPE; extern const char* PARAM_INDEX_EXTRA_PARAMS; extern const char* PARAM_SEGMENT_ROW_COUNT; +constexpr int64_t MAX_NAME_LENGTH = 255; +constexpr int64_t MAX_DIMENSION = 32768; +constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024; constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; +constexpr int64_t M_BYTE = 1024 * 1024; +constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE; using FieldType = meta::DataType; diff --git a/core/src/db/snapshot/ReferenceProxy.h b/core/src/db/snapshot/ReferenceProxy.h index f0a64613..b46be877 100644 --- a/core/src/db/snapshot/ReferenceProxy.h +++ b/core/src/db/snapshot/ReferenceProxy.h @@ -47,7 +47,7 @@ class ReferenceProxy { } } - [[nodiscard]] int64_t + int64_t ref_count() const { return ref_count_; } diff --git a/core/src/server/ValidationUtil.cpp b/core/src/server/ValidationUtil.cpp index 347e7c37..c3c569d1 100644 --- a/core/src/server/ValidationUtil.cpp +++ b/core/src/server/ValidationUtil.cpp @@ -26,12 +26,6 @@ namespace server { namespace { -constexpr size_t NAME_SIZE_LIMIT = 255; -constexpr int64_t COLLECTION_DIMENSION_LIMIT = 32768; -constexpr int32_t SEGMENT_ROW_COUNT_LIMIT = 4 * 1024 * 1024; -constexpr int64_t M_BYTE = 1024 * 1024; -constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE; - Status CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max, bool min_close = true, bool max_closed = true) { @@ -102,7 +96,7 @@ ValidateCollectionName(const std::string& collection_name) { std::string invalid_msg = "Invalid collection name: " + collection_name + ". "; // Collection name size shouldn't exceed 255. - if (collection_name.size() > NAME_SIZE_LIMIT) { + if (collection_name.size() > engine::MAX_NAME_LENGTH) { std::string msg = invalid_msg + "The length of a collection name must be less than 255 characters."; LOG_SERVER_ERROR_ << msg; return Status(SERVER_INVALID_COLLECTION_NAME, msg); @@ -140,7 +134,7 @@ ValidateFieldName(const std::string& field_name) { std::string invalid_msg = "Invalid field name: " + field_name + ". "; // Field name size shouldn't exceed 255. - if (field_name.size() > NAME_SIZE_LIMIT) { + if (field_name.size() > engine::MAX_NAME_LENGTH) { std::string msg = invalid_msg + "The length of a field name must be less than 255 characters."; LOG_SERVER_ERROR_ << msg; return Status(SERVER_INVALID_FIELD_NAME, msg); @@ -202,22 +196,18 @@ ValidateIndexType(const std::string& index_type) { } Status -ValidateVectorDimension(int64_t dimension, const std::string& metric_type) { - if (dimension <= 0 || dimension > COLLECTION_DIMENSION_LIMIT) { - std::string msg = "Invalid collection dimension: " + std::to_string(dimension) + ". " + - "The collection dimension must be within the range of 1 ~ " + - std::to_string(COLLECTION_DIMENSION_LIMIT) + "."; +ValidateDimension(int64_t dim, bool is_binary) { + if (dim <= 0 || dim > engine::MAX_DIMENSION) { + std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be in range 1 ~ " + + std::to_string(engine::MAX_DIMENSION) + "."; LOG_SERVER_ERROR_ << msg; return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); } - if (milvus::engine::utils::IsBinaryMetricType(metric_type)) { - if ((dimension % 8) != 0) { - std::string msg = "Invalid collection dimension: " + std::to_string(dimension) + ". " + - "The collection dimension must be a multiple of 8"; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); - } + if (is_binary && (dim % 8) != 0) { + std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be multiple of 8."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); } return Status::OK(); @@ -310,14 +300,12 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s Status ValidateSegmentRowCount(int64_t segment_row_count) { - if (segment_row_count <= 0 || segment_row_count > SEGMENT_ROW_COUNT_LIMIT) { + if (segment_row_count <= 0 || segment_row_count > engine::MAX_SEGMENT_ROW_COUNT) { std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " + - "The segment row count must be within the range of 1 ~ " + - std::to_string(SEGMENT_ROW_COUNT_LIMIT) + "."; + "Should be in range 1 ~ " + std::to_string(engine::MAX_SEGMENT_ROW_COUNT) + "."; LOG_SERVER_ERROR_ << msg; return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg); } - return Status::OK(); } @@ -363,7 +351,7 @@ ValidatePartitionName(const std::string& partition_name) { std::string invalid_msg = "Invalid partition name: " + partition_name + ". "; // Collection name size shouldn't exceed 255. - if (partition_name.size() > NAME_SIZE_LIMIT) { + if (partition_name.size() > engine::MAX_NAME_LENGTH) { std::string msg = invalid_msg + "The length of a partition name must be less than 255 characters."; LOG_SERVER_ERROR_ << msg; return Status(SERVER_INVALID_COLLECTION_NAME, msg); diff --git a/core/src/server/ValidationUtil.h b/core/src/server/ValidationUtil.h index 1ff50283..079a71a8 100644 --- a/core/src/server/ValidationUtil.h +++ b/core/src/server/ValidationUtil.h @@ -30,6 +30,12 @@ ValidateCollectionName(const std::string& collection_name); extern Status ValidateFieldName(const std::string& field_name); +extern Status +ValidateIndexName(const std::string& index_name); + +extern Status +ValidateDimension(int64_t dimension, bool is_binary); + extern Status ValidateIndexType(const std::string& index_type); diff --git a/core/src/server/delivery/request/CreateCollectionReq.cpp b/core/src/server/delivery/request/CreateCollectionReq.cpp index 43b95615..eac0e534 100644 --- a/core/src/server/delivery/request/CreateCollectionReq.cpp +++ b/core/src/server/delivery/request/CreateCollectionReq.cpp @@ -13,17 +13,10 @@ #include "db/Utils.h" #include "server/DBWrapper.h" #include "server/ValidationUtil.h" -#include "server/delivery/request/BaseReq.h" -#include "server/web_impl/Constants.h" #include "utils/Log.h" #include "utils/TimeRecorder.h" #include -#include -#include -#include -#include -#include namespace milvus { namespace server { @@ -79,16 +72,23 @@ CreateCollectionReq::OnExecute() { auto& field_params = field_schema.field_params_; auto& index_params = field_schema.index_params_; - std::cout << index_params.dump() << std::endl; + STATUS_CHECK(ValidateFieldName(field_name)); + std::string index_name; if (index_params.contains("name")) { index_name = index_params["name"]; } - std::cout << field_params.dump() << std::endl; if (field_type == engine::FieldType::VECTOR_FLOAT || field_type == engine::FieldType::VECTOR_BINARY) { if (!field_params.contains(engine::PARAM_DIMENSION)) { return Status(SERVER_INVALID_VECTOR_DIMENSION, "Dimension not defined in field_params"); + } else { + auto dim = field_params[engine::PARAM_DIMENSION].get(); + if (field_type == engine::FieldType::VECTOR_FLOAT) { + STATUS_CHECK(ValidateDimension(dim, false)); + } else { + STATUS_CHECK(ValidateDimension(dim, true)); + } } } diff --git a/core/src/server/delivery/request/CreateCollectionReq.h b/core/src/server/delivery/request/CreateCollectionReq.h index c6fee6c7..1778fd88 100644 --- a/core/src/server/delivery/request/CreateCollectionReq.h +++ b/core/src/server/delivery/request/CreateCollectionReq.h @@ -14,8 +14,6 @@ #include #include #include -#include -#include #include "server/delivery/request/BaseReq.h" diff --git a/core/src/server/delivery/request/DropCollectionReq.cpp b/core/src/server/delivery/request/DropCollectionReq.cpp index 84071365..2b244efe 100644 --- a/core/src/server/delivery/request/DropCollectionReq.cpp +++ b/core/src/server/delivery/request/DropCollectionReq.cpp @@ -15,11 +15,6 @@ #include "utils/Log.h" #include "utils/TimeRecorder.h" -#include -#include -#include -#include - namespace milvus { namespace server { @@ -39,6 +34,8 @@ DropCollectionReq::OnExecute() { std::string hdr = "DropCollectionReq(collection=" + collection_name_ + ")"; TimeRecorder rc(hdr); + STATUS_CHECK(ValidateCollectionName(collection_name_)); + bool exist = false; auto status = DBWrapper::DB()->HasCollection(collection_name_, exist); if (!exist) { diff --git a/core/src/server/delivery/request/GetCollectionInfoReq.cpp b/core/src/server/delivery/request/GetCollectionInfoReq.cpp index 07584da2..afdd49bc 100644 --- a/core/src/server/delivery/request/GetCollectionInfoReq.cpp +++ b/core/src/server/delivery/request/GetCollectionInfoReq.cpp @@ -12,16 +12,12 @@ #include "server/delivery/request/GetCollectionInfoReq.h" #include "db/Utils.h" #include "server/DBWrapper.h" +#include "server/ValidationUtil.h" #include "server/web_impl/Constants.h" #include "utils/Log.h" #include "utils/TimeRecorder.h" -#include -#include -#include -#include #include -#include namespace milvus { namespace server { @@ -45,6 +41,8 @@ GetCollectionInfoReq::OnExecute() { TimeRecorderAuto rc(hdr); try { + STATUS_CHECK(ValidateCollectionName(collection_name_)); + engine::snapshot::CollectionPtr collection; engine::snapshot::CollectionMappings collection_mappings; STATUS_CHECK(DBWrapper::DB()->GetCollectionInfo(collection_name_, collection, collection_mappings)); diff --git a/core/src/server/delivery/request/GetCollectionInfoReq.h b/core/src/server/delivery/request/GetCollectionInfoReq.h index 897c8861..c2f54c27 100644 --- a/core/src/server/delivery/request/GetCollectionInfoReq.h +++ b/core/src/server/delivery/request/GetCollectionInfoReq.h @@ -13,9 +13,6 @@ #include #include -#include -#include -#include #include "server/delivery/request/BaseReq.h" diff --git a/core/src/server/delivery/request/GetCollectionStatsReq.cpp b/core/src/server/delivery/request/GetCollectionStatsReq.cpp index dadfe726..4c710a9d 100644 --- a/core/src/server/delivery/request/GetCollectionStatsReq.cpp +++ b/core/src/server/delivery/request/GetCollectionStatsReq.cpp @@ -21,10 +21,6 @@ #include "utils/Log.h" #include "utils/TimeRecorder.h" -#include -#include -#include - namespace milvus { namespace server { @@ -46,16 +42,22 @@ GetCollectionStatsReq::OnExecute() { std::string hdr = "GetCollectionStatsReq(collection=" + collection_name_ + ")"; TimeRecorderAuto rc(hdr); - bool exist = false; - auto status = DBWrapper::DB()->HasCollection(collection_name_, exist); - if (!exist) { - return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_)); - } + try { + STATUS_CHECK(ValidateCollectionName(collection_name_)); - nlohmann::json json_stats; - STATUS_CHECK(DBWrapper::DB()->GetCollectionStats(collection_name_, json_stats)); - collection_stats_ = json_stats.dump(); - rc.ElapseFromBegin("done"); + bool exist = false; + auto status = DBWrapper::DB()->HasCollection(collection_name_, exist); + if (!exist) { + return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_)); + } + + milvus::json json_stats; + STATUS_CHECK(DBWrapper::DB()->GetCollectionStats(collection_name_, json_stats)); + collection_stats_ = json_stats.dump(); + rc.ElapseFromBegin("done"); + } catch (std::exception& ex) { + return Status(SERVER_UNEXPECTED_ERROR, ex.what()); + } return Status::OK(); } diff --git a/core/src/server/delivery/request/GetCollectionStatsReq.h b/core/src/server/delivery/request/GetCollectionStatsReq.h index 586e3415..9f8923f3 100644 --- a/core/src/server/delivery/request/GetCollectionStatsReq.h +++ b/core/src/server/delivery/request/GetCollectionStatsReq.h @@ -21,7 +21,6 @@ #include #include -#include namespace milvus { namespace server { diff --git a/core/src/server/delivery/request/HasCollectionReq.cpp b/core/src/server/delivery/request/HasCollectionReq.cpp index 161366c0..ea54598d 100644 --- a/core/src/server/delivery/request/HasCollectionReq.cpp +++ b/core/src/server/delivery/request/HasCollectionReq.cpp @@ -16,7 +16,6 @@ #include "utils/TimeRecorder.h" #include -#include namespace milvus { namespace server { @@ -38,6 +37,7 @@ HasCollectionReq::OnExecute() { std::string hdr = "HasCollectionReq(collection=" + collection_name_ + ")"; TimeRecorderAuto rc(hdr); + STATUS_CHECK(ValidateCollectionName(collection_name_)); STATUS_CHECK(DBWrapper::DB()->HasCollection(collection_name_, exist_)); rc.ElapseFromBegin("done"); diff --git a/core/src/server/delivery/request/ListCollectionsReq.cpp b/core/src/server/delivery/request/ListCollectionsReq.cpp index 4196ab13..362e0aa2 100644 --- a/core/src/server/delivery/request/ListCollectionsReq.cpp +++ b/core/src/server/delivery/request/ListCollectionsReq.cpp @@ -15,9 +15,7 @@ #include "utils/TimeRecorder.h" #include -#include #include -#include namespace milvus { namespace server { -- GitLab