未验证 提交 68e1a8da 编写于 作者: C Cai Yudong 提交者: GitHub

add field name and segment size validation check when create collection (#3065)

* add field name and segment size validation check when create collection
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* fix clang-format
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* add collection name check
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* clean header files
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* move constants to Types.h
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 5e59056a
......@@ -264,7 +264,7 @@ DBImpl::HasCollection(const std::string& collection_name, bool& has_or_not) {
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name);
has_or_not = status.ok();
return status;
return Status::OK();
}
Status
......
......@@ -90,7 +90,12 @@ extern const char* PARAM_INDEX_METRIC_TYPE;
extern const char* PARAM_INDEX_EXTRA_PARAMS;
extern const char* PARAM_SEGMENT_ROW_COUNT;
constexpr int64_t MAX_NAME_LENGTH = 255;
constexpr int64_t MAX_DIMENSION = 32768;
constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024;
constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000;
constexpr int64_t M_BYTE = 1024 * 1024;
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
using FieldType = meta::DataType;
......
......@@ -47,7 +47,7 @@ class ReferenceProxy {
}
}
[[nodiscard]] int64_t
int64_t
ref_count() const {
return ref_count_;
}
......
......@@ -26,12 +26,6 @@ namespace server {
namespace {
constexpr size_t NAME_SIZE_LIMIT = 255;
constexpr int64_t COLLECTION_DIMENSION_LIMIT = 32768;
constexpr int32_t SEGMENT_ROW_COUNT_LIMIT = 4 * 1024 * 1024;
constexpr int64_t M_BYTE = 1024 * 1024;
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
Status
CheckParameterRange(const milvus::json& json_params, const std::string& param_name, int64_t min, int64_t max,
bool min_close = true, bool max_closed = true) {
......@@ -102,7 +96,7 @@ ValidateCollectionName(const std::string& collection_name) {
std::string invalid_msg = "Invalid collection name: " + collection_name + ". ";
// Collection name size shouldn't exceed 255.
if (collection_name.size() > NAME_SIZE_LIMIT) {
if (collection_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a collection name must be less than 255 characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
......@@ -140,7 +134,7 @@ ValidateFieldName(const std::string& field_name) {
std::string invalid_msg = "Invalid field name: " + field_name + ". ";
// Field name size shouldn't exceed 255.
if (field_name.size() > NAME_SIZE_LIMIT) {
if (field_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a field name must be less than 255 characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_FIELD_NAME, msg);
......@@ -202,22 +196,18 @@ ValidateIndexType(const std::string& index_type) {
}
Status
ValidateVectorDimension(int64_t dimension, const std::string& metric_type) {
if (dimension <= 0 || dimension > COLLECTION_DIMENSION_LIMIT) {
std::string msg = "Invalid collection dimension: " + std::to_string(dimension) + ". " +
"The collection dimension must be within the range of 1 ~ " +
std::to_string(COLLECTION_DIMENSION_LIMIT) + ".";
ValidateDimension(int64_t dim, bool is_binary) {
if (dim <= 0 || dim > engine::MAX_DIMENSION) {
std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be in range 1 ~ " +
std::to_string(engine::MAX_DIMENSION) + ".";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
if (milvus::engine::utils::IsBinaryMetricType(metric_type)) {
if ((dimension % 8) != 0) {
std::string msg = "Invalid collection dimension: " + std::to_string(dimension) + ". " +
"The collection dimension must be a multiple of 8";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
if (is_binary && (dim % 8) != 0) {
std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be multiple of 8.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
return Status::OK();
......@@ -310,14 +300,12 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
Status
ValidateSegmentRowCount(int64_t segment_row_count) {
if (segment_row_count <= 0 || segment_row_count > SEGMENT_ROW_COUNT_LIMIT) {
if (segment_row_count <= 0 || segment_row_count > engine::MAX_SEGMENT_ROW_COUNT) {
std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " +
"The segment row count must be within the range of 1 ~ " +
std::to_string(SEGMENT_ROW_COUNT_LIMIT) + ".";
"Should be in range 1 ~ " + std::to_string(engine::MAX_SEGMENT_ROW_COUNT) + ".";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg);
}
return Status::OK();
}
......@@ -363,7 +351,7 @@ ValidatePartitionName(const std::string& partition_name) {
std::string invalid_msg = "Invalid partition name: " + partition_name + ". ";
// Collection name size shouldn't exceed 255.
if (partition_name.size() > NAME_SIZE_LIMIT) {
if (partition_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a partition name must be less than 255 characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
......
......@@ -30,6 +30,12 @@ ValidateCollectionName(const std::string& collection_name);
extern Status
ValidateFieldName(const std::string& field_name);
extern Status
ValidateIndexName(const std::string& index_name);
extern Status
ValidateDimension(int64_t dimension, bool is_binary);
extern Status
ValidateIndexType(const std::string& index_type);
......
......@@ -13,17 +13,10 @@
#include "db/Utils.h"
#include "server/DBWrapper.h"
#include "server/ValidationUtil.h"
#include "server/delivery/request/BaseReq.h"
#include "server/web_impl/Constants.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <src/db/snapshot/Context.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace milvus {
namespace server {
......@@ -79,16 +72,23 @@ CreateCollectionReq::OnExecute() {
auto& field_params = field_schema.field_params_;
auto& index_params = field_schema.index_params_;
std::cout << index_params.dump() << std::endl;
STATUS_CHECK(ValidateFieldName(field_name));
std::string index_name;
if (index_params.contains("name")) {
index_name = index_params["name"];
}
std::cout << field_params.dump() << std::endl;
if (field_type == engine::FieldType::VECTOR_FLOAT || field_type == engine::FieldType::VECTOR_BINARY) {
if (!field_params.contains(engine::PARAM_DIMENSION)) {
return Status(SERVER_INVALID_VECTOR_DIMENSION, "Dimension not defined in field_params");
} else {
auto dim = field_params[engine::PARAM_DIMENSION].get<int64_t>();
if (field_type == engine::FieldType::VECTOR_FLOAT) {
STATUS_CHECK(ValidateDimension(dim, false));
} else {
STATUS_CHECK(ValidateDimension(dim, true));
}
}
}
......
......@@ -14,8 +14,6 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "server/delivery/request/BaseReq.h"
......
......@@ -15,11 +15,6 @@
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <memory>
#include <unordered_map>
#include <vector>
namespace milvus {
namespace server {
......@@ -39,6 +34,8 @@ DropCollectionReq::OnExecute() {
std::string hdr = "DropCollectionReq(collection=" + collection_name_ + ")";
TimeRecorder rc(hdr);
STATUS_CHECK(ValidateCollectionName(collection_name_));
bool exist = false;
auto status = DBWrapper::DB()->HasCollection(collection_name_, exist);
if (!exist) {
......
......@@ -12,16 +12,12 @@
#include "server/delivery/request/GetCollectionInfoReq.h"
#include "db/Utils.h"
#include "server/DBWrapper.h"
#include "server/ValidationUtil.h"
#include "server/web_impl/Constants.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace milvus {
namespace server {
......@@ -45,6 +41,8 @@ GetCollectionInfoReq::OnExecute() {
TimeRecorderAuto rc(hdr);
try {
STATUS_CHECK(ValidateCollectionName(collection_name_));
engine::snapshot::CollectionPtr collection;
engine::snapshot::CollectionMappings collection_mappings;
STATUS_CHECK(DBWrapper::DB()->GetCollectionInfo(collection_name_, collection, collection_mappings));
......
......@@ -13,9 +13,6 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "server/delivery/request/BaseReq.h"
......
......@@ -21,10 +21,6 @@
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include <memory>
#include <unordered_map>
#include <vector>
namespace milvus {
namespace server {
......@@ -46,16 +42,22 @@ GetCollectionStatsReq::OnExecute() {
std::string hdr = "GetCollectionStatsReq(collection=" + collection_name_ + ")";
TimeRecorderAuto rc(hdr);
bool exist = false;
auto status = DBWrapper::DB()->HasCollection(collection_name_, exist);
if (!exist) {
return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_));
}
try {
STATUS_CHECK(ValidateCollectionName(collection_name_));
nlohmann::json json_stats;
STATUS_CHECK(DBWrapper::DB()->GetCollectionStats(collection_name_, json_stats));
collection_stats_ = json_stats.dump();
rc.ElapseFromBegin("done");
bool exist = false;
auto status = DBWrapper::DB()->HasCollection(collection_name_, exist);
if (!exist) {
return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_));
}
milvus::json json_stats;
STATUS_CHECK(DBWrapper::DB()->GetCollectionStats(collection_name_, json_stats));
collection_stats_ = json_stats.dump();
rc.ElapseFromBegin("done");
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
return Status::OK();
}
......
......@@ -21,7 +21,6 @@
#include <memory>
#include <string>
#include <vector>
namespace milvus {
namespace server {
......
......@@ -16,7 +16,6 @@
#include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <memory>
namespace milvus {
namespace server {
......@@ -38,6 +37,7 @@ HasCollectionReq::OnExecute() {
std::string hdr = "HasCollectionReq(collection=" + collection_name_ + ")";
TimeRecorderAuto rc(hdr);
STATUS_CHECK(ValidateCollectionName(collection_name_));
STATUS_CHECK(DBWrapper::DB()->HasCollection(collection_name_, exist_));
rc.ElapseFromBegin("done");
......
......@@ -15,9 +15,7 @@
#include "utils/TimeRecorder.h"
#include <fiu-local.h>
#include <memory>
#include <string>
#include <vector>
namespace milvus {
namespace server {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册