未验证 提交 8f8fa0aa 编写于 作者: G groot 提交者: GitHub

use index type replace name (#3114)

* use index type replace name
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>
上级 6ecf7a27
...@@ -50,26 +50,23 @@ SetSnapshotIndex(const std::string& collection_name, const std::string& field_na ...@@ -50,26 +50,23 @@ SetSnapshotIndex(const std::string& collection_name, const std::string& field_na
} }
snapshot::OperationContext ss_context; snapshot::OperationContext ss_context;
auto index_element =
std::make_shared<snapshot::FieldElement>(ss->GetCollectionId(), field->GetID(), index_info.index_name_,
milvus::engine::FieldElementType::FET_INDEX, index_info.index_type_);
ss_context.new_field_elements.push_back(index_element);
if (IsVectorField(field)) { if (IsVectorField(field)) {
auto new_element = std::make_shared<snapshot::FieldElement>(
ss->GetCollectionId(), field->GetID(), index_info.index_name_, milvus::engine::FieldElementType::FET_INDEX);
milvus::json json; milvus::json json;
json[engine::PARAM_INDEX_METRIC_TYPE] = index_info.metric_name_; json[engine::PARAM_INDEX_METRIC_TYPE] = index_info.metric_name_;
json[engine::PARAM_INDEX_EXTRA_PARAMS] = index_info.extra_params_; json[engine::PARAM_INDEX_EXTRA_PARAMS] = index_info.extra_params_;
new_element->SetParams(json); index_element->SetParams(json);
ss_context.new_field_elements.push_back(new_element);
if (index_info.index_name_ == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR || if (index_info.index_name_ == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR ||
index_info.index_name_ == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) { index_info.index_name_ == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) {
auto new_element = std::make_shared<snapshot::FieldElement>( auto compress_element = std::make_shared<snapshot::FieldElement>(
ss->GetCollectionId(), field->GetID(), DEFAULT_INDEX_COMPRESS_NAME, ss->GetCollectionId(), field->GetID(), DEFAULT_INDEX_COMPRESS_NAME,
milvus::engine::FieldElementType::FET_COMPRESS_SQ8); milvus::engine::FieldElementType::FET_COMPRESS_SQ8);
ss_context.new_field_elements.push_back(new_element); ss_context.new_field_elements.push_back(compress_element);
} }
} else {
auto new_element = std::make_shared<snapshot::FieldElement>(
ss->GetCollectionId(), field->GetID(), index_info.index_name_, milvus::engine::FieldElementType::FET_INDEX);
ss_context.new_field_elements.push_back(new_element);
} }
auto op = std::make_shared<snapshot::AddFieldElementOperation>(ss_context, ss); auto op = std::make_shared<snapshot::AddFieldElementOperation>(ss_context, ss);
...@@ -97,6 +94,7 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na ...@@ -97,6 +94,7 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na
for (auto& field_element : field_elements) { for (auto& field_element : field_elements) {
if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) { if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) {
index_info.index_name_ = field_element->GetName(); index_info.index_name_ = field_element->GetName();
index_info.index_type_ = field_element->GetTypeName();
auto json = field_element->GetParams(); auto json = field_element->GetParams();
if (json.find(engine::PARAM_INDEX_METRIC_TYPE) != json.end()) { if (json.find(engine::PARAM_INDEX_METRIC_TYPE) != json.end()) {
index_info.metric_name_ = json[engine::PARAM_INDEX_METRIC_TYPE]; index_info.metric_name_ = json[engine::PARAM_INDEX_METRIC_TYPE];
...@@ -110,7 +108,8 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na ...@@ -110,7 +108,8 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na
} else { } else {
for (auto& field_element : field_elements) { for (auto& field_element : field_elements) {
if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) { if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) {
index_info.index_name_ = DEFAULT_STRUCTURED_INDEX_NAME; index_info.index_name_ = field_element->GetName();
index_info.index_type_ = field_element->GetTypeName();
} }
} }
} }
......
...@@ -78,6 +78,7 @@ using DataChunkPtr = std::shared_ptr<DataChunk>; ...@@ -78,6 +78,7 @@ using DataChunkPtr = std::shared_ptr<DataChunk>;
struct CollectionIndex { struct CollectionIndex {
std::string index_name_; std::string index_name_;
std::string index_type_;
std::string metric_name_; std::string metric_name_;
milvus::json extra_params_ = {{"nlist", 2048}}; milvus::json extra_params_ = {{"nlist", 2048}};
}; };
......
...@@ -49,7 +49,7 @@ GetMicroSecTimeStamp() { ...@@ -49,7 +49,7 @@ GetMicroSecTimeStamp() {
bool bool
IsSameIndex(const CollectionIndex& index1, const CollectionIndex& index2) { IsSameIndex(const CollectionIndex& index1, const CollectionIndex& index2) {
return index1.index_name_ == index2.index_name_ && index1.extra_params_ == index2.extra_params_ && return index1.index_type_ == index2.index_type_ && index1.extra_params_ == index2.extra_params_ &&
index1.metric_name_ == index2.metric_name_; index1.metric_name_ == index2.metric_name_;
} }
......
...@@ -647,6 +647,7 @@ ExecutionEngineImpl::CreateSnapshotIndexFile(AddSegmentFileOperation& operation, ...@@ -647,6 +647,7 @@ ExecutionEngineImpl::CreateSnapshotIndexFile(AddSegmentFileOperation& operation,
auto& index_element = element_visitor->GetElement(); auto& index_element = element_visitor->GetElement();
index_info.index_name_ = index_element->GetName(); index_info.index_name_ = index_element->GetName();
index_info.index_type_ = index_element->GetTypeName();
auto params = index_element->GetParams(); auto params = index_element->GetParams();
if (params.find(engine::PARAM_INDEX_METRIC_TYPE) != params.end()) { if (params.find(engine::PARAM_INDEX_METRIC_TYPE) != params.end()) {
index_info.metric_name_ = params[engine::PARAM_INDEX_METRIC_TYPE]; index_info.metric_name_ = params[engine::PARAM_INDEX_METRIC_TYPE];
...@@ -728,7 +729,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col ...@@ -728,7 +729,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
} }
// build index by knowhere // build index by knowhere
new_index = CreateVecIndex(index_info.index_name_); new_index = CreateVecIndex(index_info.index_type_);
if (!new_index) { if (!new_index) {
throw Exception(DB_ERROR, "Unsupported index type"); throw Exception(DB_ERROR, "Unsupported index type");
} }
......
...@@ -84,7 +84,7 @@ CreateIndexReq::OnExecute() { ...@@ -84,7 +84,7 @@ CreateIndexReq::OnExecute() {
int64_t dimension = params[engine::PARAM_DIMENSION].get<int64_t>(); int64_t dimension = params[engine::PARAM_DIMENSION].get<int64_t>();
// validate index type // validate index type
std::string index_type = 0; std::string index_type;
if (json_params_.contains(engine::PARAM_INDEX_TYPE)) { if (json_params_.contains(engine::PARAM_INDEX_TYPE)) {
index_type = json_params_[engine::PARAM_INDEX_TYPE].get<std::string>(); index_type = json_params_[engine::PARAM_INDEX_TYPE].get<std::string>();
} }
...@@ -94,7 +94,7 @@ CreateIndexReq::OnExecute() { ...@@ -94,7 +94,7 @@ CreateIndexReq::OnExecute() {
} }
// validate metric type // validate metric type
std::string metric_type = 0; std::string metric_type;
if (json_params_.contains(engine::PARAM_INDEX_METRIC_TYPE)) { if (json_params_.contains(engine::PARAM_INDEX_METRIC_TYPE)) {
metric_type = json_params_[engine::PARAM_INDEX_METRIC_TYPE].get<std::string>(); metric_type = json_params_[engine::PARAM_INDEX_METRIC_TYPE].get<std::string>();
} }
...@@ -111,13 +111,19 @@ CreateIndexReq::OnExecute() { ...@@ -111,13 +111,19 @@ CreateIndexReq::OnExecute() {
rc.RecordSection("check validation"); rc.RecordSection("check validation");
index.index_name_ = index_type; index.index_name_ = index_name_;
index.index_type_ = index_type;
index.metric_name_ = metric_type; index.metric_name_ = metric_type;
if (json_params_.contains(engine::PARAM_INDEX_EXTRA_PARAMS)) { if (json_params_.contains(engine::PARAM_INDEX_EXTRA_PARAMS)) {
index.extra_params_ = json_params_[engine::PARAM_INDEX_EXTRA_PARAMS]; index.extra_params_ = json_params_[engine::PARAM_INDEX_EXTRA_PARAMS];
} }
} else { } else {
index.index_name_ = index_name_; index.index_name_ = index_name_;
std::string index_type;
if (json_params_.contains(engine::PARAM_INDEX_TYPE)) {
index_type = json_params_[engine::PARAM_INDEX_TYPE].get<std::string>();
}
index.index_type_ = index_type;
} }
STATUS_CHECK(DBWrapper::DB()->CreateIndex(context_, collection_name_, field_name_, index)); STATUS_CHECK(DBWrapper::DB()->CreateIndex(context_, collection_name_, field_name_, index));
......
...@@ -502,7 +502,8 @@ TEST_F(DBTest, IndexTest) { ...@@ -502,7 +502,8 @@ TEST_F(DBTest, IndexTest) {
{ {
milvus::engine::CollectionIndex index; milvus::engine::CollectionIndex index;
index.index_name_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; index.index_name_ = "my_index1";
index.index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
index.metric_name_ = milvus::knowhere::Metric::L2; index.metric_name_ = milvus::knowhere::Metric::L2;
index.extra_params_["nlist"] = 2048; index.extra_params_["nlist"] = 2048;
status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index); status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index);
...@@ -512,13 +513,15 @@ TEST_F(DBTest, IndexTest) { ...@@ -512,13 +513,15 @@ TEST_F(DBTest, IndexTest) {
status = db_->DescribeIndex(collection_name, VECTOR_FIELD_NAME, index_get); status = db_->DescribeIndex(collection_name, VECTOR_FIELD_NAME, index_get);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(index.index_name_, index_get.index_name_); ASSERT_EQ(index.index_name_, index_get.index_name_);
ASSERT_EQ(index.index_type_, index_get.index_type_);
ASSERT_EQ(index.metric_name_, index_get.metric_name_); ASSERT_EQ(index.metric_name_, index_get.metric_name_);
ASSERT_EQ(index.extra_params_, index_get.extra_params_); ASSERT_EQ(index.extra_params_, index_get.extra_params_);
} }
{ {
milvus::engine::CollectionIndex index; milvus::engine::CollectionIndex index;
index.index_name_ = "SORTED"; index.index_name_ = "my_index2";
index.index_type_ = milvus::engine::DEFAULT_STRUCTURED_INDEX_NAME;
status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index); status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index); status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index);
...@@ -530,6 +533,7 @@ TEST_F(DBTest, IndexTest) { ...@@ -530,6 +533,7 @@ TEST_F(DBTest, IndexTest) {
status = db_->DescribeIndex(collection_name, "field_0", index_get); status = db_->DescribeIndex(collection_name, "field_0", index_get);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
ASSERT_EQ(index.index_name_, index_get.index_name_); ASSERT_EQ(index.index_name_, index_get.index_name_);
ASSERT_EQ(index.index_type_, index_get.index_type_);
} }
{ {
...@@ -577,7 +581,7 @@ TEST_F(DBTest, StatsTest) { ...@@ -577,7 +581,7 @@ TEST_F(DBTest, StatsTest) {
{ {
milvus::engine::CollectionIndex index; milvus::engine::CollectionIndex index;
index.index_name_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; index.index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
index.metric_name_ = milvus::knowhere::Metric::L2; index.metric_name_ = milvus::knowhere::Metric::L2;
index.extra_params_["nlist"] = 2048; index.extra_params_["nlist"] = 2048;
status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index); status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index);
...@@ -586,7 +590,7 @@ TEST_F(DBTest, StatsTest) { ...@@ -586,7 +590,7 @@ TEST_F(DBTest, StatsTest) {
{ {
milvus::engine::CollectionIndex index; milvus::engine::CollectionIndex index;
index.index_name_ = "SORTED"; index.index_type_ = milvus::engine::DEFAULT_STRUCTURED_INDEX_NAME;
status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index); status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index);
ASSERT_TRUE(status.ok()); ASSERT_TRUE(status.ok());
status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index); status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册