From 8f8fa0aaf24b836653c1be628030afa3655ce61c Mon Sep 17 00:00:00 2001 From: groot Date: Tue, 4 Aug 2020 10:45:41 +0800 Subject: [PATCH] use index type replace name (#3114) * use index type replace name Signed-off-by: yhmo --- core/src/db/SnapshotUtils.cpp | 21 +++++++++---------- core/src/db/Types.h | 1 + core/src/db/Utils.cpp | 2 +- core/src/db/engine/ExecutionEngineImpl.cpp | 3 ++- .../delivery/request/CreateIndexReq.cpp | 12 ++++++++--- core/unittest/db/test_db.cpp | 12 +++++++---- 6 files changed, 31 insertions(+), 20 deletions(-) diff --git a/core/src/db/SnapshotUtils.cpp b/core/src/db/SnapshotUtils.cpp index 5e9f95575..f1e96a658 100644 --- a/core/src/db/SnapshotUtils.cpp +++ b/core/src/db/SnapshotUtils.cpp @@ -50,26 +50,23 @@ SetSnapshotIndex(const std::string& collection_name, const std::string& field_na } snapshot::OperationContext ss_context; + auto index_element = + std::make_shared(ss->GetCollectionId(), field->GetID(), index_info.index_name_, + milvus::engine::FieldElementType::FET_INDEX, index_info.index_type_); + ss_context.new_field_elements.push_back(index_element); if (IsVectorField(field)) { - auto new_element = std::make_shared( - ss->GetCollectionId(), field->GetID(), index_info.index_name_, milvus::engine::FieldElementType::FET_INDEX); milvus::json json; json[engine::PARAM_INDEX_METRIC_TYPE] = index_info.metric_name_; json[engine::PARAM_INDEX_EXTRA_PARAMS] = index_info.extra_params_; - new_element->SetParams(json); - ss_context.new_field_elements.push_back(new_element); + index_element->SetParams(json); if (index_info.index_name_ == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR || index_info.index_name_ == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) { - auto new_element = std::make_shared( + auto compress_element = std::make_shared( ss->GetCollectionId(), field->GetID(), DEFAULT_INDEX_COMPRESS_NAME, milvus::engine::FieldElementType::FET_COMPRESS_SQ8); - ss_context.new_field_elements.push_back(new_element); + ss_context.new_field_elements.push_back(compress_element); } - } else { - auto new_element = std::make_shared( - ss->GetCollectionId(), field->GetID(), index_info.index_name_, milvus::engine::FieldElementType::FET_INDEX); - ss_context.new_field_elements.push_back(new_element); } auto op = std::make_shared(ss_context, ss); @@ -97,6 +94,7 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na for (auto& field_element : field_elements) { if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) { index_info.index_name_ = field_element->GetName(); + index_info.index_type_ = field_element->GetTypeName(); auto json = field_element->GetParams(); if (json.find(engine::PARAM_INDEX_METRIC_TYPE) != json.end()) { index_info.metric_name_ = json[engine::PARAM_INDEX_METRIC_TYPE]; @@ -110,7 +108,8 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na } else { for (auto& field_element : field_elements) { if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) { - index_info.index_name_ = DEFAULT_STRUCTURED_INDEX_NAME; + index_info.index_name_ = field_element->GetName(); + index_info.index_type_ = field_element->GetTypeName(); } } } diff --git a/core/src/db/Types.h b/core/src/db/Types.h index 0a582bc77..a555d278a 100644 --- a/core/src/db/Types.h +++ b/core/src/db/Types.h @@ -78,6 +78,7 @@ using DataChunkPtr = std::shared_ptr; struct CollectionIndex { std::string index_name_; + std::string index_type_; std::string metric_name_; milvus::json extra_params_ = {{"nlist", 2048}}; }; diff --git a/core/src/db/Utils.cpp b/core/src/db/Utils.cpp index 6fb1d11dc..c48a92b5d 100644 --- a/core/src/db/Utils.cpp +++ b/core/src/db/Utils.cpp @@ -49,7 +49,7 @@ GetMicroSecTimeStamp() { bool IsSameIndex(const CollectionIndex& index1, const CollectionIndex& index2) { - return index1.index_name_ == index2.index_name_ && index1.extra_params_ == index2.extra_params_ && + return index1.index_type_ == index2.index_type_ && index1.extra_params_ == index2.extra_params_ && index1.metric_name_ == index2.metric_name_; } diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index ada42c96c..5265f5997 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -647,6 +647,7 @@ ExecutionEngineImpl::CreateSnapshotIndexFile(AddSegmentFileOperation& operation, auto& index_element = element_visitor->GetElement(); index_info.index_name_ = index_element->GetName(); + index_info.index_type_ = index_element->GetTypeName(); auto params = index_element->GetParams(); if (params.find(engine::PARAM_INDEX_METRIC_TYPE) != params.end()) { index_info.metric_name_ = params[engine::PARAM_INDEX_METRIC_TYPE]; @@ -728,7 +729,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col } // build index by knowhere - new_index = CreateVecIndex(index_info.index_name_); + new_index = CreateVecIndex(index_info.index_type_); if (!new_index) { throw Exception(DB_ERROR, "Unsupported index type"); } diff --git a/core/src/server/delivery/request/CreateIndexReq.cpp b/core/src/server/delivery/request/CreateIndexReq.cpp index f36b8bdb2..fb30232c2 100644 --- a/core/src/server/delivery/request/CreateIndexReq.cpp +++ b/core/src/server/delivery/request/CreateIndexReq.cpp @@ -84,7 +84,7 @@ CreateIndexReq::OnExecute() { int64_t dimension = params[engine::PARAM_DIMENSION].get(); // validate index type - std::string index_type = 0; + std::string index_type; if (json_params_.contains(engine::PARAM_INDEX_TYPE)) { index_type = json_params_[engine::PARAM_INDEX_TYPE].get(); } @@ -94,7 +94,7 @@ CreateIndexReq::OnExecute() { } // validate metric type - std::string metric_type = 0; + std::string metric_type; if (json_params_.contains(engine::PARAM_INDEX_METRIC_TYPE)) { metric_type = json_params_[engine::PARAM_INDEX_METRIC_TYPE].get(); } @@ -111,13 +111,19 @@ CreateIndexReq::OnExecute() { rc.RecordSection("check validation"); - index.index_name_ = index_type; + index.index_name_ = index_name_; + index.index_type_ = index_type; index.metric_name_ = metric_type; if (json_params_.contains(engine::PARAM_INDEX_EXTRA_PARAMS)) { index.extra_params_ = json_params_[engine::PARAM_INDEX_EXTRA_PARAMS]; } } else { index.index_name_ = index_name_; + std::string index_type; + if (json_params_.contains(engine::PARAM_INDEX_TYPE)) { + index_type = json_params_[engine::PARAM_INDEX_TYPE].get(); + } + index.index_type_ = index_type; } STATUS_CHECK(DBWrapper::DB()->CreateIndex(context_, collection_name_, field_name_, index)); diff --git a/core/unittest/db/test_db.cpp b/core/unittest/db/test_db.cpp index 72b362322..730a430fe 100644 --- a/core/unittest/db/test_db.cpp +++ b/core/unittest/db/test_db.cpp @@ -502,7 +502,8 @@ TEST_F(DBTest, IndexTest) { { milvus::engine::CollectionIndex index; - index.index_name_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + index.index_name_ = "my_index1"; + index.index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; index.metric_name_ = milvus::knowhere::Metric::L2; index.extra_params_["nlist"] = 2048; status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index); @@ -512,13 +513,15 @@ TEST_F(DBTest, IndexTest) { status = db_->DescribeIndex(collection_name, VECTOR_FIELD_NAME, index_get); ASSERT_TRUE(status.ok()); ASSERT_EQ(index.index_name_, index_get.index_name_); + ASSERT_EQ(index.index_type_, index_get.index_type_); ASSERT_EQ(index.metric_name_, index_get.metric_name_); ASSERT_EQ(index.extra_params_, index_get.extra_params_); } { milvus::engine::CollectionIndex index; - index.index_name_ = "SORTED"; + index.index_name_ = "my_index2"; + index.index_type_ = milvus::engine::DEFAULT_STRUCTURED_INDEX_NAME; status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index); ASSERT_TRUE(status.ok()); status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index); @@ -530,6 +533,7 @@ TEST_F(DBTest, IndexTest) { status = db_->DescribeIndex(collection_name, "field_0", index_get); ASSERT_TRUE(status.ok()); ASSERT_EQ(index.index_name_, index_get.index_name_); + ASSERT_EQ(index.index_type_, index_get.index_type_); } { @@ -577,7 +581,7 @@ TEST_F(DBTest, StatsTest) { { milvus::engine::CollectionIndex index; - index.index_name_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + index.index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; index.metric_name_ = milvus::knowhere::Metric::L2; index.extra_params_["nlist"] = 2048; status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index); @@ -586,7 +590,7 @@ TEST_F(DBTest, StatsTest) { { milvus::engine::CollectionIndex index; - index.index_name_ = "SORTED"; + index.index_type_ = milvus::engine::DEFAULT_STRUCTURED_INDEX_NAME; status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index); ASSERT_TRUE(status.ok()); status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index); -- GitLab