未验证 提交 68e1a8da 编写于 作者: C Cai Yudong 提交者: GitHub

add field name and segment size validation check when create collection (#3065)

* add field name and segment size validation check when create collection
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* fix clang-format
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* add collection name check
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* clean header files
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* move constants to Types.h
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 5e59056a
...@@ -264,7 +264,7 @@ DBImpl::HasCollection(const std::string& collection_name, bool& has_or_not) { ...@@ -264,7 +264,7 @@ DBImpl::HasCollection(const std::string& collection_name, bool& has_or_not) {
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name);
has_or_not = status.ok(); has_or_not = status.ok();
return status; return Status::OK();
} }
Status Status
......
...@@ -90,7 +90,12 @@ extern const char* PARAM_INDEX_METRIC_TYPE; ...@@ -90,7 +90,12 @@ extern const char* PARAM_INDEX_METRIC_TYPE;
extern const char* PARAM_INDEX_EXTRA_PARAMS; extern const char* PARAM_INDEX_EXTRA_PARAMS;
extern const char* PARAM_SEGMENT_ROW_COUNT; extern const char* PARAM_SEGMENT_ROW_COUNT;
constexpr int64_t MAX_NAME_LENGTH = 255;
constexpr int64_t MAX_DIMENSION = 32768;
constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024;
constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000;
constexpr int64_t M_BYTE = 1024 * 1024;
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
using FieldType = meta::DataType; using FieldType = meta::DataType;
......
...@@ -47,7 +47,7 @@ class ReferenceProxy { ...@@ -47,7 +47,7 @@ class ReferenceProxy {
} }
} }
[[nodiscard]] int64_t int64_t
ref_count() const { ref_count() const {
return ref_count_; return ref_count_;
} }
......
...@@ -26,12 +26,6 @@ namespace server { ...@@ -26,12 +26,6 @@ namespace server {
namespace { namespace {
constexpr size_t NAME_SIZE_LIMIT = 255;
constexpr int64_t COLLECTION_DIMENSION_LIMIT = 32768;
constexpr int32_t SEGMENT_ROW_COUNT_LIMIT = 4 * 1024 * 1024;
constexpr int64_t M_BYTE = 1024 * 1024;
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
Status Status
CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max, CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max,
bool min_close = true, bool max_closed = true) { bool min_close = true, bool max_closed = true) {
...@@ -102,7 +96,7 @@ ValidateCollectionName(const std::string& collection_name) { ...@@ -102,7 +96,7 @@ ValidateCollectionName(const std::string& collection_name) {
std::string invalid_msg = "Invalid collection name: " + collection_name + ". "; std::string invalid_msg = "Invalid collection name: " + collection_name + ". ";
// Collection name size shouldn't exceed 255. // Collection name size shouldn't exceed 255.
if (collection_name.size() > NAME_SIZE_LIMIT) { if (collection_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a collection name must be less than 255 characters."; std::string msg = invalid_msg + "The length of a collection name must be less than 255 characters.";
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg); return Status(SERVER_INVALID_COLLECTION_NAME, msg);
...@@ -140,7 +134,7 @@ ValidateFieldName(const std::string& field_name) { ...@@ -140,7 +134,7 @@ ValidateFieldName(const std::string& field_name) {
std::string invalid_msg = "Invalid field name: " + field_name + ". "; std::string invalid_msg = "Invalid field name: " + field_name + ". ";
// Field name size shouldn't exceed 255. // Field name size shouldn't exceed 255.
if (field_name.size() > NAME_SIZE_LIMIT) { if (field_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a field name must be less than 255 characters."; std::string msg = invalid_msg + "The length of a field name must be less than 255 characters.";
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_FIELD_NAME, msg); return Status(SERVER_INVALID_FIELD_NAME, msg);
...@@ -202,22 +196,18 @@ ValidateIndexType(const std::string& index_type) { ...@@ -202,22 +196,18 @@ ValidateIndexType(const std::string& index_type) {
} }
Status Status
ValidateVectorDimension(int64_t dimension, const std::string& metric_type) { ValidateDimension(int64_t dim, bool is_binary) {
if (dimension <= 0 || dimension > COLLECTION_DIMENSION_LIMIT) { if (dim <= 0 || dim > engine::MAX_DIMENSION) {
std::string msg = "Invalid collection dimension: " + std::to_string(dimension) + ". " + std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be in range 1 ~ " +
"The collection dimension must be within the range of 1 ~ " + std::to_string(engine::MAX_DIMENSION) + ".";
std::to_string(COLLECTION_DIMENSION_LIMIT) + ".";
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
} }
if (milvus::engine::utils::IsBinaryMetricType(metric_type)) { if (is_binary && (dim % 8) != 0) {
if ((dimension % 8) != 0) { std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be multiple of 8.";
std::string msg = "Invalid collection dimension: " + std::to_string(dimension) + ". " + LOG_SERVER_ERROR_ << msg;
"The collection dimension must be a multiple of 8"; return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
} }
return Status::OK(); return Status::OK();
...@@ -310,14 +300,12 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s ...@@ -310,14 +300,12 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
Status Status
ValidateSegmentRowCount(int64_t segment_row_count) { ValidateSegmentRowCount(int64_t segment_row_count) {
if (segment_row_count <= 0 || segment_row_count > SEGMENT_ROW_COUNT_LIMIT) { if (segment_row_count <= 0 || segment_row_count > engine::MAX_SEGMENT_ROW_COUNT) {
std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " + std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " +
"The segment row count must be within the range of 1 ~ " + "Should be in range 1 ~ " + std::to_string(engine::MAX_SEGMENT_ROW_COUNT) + ".";
std::to_string(SEGMENT_ROW_COUNT_LIMIT) + ".";
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg); return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg);
} }
return Status::OK(); return Status::OK();
} }
...@@ -363,7 +351,7 @@ ValidatePartitionName(const std::string& partition_name) { ...@@ -363,7 +351,7 @@ ValidatePartitionName(const std::string& partition_name) {
std::string invalid_msg = "Invalid partition name: " + partition_name + ". "; std::string invalid_msg = "Invalid partition name: " + partition_name + ". ";
// Collection name size shouldn't exceed 255. // Collection name size shouldn't exceed 255.
if (partition_name.size() > NAME_SIZE_LIMIT) { if (partition_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a partition name must be less than 255 characters."; std::string msg = invalid_msg + "The length of a partition name must be less than 255 characters.";
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg); return Status(SERVER_INVALID_COLLECTION_NAME, msg);
......
...@@ -30,6 +30,12 @@ ValidateCollectionName(const std::string& collection_name); ...@@ -30,6 +30,12 @@ ValidateCollectionName(const std::string& collection_name);
extern Status extern Status
ValidateFieldName(const std::string& field_name); ValidateFieldName(const std::string& field_name);
extern Status
ValidateIndexName(const std::string& index_name);
extern Status
ValidateDimension(int64_t dimension, bool is_binary);
extern Status extern Status
ValidateIndexType(const std::string& index_type); ValidateIndexType(const std::string& index_type);
......
...@@ -13,17 +13,10 @@ ...@@ -13,17 +13,10 @@
#include "db/Utils.h" #include "db/Utils.h"
#include "server/DBWrapper.h" #include "server/DBWrapper.h"
#include "server/ValidationUtil.h" #include "server/ValidationUtil.h"
#include "server/delivery/request/BaseReq.h"
#include "server/web_impl/Constants.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <fiu-local.h> #include <fiu-local.h>
#include <src/db/snapshot/Context.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace milvus { namespace milvus {
namespace server { namespace server {
...@@ -79,16 +72,23 @@ CreateCollectionReq::OnExecute() { ...@@ -79,16 +72,23 @@ CreateCollectionReq::OnExecute() {
auto& field_params = field_schema.field_params_; auto& field_params = field_schema.field_params_;
auto& index_params = field_schema.index_params_; auto& index_params = field_schema.index_params_;
std::cout << index_params.dump() << std::endl; STATUS_CHECK(ValidateFieldName(field_name));
std::string index_name; std::string index_name;
if (index_params.contains("name")) { if (index_params.contains("name")) {
index_name = index_params["name"]; index_name = index_params["name"];
} }
std::cout << field_params.dump() << std::endl;
if (field_type == engine::FieldType::VECTOR_FLOAT || field_type == engine::FieldType::VECTOR_BINARY) { if (field_type == engine::FieldType::VECTOR_FLOAT || field_type == engine::FieldType::VECTOR_BINARY) {
if (!field_params.contains(engine::PARAM_DIMENSION)) { if (!field_params.contains(engine::PARAM_DIMENSION)) {
return Status(SERVER_INVALID_VECTOR_DIMENSION, "Dimension not defined in field_params"); return Status(SERVER_INVALID_VECTOR_DIMENSION, "Dimension not defined in field_params");
} else {
auto dim = field_params[engine::PARAM_DIMENSION].get<int64_t>();
if (field_type == engine::FieldType::VECTOR_FLOAT) {
STATUS_CHECK(ValidateDimension(dim, false));
} else {
STATUS_CHECK(ValidateDimension(dim, true));
}
} }
} }
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector>
#include "server/delivery/request/BaseReq.h" #include "server/delivery/request/BaseReq.h"
......
...@@ -15,11 +15,6 @@ ...@@ -15,11 +15,6 @@
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <memory>
#include <unordered_map>
#include <vector>
namespace milvus { namespace milvus {
namespace server { namespace server {
...@@ -39,6 +34,8 @@ DropCollectionReq::OnExecute() { ...@@ -39,6 +34,8 @@ DropCollectionReq::OnExecute() {
std::string hdr = "DropCollectionReq(collection=" + collection_name_ + ")"; std::string hdr = "DropCollectionReq(collection=" + collection_name_ + ")";
TimeRecorder rc(hdr); TimeRecorder rc(hdr);
STATUS_CHECK(ValidateCollectionName(collection_name_));
bool exist = false; bool exist = false;
auto status = DBWrapper::DB()->HasCollection(collection_name_, exist); auto status = DBWrapper::DB()->HasCollection(collection_name_, exist);
if (!exist) { if (!exist) {
......
...@@ -12,16 +12,12 @@ ...@@ -12,16 +12,12 @@
#include "server/delivery/request/GetCollectionInfoReq.h" #include "server/delivery/request/GetCollectionInfoReq.h"
#include "db/Utils.h" #include "db/Utils.h"
#include "server/DBWrapper.h" #include "server/DBWrapper.h"
#include "server/ValidationUtil.h"
#include "server/web_impl/Constants.h" #include "server/web_impl/Constants.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector>
namespace milvus { namespace milvus {
namespace server { namespace server {
...@@ -45,6 +41,8 @@ GetCollectionInfoReq::OnExecute() { ...@@ -45,6 +41,8 @@ GetCollectionInfoReq::OnExecute() {
TimeRecorderAuto rc(hdr); TimeRecorderAuto rc(hdr);
try { try {
STATUS_CHECK(ValidateCollectionName(collection_name_));
engine::snapshot::CollectionPtr collection; engine::snapshot::CollectionPtr collection;
engine::snapshot::CollectionMappings collection_mappings; engine::snapshot::CollectionMappings collection_mappings;
STATUS_CHECK(DBWrapper::DB()->GetCollectionInfo(collection_name_, collection, collection_mappings)); STATUS_CHECK(DBWrapper::DB()->GetCollectionInfo(collection_name_, collection, collection_mappings));
......
...@@ -13,9 +13,6 @@ ...@@ -13,9 +13,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "server/delivery/request/BaseReq.h" #include "server/delivery/request/BaseReq.h"
......
...@@ -21,10 +21,6 @@ ...@@ -21,10 +21,6 @@
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <memory>
#include <unordered_map>
#include <vector>
namespace milvus { namespace milvus {
namespace server { namespace server {
...@@ -46,16 +42,22 @@ GetCollectionStatsReq::OnExecute() { ...@@ -46,16 +42,22 @@ GetCollectionStatsReq::OnExecute() {
std::string hdr = "GetCollectionStatsReq(collection=" + collection_name_ + ")"; std::string hdr = "GetCollectionStatsReq(collection=" + collection_name_ + ")";
TimeRecorderAuto rc(hdr); TimeRecorderAuto rc(hdr);
bool exist = false; try {
auto status = DBWrapper::DB()->HasCollection(collection_name_, exist); STATUS_CHECK(ValidateCollectionName(collection_name_));
if (!exist) {
return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_));
}
nlohmann::json json_stats; bool exist = false;
STATUS_CHECK(DBWrapper::DB()->GetCollectionStats(collection_name_, json_stats)); auto status = DBWrapper::DB()->HasCollection(collection_name_, exist);
collection_stats_ = json_stats.dump(); if (!exist) {
rc.ElapseFromBegin("done"); return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_));
}
milvus::json json_stats;
STATUS_CHECK(DBWrapper::DB()->GetCollectionStats(collection_name_, json_stats));
collection_stats_ = json_stats.dump();
rc.ElapseFromBegin("done");
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
return Status::OK(); return Status::OK();
} }
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
namespace milvus { namespace milvus {
namespace server { namespace server {
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <fiu-local.h> #include <fiu-local.h>
#include <memory>
namespace milvus { namespace milvus {
namespace server { namespace server {
...@@ -38,6 +37,7 @@ HasCollectionReq::OnExecute() { ...@@ -38,6 +37,7 @@ HasCollectionReq::OnExecute() {
std::string hdr = "HasCollectionReq(collection=" + collection_name_ + ")"; std::string hdr = "HasCollectionReq(collection=" + collection_name_ + ")";
TimeRecorderAuto rc(hdr); TimeRecorderAuto rc(hdr);
STATUS_CHECK(ValidateCollectionName(collection_name_));
STATUS_CHECK(DBWrapper::DB()->HasCollection(collection_name_, exist_)); STATUS_CHECK(DBWrapper::DB()->HasCollection(collection_name_, exist_));
rc.ElapseFromBegin("done"); rc.ElapseFromBegin("done");
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
#include "utils/TimeRecorder.h" #include "utils/TimeRecorder.h"
#include <fiu-local.h> #include <fiu-local.h>
#include <memory>
#include <string> #include <string>
#include <vector>
namespace milvus { namespace milvus {
namespace server { namespace server {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册