未验证 提交 ecb0aff9 编写于 作者: G groot 提交者: GitHub

#4511 Insert should be failed if field type not matched (#4517)

Signed-off-by: Nyhmo <yihua.mo@zilliz.com>
上级 56b4e40c
......@@ -23,6 +23,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#4329 C++ sdk sdk_binary needs to update
- \#4418 Fix search when there are multiple vector fields
- \#4488 get_entity_by_id() performance is poor in 0.11.0
- \#4511 Insert should be failed if field type not matched
## Feature
- \#4163 Update C++ sdk search interface
......
......@@ -77,6 +77,27 @@ InsertReq::Create(const ContextPtr& context, const std::string& collection_name,
return std::shared_ptr<BaseReq>(new InsertReq(context, collection_name, partition_name, insert_param));
}
Status
InsertReq::ValidateFieldType(const InsertParam& insert_param) {
engine::snapshot::CollectionPtr collection;
engine::snapshot::FieldElementMappings field_mappings;
STATUS_CHECK(DBWrapper::DB()->GetCollectionInfo(collection_name_, collection, field_mappings));
for (auto& field_kv : field_mappings) {
auto field = field_kv.first;
auto field_name = field->GetName();
auto field_type = field->GetFtype();
auto iter = insert_param.fields_type_.find(field_name);
if (iter != insert_param.fields_type_.end()) {
if (iter->second != field_type) {
return Status{SERVER_INVALID_ARGUMENT, "Field type is incorrect"};
}
}
}
return Status::OK();
}
Status
InsertReq::OnExecute() {
try {
......@@ -105,6 +126,12 @@ InsertReq::OnExecute() {
return status;
}
status = ValidateFieldType(insert_param_);
if (!status.ok()) {
LOG_SERVER_ERROR_ << LogOut("[%s][%d] Invalid field type: %s", "insert", 0, status.message().c_str());
return status;
}
// step 3: construct insert data
engine::DataChunkPtr data_chunk;
STATUS_CHECK(ConvertToChunk(insert_param_, data_chunk));
......
......@@ -34,6 +34,9 @@ class InsertReq : public BaseReq {
Status
OnExecute() override;
Status
ValidateFieldType(const InsertParam& insert_param);
private:
std::string collection_name_;
std::string partition_name_;
......
......@@ -89,10 +89,12 @@ struct InsertParam {
using DataSegment = std::pair<const char*, int64_t>;
using DataSegments = std::vector<DataSegment>;
using FieldDataMap = std::unordered_map<std::string, DataSegments>;
using FieldTypeMap = std::unordered_map<std::string, engine::DataType>;
// for the purpose to avoid data copy
// the fields_data_ only pass data address, makesure all data address are keep alive
FieldDataMap fields_data_;
FieldTypeMap fields_type_;
int64_t row_count_ = 0;
// to return entities id
......
......@@ -105,7 +105,7 @@ RecordDataAddr(const std::string& field_name, int32_t num, const T* data, Insert
void
RecordVectorDataAddr(const std::string& field_name,
const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records,
InsertParam& insert_param) {
InsertParam& insert_param, bool& is_binary) {
// calculate data size
int64_t float_data_size = 0, binary_data_size = 0;
for (auto& record : grpc_records) {
......@@ -114,10 +114,12 @@ RecordVectorDataAddr(const std::string& field_name,
}
if (float_data_size > 0) {
is_binary = false;
for (auto& record : grpc_records) {
RecordDataAddr<float>(field_name, record.float_data_size(), record.float_data().data(), insert_param);
}
} else if (binary_data_size > 0) {
is_binary = true;
for (auto& record : grpc_records) {
RecordDataAddr<char>(field_name, record.binary_data().size(), record.binary_data().data(), insert_param);
}
......@@ -1366,28 +1368,36 @@ GrpcRequestHandler::OnInsert(::grpc::ServerContext* context, const ::milvus::grp
}
RecordDataAddr<int32_t>(field_name, grpc_int32_size, field.attr_record().int32_value().data(),
insert_param);
insert_param.fields_type_.insert(std::make_pair(field_name, engine::DataType::INT32));
} else if (grpc_int64_size > 0) {
if (!valid_row_count(row_num, grpc_int64_size)) {
return ::grpc::Status::OK;
}
RecordDataAddr<int64_t>(field_name, grpc_int64_size, field.attr_record().int64_value().data(),
insert_param);
insert_param.fields_type_.insert(std::make_pair(field_name, engine::DataType::INT64));
} else if (grpc_float_size > 0) {
if (!valid_row_count(row_num, grpc_float_size)) {
return ::grpc::Status::OK;
}
RecordDataAddr<float>(field_name, grpc_float_size, field.attr_record().float_value().data(), insert_param);
insert_param.fields_type_.insert(std::make_pair(field_name, engine::DataType::FLOAT));
} else if (grpc_double_size > 0) {
if (!valid_row_count(row_num, grpc_double_size)) {
return ::grpc::Status::OK;
}
RecordDataAddr<double>(field_name, grpc_double_size, field.attr_record().double_value().data(),
insert_param);
insert_param.fields_type_.insert(std::make_pair(field_name, engine::DataType::DOUBLE));
} else {
if (!valid_row_count(row_num, field.vector_record().records_size())) {
return ::grpc::Status::OK;
}
RecordVectorDataAddr(field_name, field.vector_record().records(), insert_param);
bool is_binary = false;
RecordVectorDataAddr(field_name, field.vector_record().records(), insert_param, is_binary);
engine::DataType dt = is_binary ? engine::DataType::VECTOR_BINARY : engine::DataType::VECTOR_FLOAT;
insert_param.fields_type_.insert(std::make_pair(field_name, dt));
}
}
insert_param.row_count_ = row_num;
......@@ -1396,6 +1406,7 @@ GrpcRequestHandler::OnInsert(::grpc::ServerContext* context, const ::milvus::grp
if (request->entity_id_array_size() > 0) {
RecordDataAddr<int64_t>(engine::FIELD_UID, request->entity_id_array_size(), request->entity_id_array().data(),
insert_param);
insert_param.fields_type_.insert(std::make_pair(engine::FIELD_UID, engine::DataType::INT64));
}
std::string collection_name = request->collection_name();
......
......@@ -1656,6 +1656,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
}
// construct chunk data by json object
InsertParam insert_param;
ChunkDataMap chunk_data;
int64_t row_num = entities_json.size();
int64_t offset = 0;
......@@ -1665,9 +1666,11 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
if (field_name == NAME_ID) {
// special handle id field
CopyRowStructuredData<int64_t>(entity.value(), engine::FIELD_UID, offset, row_num, chunk_data);
insert_param.fields_type_.insert(std::make_pair(engine::FIELD_UID, engine::DataType::INT64));
continue;
}
insert_param.fields_type_.insert(std::make_pair(field_name, field_types.at(field_name)));
switch (field_types.at(field_name)) {
case engine::DataType::INT32: {
CopyRowStructuredData<int32_t>(entity.value(), field_name, offset, row_num, chunk_data);
......@@ -1698,7 +1701,6 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
}
// conver to InsertParam, no memory copy, just record the data address and pass to InsertReq
InsertParam insert_param;
ConvertToParam(chunk_data, row_num, insert_param);
// do insert
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册