未验证 提交 26536fac 编写于 作者: Y yukun 提交者: GitHub

Fix CreateCollection extra_param bugs (#2995)

* Use unordered_map in CollectionMappings
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix ParseMetaUri bug
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Delete GetVectorsByID
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix CreateCollection extra_param bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Change dimension to dim
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix InsertEntities bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 0f7e8085
......@@ -21,9 +21,10 @@ const char* DEFAULT_BLOOM_FILTER_NAME = "_blf";
const char* DEFAULT_DELETED_DOCS_NAME = "_del";
const char* DEFAULT_INDEX_NAME = "_idx";
const char* PARAM_COLLECTION_DIMENSION = "dimension";
const char* PARAM_COLLECTION_DIMENSION = "dim";
const char* PARAM_INDEX_METRIC_TYPE = "metric_type";
const char* PARAM_INDEX_EXTRA_PARAMS = "extra_params";
const char* PARAM_SEGMENT_SIZE = "segment_size";
} // namespace engine
} // namespace milvus
......@@ -91,6 +91,7 @@ extern const char* DEFAULT_INDEX_NAME;
extern const char* PARAM_COLLECTION_DIMENSION;
extern const char* PARAM_INDEX_METRIC_TYPE;
extern const char* PARAM_INDEX_EXTRA_PARAMS;
extern const char* PARAM_SEGMENT_SIZE;
using FieldType = meta::hybrid::DataType;
......
......@@ -709,7 +709,8 @@ CreateCollectionOperation::DoExecute(StorePtr store) {
return Status(SS_DUPLICATED_ERROR, emsg.str());
}
auto status = store->CreateResource<Collection>(Collection(c_context_.collection->GetName()), collection);
auto status = store->CreateResource<Collection>(
Collection(c_context_.collection->GetName(), c_context_.collection->GetParams()), collection);
if (!status.ok()) {
std::cerr << status.ToString() << std::endl;
return status;
......
......@@ -279,9 +279,10 @@ RequestHandler::HasHybridCollection(const std::shared_ptr<Context>& context, std
Status
RequestHandler::InsertEntity(const std::shared_ptr<Context>& context, const std::string& collection_name,
const std::string& partition_name,
const std::string& partition_name, const int32_t& row_count,
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data) {
BaseRequestPtr request_ptr = InsertEntityRequest::Create(context, collection_name, partition_name, chunk_data);
BaseRequestPtr request_ptr =
InsertEntityRequest::Create(context, collection_name, partition_name, row_count, chunk_data);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
......
......@@ -129,7 +129,8 @@ class RequestHandler {
Status
InsertEntity(const std::shared_ptr<Context>& context, const std::string& collection_name,
const std::string& partition_name, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
const std::string& partition_name, const int32_t& row_count,
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
Status
GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
......
......@@ -97,8 +97,8 @@ CreateHybridCollectionRequest::OnExecute() {
if (field_type.second == engine::meta::hybrid::DataType::VECTOR_FLOAT ||
field_type.second == engine::meta::hybrid::DataType::VECTOR_BINARY) {
vector_param = milvus::json::parse(field_param);
if (vector_param.contains("dimension")) {
dimension = vector_param["dimension"].get<uint16_t>();
if (vector_param.contains(engine::PARAM_COLLECTION_DIMENSION)) {
dimension = vector_param[engine::PARAM_COLLECTION_DIMENSION].get<uint16_t>();
} else {
return Status{milvus::SERVER_INVALID_VECTOR_DIMENSION,
"Dimension should be defined in vector field extra_params"};
......@@ -110,8 +110,8 @@ CreateHybridCollectionRequest::OnExecute() {
collection_info.collection_id_ = collection_name_;
collection_info.dimension_ = dimension;
if (extra_params_.contains("segment_size")) {
auto segment_size = extra_params_["segment_size"].get<int64_t>();
if (extra_params_.contains(engine::PARAM_SEGMENT_SIZE)) {
auto segment_size = extra_params_[engine::PARAM_SEGMENT_SIZE].get<int64_t>();
collection_info.index_file_size_ = segment_size;
status = ValidateCollectionIndexFileSize(segment_size);
if (!status.ok()) {
......@@ -119,14 +119,15 @@ CreateHybridCollectionRequest::OnExecute() {
}
}
if (vector_param.contains("metric_type")) {
int32_t metric_type = (int32_t)milvus::engine::s_map_metric_type.at(vector_param["metric_type"]);
if (vector_param.contains(engine::PARAM_INDEX_METRIC_TYPE)) {
int32_t metric_type =
(int32_t)milvus::engine::s_map_metric_type.at(vector_param[engine::PARAM_INDEX_METRIC_TYPE]);
collection_info.metric_type_ = metric_type;
}
// step 3: create snapshot collection
engine::snapshot::CreateCollectionContext create_collection_context;
auto ss_collection_schema = std::make_shared<engine::snapshot::Collection>(collection_name_);
auto ss_collection_schema = std::make_shared<engine::snapshot::Collection>(collection_name_, extra_params_);
create_collection_context.collection = ss_collection_schema;
for (const auto& schema : fields_schema.fields_schema_) {
auto field = std::make_shared<engine::snapshot::Field>(
......
......@@ -49,17 +49,20 @@ DescribeHybridCollectionRequest::OnExecute() {
try {
engine::snapshot::CollectionPtr collection;
engine::snapshot::CollectionMappings fields_schema;
auto status = DBWrapper::SSDB()->DescribeCollection(collection_name_, collection, fields_schema);
engine::snapshot::CollectionMappings collection_mappings;
auto status = DBWrapper::SSDB()->DescribeCollection(collection_name_, collection, collection_mappings);
if (!status.ok()) {
return status;
}
collection_schema_.collection_name_ = collection_name_;
collection_schema_.extra_params_ = collection->GetParams();
engine::meta::hybrid::FieldsSchema fields_info;
for (auto field_it = fields_schema.begin(); field_it != fields_schema.end(); field_it++) {
engine::meta::hybrid::FieldsSchema fields_schema;
for (auto field_it = collection_mappings.begin(); field_it != collection_mappings.end(); field_it++) {
engine::meta::hybrid::FieldSchema schema;
auto field = field_it->first;
if (field->GetFtype() == (int)engine::meta::hybrid::DataType::UID) {
continue;
}
schema.field_name_ = field->GetName();
schema.field_type_ = (int)field->GetFtype();
schema.field_params_ = field->GetParams().dump();
......@@ -71,9 +74,10 @@ DescribeHybridCollectionRequest::OnExecute() {
break;
}
}
fields_schema.fields_schema_.emplace_back(schema);
}
for (const auto& schema : fields_info.fields_schema_) {
for (const auto& schema : fields_schema.fields_schema_) {
auto field_name = schema.field_name_;
collection_schema_.field_types_.insert(
std::make_pair(field_name, (engine::meta::hybrid::DataType)schema.field_type_));
......
......@@ -34,18 +34,21 @@ namespace server {
InsertEntityRequest::InsertEntityRequest(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, const std::string& partition_name,
const int32_t& row_count,
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data)
: BaseRequest(context, BaseRequest::kInsertEntity),
collection_name_(collection_name),
partition_name_(partition_name),
row_count_(row_count),
chunk_data_(chunk_data) {
}
BaseRequestPtr
InsertEntityRequest::Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const std::string& partition_name,
const std::string& partition_name, const int32_t& row_count,
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data) {
return std::shared_ptr<BaseRequest>(new InsertEntityRequest(context, collection_name, partition_name, chunk_data));
return std::shared_ptr<BaseRequest>(
new InsertEntityRequest(context, collection_name, partition_name, row_count, chunk_data));
}
Status
......@@ -79,6 +82,7 @@ InsertEntityRequest::OnExecute() {
}
engine::DataChunkPtr data_chunk = std::make_shared<engine::DataChunk>();
data_chunk->count_ = row_count_;
data_chunk->fixed_fields_.swap(chunk_data_);
status = DBWrapper::SSDB()->InsertEntities(collection_name_, partition_name_, data_chunk);
if (!status.ok()) {
......
......@@ -25,11 +25,12 @@ class InsertEntityRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const std::string& partition_name, std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
const std::string& partition_name, const int32_t& row_count,
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
protected:
InsertEntityRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const std::string& partition_name,
const std::string& partition_name, const int32_t& row_count,
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data);
Status
......@@ -38,6 +39,7 @@ class InsertEntityRequest : public BaseRequest {
private:
const std::string collection_name_;
const std::string partition_name_;
const int32_t row_count_;
std::unordered_map<std::string, std::vector<uint8_t>>& chunk_data_;
};
......
......@@ -1291,6 +1291,9 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
temp_data.resize(grpc_double_size * sizeof(double));
memcpy(temp_data.data(), field.attr_record().double_value().data(), grpc_double_size * sizeof(double));
} else {
if (!valid_row_count(row_num, field.vector_record().records_size())) {
return ::grpc::Status::OK;
}
CopyVectorData(field.vector_record().records(), temp_data);
}
......@@ -1307,7 +1310,8 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
std::string collection_name = request->collection_name();
std::string partition_name = request->partition_tag();
Status status = request_handler_.InsertEntity(GetContext(context), collection_name, partition_name, chunk_data);
Status status =
request_handler_.InsertEntity(GetContext(context), collection_name, partition_name, row_num, chunk_data);
if (!status.ok()) {
SET_RESPONSE(response->mutable_status(), status, context);
return ::grpc::Status::OK;
......@@ -1316,7 +1320,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
// return generated ids
auto pair = chunk_data.find(engine::DEFAULT_UID_NAME);
if (pair != chunk_data.end()) {
response->mutable_entity_id_array()->Resize(static_cast<int>(pair->second.size()), 0);
response->mutable_entity_id_array()->Resize(static_cast<int>(pair->second.size() / sizeof(int64_t)), 0);
memcpy(response->mutable_entity_id_array()->mutable_data(), pair->second.data(), pair->second.size());
}
......
......@@ -1755,7 +1755,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
auto body_json = nlohmann::json::parse(body->c_str());
std::string partition_name = body_json["partition_tag"];
uint64_t row_num = body_json["row_num"];
int32_t row_num = body_json["row_num"];
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_types;
auto status = Status::OK();
......@@ -1811,7 +1811,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
chunk_data.insert(std::make_pair(field_name, temp_data));
}
status = request_handler_.InsertEntity(context_ptr_, collection_name->c_str(), partition_name, chunk_data);
status = request_handler_.InsertEntity(context_ptr_, collection_name->c_str(), partition_name, row_num, chunk_data);
if (!status.ok()) {
RETURN_STATUS_DTO(UNEXPECTED_ERROR, "Failed to insert data");
}
......
......@@ -35,7 +35,10 @@ constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value,
constexpr int64_t ADD_ENTITY_LOOP = 1;
constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFFLAT;
constexpr int32_t NLIST = 16384;
constexpr char* PARTITION_TAG = "part";
const char* PARTITION_TAG = "part";
const char* DIMENSION = "dim";
const char* METRICTYPE = "metric_type";
const char* INDEXTYPE = "index_type";
void
PrintEntity(const std::string& tag, const milvus::VectorData& entity) {
......@@ -115,10 +118,11 @@ ClientTest::CreateCollection(const std::string& collection_name) {
field_ptr4->field_name = "field_vec";
field_ptr4->field_type = milvus::DataType::VECTOR_FLOAT;
JSON index_param_4;
index_param_4["name"] = "index_3";
index_param_4["name"] = "index_vec";
field_ptr4->index_params = index_param_4.dump();
JSON extra_params_4;
extra_params_4["dimension"] = COLLECTION_DIMENSION;
extra_params_4[METRICTYPE] = "L2";
extra_params_4[DIMENSION] = COLLECTION_DIMENSION;
field_ptr4->extra_params = extra_params_4.dump();
JSON extra_params;
......@@ -150,6 +154,7 @@ ClientTest::InsertEntities(const std::string& collection_name) {
milvus_sdk::Utils::BuildEntities(begin_index, begin_index + BATCH_ENTITY_COUNT, field_value, entity_ids,
COLLECTION_DIMENSION);
}
entity_ids.clear();
milvus::Status status = conn_->Insert(collection_name, "", field_value, entity_ids);
search_id_array_.emplace_back(entity_ids[10]);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
......@@ -157,6 +162,13 @@ ClientTest::InsertEntities(const std::string& collection_name) {
}
}
void
ClientTest::CountEntities(const std::string& collection_name) {
int64_t entity_count = 0;
auto status = conn_->CountEntities(collection_name, entity_count);
std::cout << "Collection " << collection_name << " entity count: " << entity_count << std::endl;
}
void
ClientTest::Flush(const std::string& collection_name) {
milvus_sdk::TimeRecorder rc("Flush");
......@@ -338,21 +350,22 @@ ClientTest::Test() {
InsertEntities(collection_name);
Flush(collection_name);
GetCollectionStats(collection_name);
BuildVectors(NQ, COLLECTION_DIMENSION);
GetEntityByID(collection_name, search_id_array_);
SearchEntities(collection_name, TOP_K, NPROBE);
GetCollectionStats(collection_name);
std::vector<int64_t> delete_ids = {search_id_array_[0], search_id_array_[1]};
DeleteByIds(collection_name, delete_ids);
GetEntityByID(collection_name, search_id_array_);
CompactCollection(collection_name);
LoadCollection(collection_name);
SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two entities
DropIndex(collection_name, "field_vec", "index_3");
DropCollection(collection_name);
CountEntities(collection_name);
// GetCollectionStats(collection_name);
//
// BuildVectors(NQ, COLLECTION_DIMENSION);
// GetEntityByID(collection_name, search_id_array_);
// SearchEntities(collection_name, TOP_K, NPROBE);
// GetCollectionStats(collection_name);
//
// std::vector<int64_t> delete_ids = {search_id_array_[0], search_id_array_[1]};
// DeleteByIds(collection_name, delete_ids);
// GetEntityByID(collection_name, search_id_array_);
// CompactCollection(collection_name);
//
// LoadCollection(collection_name);
// SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two entities
//
// DropIndex(collection_name, "field_vec", "index_3");
// DropCollection(collection_name);
}
......@@ -45,7 +45,8 @@ class ClientTest {
void
InsertEntities(const std::string&);
void BuildSearchEntities(int64_t, int64_t);
void
CountEntities(const std::string&);
void
Flush(const std::string&);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册