未验证 提交 8f8fa0aa 编写于 作者: G groot 提交者: GitHub

use index type replace name (#3114)

* use index type replace name
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>
上级 6ecf7a27
......@@ -50,26 +50,23 @@ SetSnapshotIndex(const std::string& collection_name, const std::string& field_na
}
snapshot::OperationContext ss_context;
auto index_element =
std::make_shared<snapshot::FieldElement>(ss->GetCollectionId(), field->GetID(), index_info.index_name_,
milvus::engine::FieldElementType::FET_INDEX, index_info.index_type_);
ss_context.new_field_elements.push_back(index_element);
if (IsVectorField(field)) {
auto new_element = std::make_shared<snapshot::FieldElement>(
ss->GetCollectionId(), field->GetID(), index_info.index_name_, milvus::engine::FieldElementType::FET_INDEX);
milvus::json json;
json[engine::PARAM_INDEX_METRIC_TYPE] = index_info.metric_name_;
json[engine::PARAM_INDEX_EXTRA_PARAMS] = index_info.extra_params_;
new_element->SetParams(json);
ss_context.new_field_elements.push_back(new_element);
index_element->SetParams(json);
if (index_info.index_name_ == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR ||
index_info.index_name_ == knowhere::IndexEnum::INDEX_HNSW_SQ8NM) {
auto new_element = std::make_shared<snapshot::FieldElement>(
auto compress_element = std::make_shared<snapshot::FieldElement>(
ss->GetCollectionId(), field->GetID(), DEFAULT_INDEX_COMPRESS_NAME,
milvus::engine::FieldElementType::FET_COMPRESS_SQ8);
ss_context.new_field_elements.push_back(new_element);
ss_context.new_field_elements.push_back(compress_element);
}
} else {
auto new_element = std::make_shared<snapshot::FieldElement>(
ss->GetCollectionId(), field->GetID(), index_info.index_name_, milvus::engine::FieldElementType::FET_INDEX);
ss_context.new_field_elements.push_back(new_element);
}
auto op = std::make_shared<snapshot::AddFieldElementOperation>(ss_context, ss);
......@@ -97,6 +94,7 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na
for (auto& field_element : field_elements) {
if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) {
index_info.index_name_ = field_element->GetName();
index_info.index_type_ = field_element->GetTypeName();
auto json = field_element->GetParams();
if (json.find(engine::PARAM_INDEX_METRIC_TYPE) != json.end()) {
index_info.metric_name_ = json[engine::PARAM_INDEX_METRIC_TYPE];
......@@ -110,7 +108,8 @@ GetSnapshotIndex(const std::string& collection_name, const std::string& field_na
} else {
for (auto& field_element : field_elements) {
if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) {
index_info.index_name_ = DEFAULT_STRUCTURED_INDEX_NAME;
index_info.index_name_ = field_element->GetName();
index_info.index_type_ = field_element->GetTypeName();
}
}
}
......
......@@ -78,6 +78,7 @@ using DataChunkPtr = std::shared_ptr<DataChunk>;
struct CollectionIndex {
std::string index_name_;
std::string index_type_;
std::string metric_name_;
milvus::json extra_params_ = {{"nlist", 2048}};
};
......
......@@ -49,7 +49,7 @@ GetMicroSecTimeStamp() {
bool
IsSameIndex(const CollectionIndex& index1, const CollectionIndex& index2) {
return index1.index_name_ == index2.index_name_ && index1.extra_params_ == index2.extra_params_ &&
return index1.index_type_ == index2.index_type_ && index1.extra_params_ == index2.extra_params_ &&
index1.metric_name_ == index2.metric_name_;
}
......
......@@ -647,6 +647,7 @@ ExecutionEngineImpl::CreateSnapshotIndexFile(AddSegmentFileOperation& operation,
auto& index_element = element_visitor->GetElement();
index_info.index_name_ = index_element->GetName();
index_info.index_type_ = index_element->GetTypeName();
auto params = index_element->GetParams();
if (params.find(engine::PARAM_INDEX_METRIC_TYPE) != params.end()) {
index_info.metric_name_ = params[engine::PARAM_INDEX_METRIC_TYPE];
......@@ -728,7 +729,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
}
// build index by knowhere
new_index = CreateVecIndex(index_info.index_name_);
new_index = CreateVecIndex(index_info.index_type_);
if (!new_index) {
throw Exception(DB_ERROR, "Unsupported index type");
}
......
......@@ -84,7 +84,7 @@ CreateIndexReq::OnExecute() {
int64_t dimension = params[engine::PARAM_DIMENSION].get<int64_t>();
// validate index type
std::string index_type = 0;
std::string index_type;
if (json_params_.contains(engine::PARAM_INDEX_TYPE)) {
index_type = json_params_[engine::PARAM_INDEX_TYPE].get<std::string>();
}
......@@ -94,7 +94,7 @@ CreateIndexReq::OnExecute() {
}
// validate metric type
std::string metric_type = 0;
std::string metric_type;
if (json_params_.contains(engine::PARAM_INDEX_METRIC_TYPE)) {
metric_type = json_params_[engine::PARAM_INDEX_METRIC_TYPE].get<std::string>();
}
......@@ -111,13 +111,19 @@ CreateIndexReq::OnExecute() {
rc.RecordSection("check validation");
index.index_name_ = index_type;
index.index_name_ = index_name_;
index.index_type_ = index_type;
index.metric_name_ = metric_type;
if (json_params_.contains(engine::PARAM_INDEX_EXTRA_PARAMS)) {
index.extra_params_ = json_params_[engine::PARAM_INDEX_EXTRA_PARAMS];
}
} else {
index.index_name_ = index_name_;
std::string index_type;
if (json_params_.contains(engine::PARAM_INDEX_TYPE)) {
index_type = json_params_[engine::PARAM_INDEX_TYPE].get<std::string>();
}
index.index_type_ = index_type;
}
STATUS_CHECK(DBWrapper::DB()->CreateIndex(context_, collection_name_, field_name_, index));
......
......@@ -502,7 +502,8 @@ TEST_F(DBTest, IndexTest) {
{
milvus::engine::CollectionIndex index;
index.index_name_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
index.index_name_ = "my_index1";
index.index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
index.metric_name_ = milvus::knowhere::Metric::L2;
index.extra_params_["nlist"] = 2048;
status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index);
......@@ -512,13 +513,15 @@ TEST_F(DBTest, IndexTest) {
status = db_->DescribeIndex(collection_name, VECTOR_FIELD_NAME, index_get);
ASSERT_TRUE(status.ok());
ASSERT_EQ(index.index_name_, index_get.index_name_);
ASSERT_EQ(index.index_type_, index_get.index_type_);
ASSERT_EQ(index.metric_name_, index_get.metric_name_);
ASSERT_EQ(index.extra_params_, index_get.extra_params_);
}
{
milvus::engine::CollectionIndex index;
index.index_name_ = "SORTED";
index.index_name_ = "my_index2";
index.index_type_ = milvus::engine::DEFAULT_STRUCTURED_INDEX_NAME;
status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index);
ASSERT_TRUE(status.ok());
status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index);
......@@ -530,6 +533,7 @@ TEST_F(DBTest, IndexTest) {
status = db_->DescribeIndex(collection_name, "field_0", index_get);
ASSERT_TRUE(status.ok());
ASSERT_EQ(index.index_name_, index_get.index_name_);
ASSERT_EQ(index.index_type_, index_get.index_type_);
}
{
......@@ -577,7 +581,7 @@ TEST_F(DBTest, StatsTest) {
{
milvus::engine::CollectionIndex index;
index.index_name_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
index.index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
index.metric_name_ = milvus::knowhere::Metric::L2;
index.extra_params_["nlist"] = 2048;
status = db_->CreateIndex(dummy_context_, collection_name, VECTOR_FIELD_NAME, index);
......@@ -586,7 +590,7 @@ TEST_F(DBTest, StatsTest) {
{
milvus::engine::CollectionIndex index;
index.index_name_ = "SORTED";
index.index_type_ = milvus::engine::DEFAULT_STRUCTURED_INDEX_NAME;
status = db_->CreateIndex(dummy_context_, collection_name, "field_0", index);
ASSERT_TRUE(status.ok());
status = db_->CreateIndex(dummy_context_, collection_name, "field_1", index);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册