From 679118d34c20633bb21a57666679bccf70637ab7 Mon Sep 17 00:00:00 2001 From: starlord Date: Fri, 12 Jul 2019 15:30:05 +0800 Subject: [PATCH] MS-212 Support Inner product metric type Former-commit-id: 068ed6d011b45f46abc485036ca8e3cf397dfcda --- cpp/CHANGELOG.md | 1 + cpp/conf/server_config.template | 3 +- cpp/src/db/FaissExecutionEngine.cpp | 15 ++++++- cpp/src/db/scheduler/task/SearchTask.cpp | 43 ++++++++++++++----- cpp/src/db/scheduler/task/SearchTask.h | 5 ++- .../sdk/examples/simple/src/ClientTest.cpp | 2 +- cpp/src/server/ServerConfig.h | 1 + cpp/src/wrapper/IndexBuilder.cpp | 9 ++-- cpp/unittest/db/search_test.cpp | 18 ++++---- 9 files changed, 70 insertions(+), 27 deletions(-) diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index 6ea62b4ef..9ac6cb11b 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -18,6 +18,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-204 - Support multi db_path - MS-206 - Support SQ8 index type - MS-208 - Add buildinde interface for C++ SDK +- MS-212 - Support Inner product metric type ## New Feature - MS-195 - Add nlist and use_blas_threshold conf diff --git a/cpp/conf/server_config.template b/cpp/conf/server_config.template index ec5bde9d8..3f3b00445 100644 --- a/cpp/conf/server_config.template +++ b/cpp/conf/server_config.template @@ -36,4 +36,5 @@ cache_config: # cache configure engine_config: nprobe: 10 nlist: 16384 - use_blas_threshold: 20 \ No newline at end of file + use_blas_threshold: 20 + metric_type: L2 #L2 or Inner Product \ No newline at end of file diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index dd22f9cb0..d2cb83536 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -22,15 +22,25 @@ namespace zilliz { namespace milvus { namespace engine { +namespace { +std::string GetMetricType() { + server::ServerConfig &config = server::ServerConfig::GetInstance(); + server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE); + return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2"); +} +} FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension, const std::string& location, const std::string& build_index_type, const std::string& raw_index_type) - : pIndex_(faiss::index_factory(dimension, raw_index_type.c_str())), - location_(location), + : location_(location), build_index_type_(build_index_type), raw_index_type_(raw_index_type) { + + std::string metric_type = GetMetricType(); + faiss::MetricType faiss_metric_type = (metric_type == "L2") ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; + pIndex_.reset(faiss::index_factory(dimension, raw_index_type.c_str(), faiss_metric_type)); } FaissExecutionEngine::FaissExecutionEngine(std::shared_ptr index, @@ -119,6 +129,7 @@ FaissExecutionEngine::BuildIndex(const std::string& location) { auto opd = std::make_shared(); opd->d = pIndex_->d; opd->index_type = build_index_type_; + opd->metric_type = GetMetricType(); IndexBuilderPtr pBuilder = GetIndexBuilder(opd); auto from_index = dynamic_cast(pIndex_.get()); diff --git a/cpp/src/db/scheduler/task/SearchTask.cpp b/cpp/src/db/scheduler/task/SearchTask.cpp index 708bcc870..8036cf986 100644 --- a/cpp/src/db/scheduler/task/SearchTask.cpp +++ b/cpp/src/db/scheduler/task/SearchTask.cpp @@ -30,11 +30,20 @@ void CollectDurationMetrics(int index_type, double total_time) { } } +std::string GetMetricType() { + server::ServerConfig &config = server::ServerConfig::GetInstance(); + server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE); + return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2"); +} + } SearchTask::SearchTask() : IScheduleTask(ScheduleTaskType::kSearch) { - + std::string metric_type = GetMetricType(); + if(metric_type != "L2") { + metric_l2 = false; + } } std::shared_ptr SearchTask::Execute() { @@ -71,7 +80,7 @@ std::shared_ptr SearchTask::Execute() { rc.Record("cluster result"); //step 4: pick up topk result - SearchTask::TopkResult(result_set, inner_k, context->GetResult()); + SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult()); rc.Record("reduce topk"); } catch (std::exception& ex) { @@ -125,7 +134,8 @@ Status SearchTask::ClusterResult(const std::vector &output_ids, Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src, SearchContext::Id2DistanceMap &distance_target, - uint64_t topk) { + uint64_t topk, + bool ascending) { //Note: the score_src and score_target are already arranged by score in ascending order if(distance_src.empty()) { SERVER_LOG_WARNING << "Empty distance source array"; @@ -161,15 +171,27 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src, break; } - //compare score, put smallest score to score_merged one by one + //compare score, + // if ascending = true, put smallest score to score_merged one by one + // else, put largest score to score_merged one by one auto& src_pair = distance_src[src_index]; auto& target_pair = distance_target[target_index]; - if(src_pair.second > target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; + if(ascending){ + if(src_pair.second > target_pair.second) { + distance_merged.push_back(target_pair); + target_index++; + } else { + distance_merged.push_back(src_pair); + src_index++; + } } else { - distance_merged.push_back(src_pair); - src_index++; + if(src_pair.second < target_pair.second) { + distance_merged.push_back(target_pair); + target_index++; + } else { + distance_merged.push_back(src_pair); + src_index++; + } } //score_merged.size() already equal topk @@ -185,6 +207,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src, Status SearchTask::TopkResult(SearchContext::ResultSet &result_src, uint64_t topk, + bool ascending, SearchContext::ResultSet &result_target) { if (result_target.empty()) { result_target.swap(result_src); @@ -200,7 +223,7 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src, for (size_t i = 0; i < result_src.size(); i++) { SearchContext::Id2DistanceMap &score_src = result_src[i]; SearchContext::Id2DistanceMap &score_target = result_target[i]; - SearchTask::MergeResult(score_src, score_target, topk); + SearchTask::MergeResult(score_src, score_target, topk, ascending); } return Status::OK(); diff --git a/cpp/src/db/scheduler/task/SearchTask.h b/cpp/src/db/scheduler/task/SearchTask.h index e4f0d872b..7e0aa52af 100644 --- a/cpp/src/db/scheduler/task/SearchTask.h +++ b/cpp/src/db/scheduler/task/SearchTask.h @@ -27,10 +27,12 @@ public: static Status MergeResult(SearchContext::Id2DistanceMap &distance_src, SearchContext::Id2DistanceMap &distance_target, - uint64_t topk); + uint64_t topk, + bool ascending); static Status TopkResult(SearchContext::ResultSet &result_src, uint64_t topk, + bool ascending, SearchContext::ResultSet &result_target); public: @@ -38,6 +40,7 @@ public: int index_type_ = 0; //for metrics ExecutionEnginePtr index_engine_; std::vector search_contexts_; + bool metric_l2 = true; }; using SearchTaskPtr = std::shared_ptr; diff --git a/cpp/src/sdk/examples/simple/src/ClientTest.cpp b/cpp/src/sdk/examples/simple/src/ClientTest.cpp index b2f1c56ba..da58f117a 100644 --- a/cpp/src/sdk/examples/simple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/simple/src/ClientTest.cpp @@ -98,7 +98,7 @@ namespace { TableSchema BuildTableSchema() { TableSchema tb_schema; tb_schema.table_name = TABLE_NAME; - tb_schema.index_type = IndexType::gpu_ivfsq8; + tb_schema.index_type = IndexType::gpu_ivfflat; tb_schema.dimension = TABLE_DIMENSION; tb_schema.store_raw_vector = true; diff --git a/cpp/src/server/ServerConfig.h b/cpp/src/server/ServerConfig.h index 1fa5bb65e..8b6c1761e 100644 --- a/cpp/src/server/ServerConfig.h +++ b/cpp/src/server/ServerConfig.h @@ -47,6 +47,7 @@ static const std::string CONFIG_ENGINE = "engine_config"; static const std::string CONFIG_NPROBE = "nprobe"; static const std::string CONFIG_NLIST = "nlist"; static const std::string CONFIG_DCBT = "use_blas_threshold"; +static const std::string CONFIG_METRICTYPE = "metric_type"; class ServerConfig { public: diff --git a/cpp/src/wrapper/IndexBuilder.cpp b/cpp/src/wrapper/IndexBuilder.cpp index d4429c381..41859907a 100644 --- a/cpp/src/wrapper/IndexBuilder.cpp +++ b/cpp/src/wrapper/IndexBuilder.cpp @@ -71,7 +71,8 @@ Index_ptr IndexBuilder::build_all(const long &nb, { LOG(DEBUG) << "Build index by GPU"; // TODO: list support index-type. - faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()); + faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; + faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type); std::lock_guard lk(gpu_resource); faiss::gpu::StandardGpuResources res; @@ -90,7 +91,8 @@ Index_ptr IndexBuilder::build_all(const long &nb, #else { LOG(DEBUG) << "Build index by CPU"; - faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()); + faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; + faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type); if (!index->is_trained) { nt == 0 || xt == nullptr ? index->train(nb, xb) : index->train(nt, xt); @@ -113,7 +115,8 @@ BgCpuBuilder::BgCpuBuilder(const zilliz::milvus::engine::Operand_ptr &opd) : Ind Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) { std::shared_ptr index = nullptr; - index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str())); + faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; + index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type)); LOG(DEBUG) << "Build index by CPU"; { diff --git a/cpp/unittest/db/search_test.cpp b/cpp/unittest/db/search_test.cpp index db10bcbad..d860295cb 100644 --- a/cpp/unittest/db/search_test.cpp +++ b/cpp/unittest/db/search_test.cpp @@ -73,13 +73,13 @@ TEST(DBSearchTest, TOPK_TEST) { ASSERT_EQ(src_result.size(), NQ); engine::SearchContext::ResultSet target_result; - status = engine::SearchTask::TopkResult(target_result, TOP_K, target_result); + status = engine::SearchTask::TopkResult(target_result, TOP_K, true, target_result); ASSERT_TRUE(status.ok()); - status = engine::SearchTask::TopkResult(target_result, TOP_K, src_result); + status = engine::SearchTask::TopkResult(target_result, TOP_K, true, src_result); ASSERT_FALSE(status.ok()); - status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result); + status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result); ASSERT_TRUE(status.ok()); ASSERT_TRUE(src_result.empty()); ASSERT_EQ(target_result.size(), NQ); @@ -92,7 +92,7 @@ TEST(DBSearchTest, TOPK_TEST) { status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result); ASSERT_TRUE(status.ok()); - status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result); + status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result); ASSERT_TRUE(status.ok()); for(uint64_t i = 0; i < NQ; i++) { ASSERT_EQ(target_result[i].size(), TOP_K); @@ -101,7 +101,7 @@ TEST(DBSearchTest, TOPK_TEST) { wrong_topk = TOP_K + 10; BuildResult(NQ, wrong_topk, src_ids, src_distence); - status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result); + status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result); ASSERT_TRUE(status.ok()); for(uint64_t i = 0; i < NQ; i++) { ASSERT_EQ(target_result[i].size(), TOP_K); @@ -126,7 +126,7 @@ TEST(DBSearchTest, MERGE_TEST) { { engine::SearchContext::Id2DistanceMap src = src_result[0]; engine::SearchContext::Id2DistanceMap target = target_result[0]; - status = engine::SearchTask::MergeResult(src, target, 10); + status = engine::SearchTask::MergeResult(src, target, 10, true); ASSERT_TRUE(status.ok()); ASSERT_EQ(target.size(), 10); CheckResult(src_result[0], target_result[0], target); @@ -135,7 +135,7 @@ TEST(DBSearchTest, MERGE_TEST) { { engine::SearchContext::Id2DistanceMap src = src_result[0]; engine::SearchContext::Id2DistanceMap target; - status = engine::SearchTask::MergeResult(src, target, 10); + status = engine::SearchTask::MergeResult(src, target, 10, true); ASSERT_TRUE(status.ok()); ASSERT_EQ(target.size(), src_count); ASSERT_TRUE(src.empty()); @@ -145,7 +145,7 @@ TEST(DBSearchTest, MERGE_TEST) { { engine::SearchContext::Id2DistanceMap src = src_result[0]; engine::SearchContext::Id2DistanceMap target = target_result[0]; - status = engine::SearchTask::MergeResult(src, target, 30); + status = engine::SearchTask::MergeResult(src, target, 30, true); ASSERT_TRUE(status.ok()); ASSERT_EQ(target.size(), src_count + target_count); CheckResult(src_result[0], target_result[0], target); @@ -154,7 +154,7 @@ TEST(DBSearchTest, MERGE_TEST) { { engine::SearchContext::Id2DistanceMap target = src_result[0]; engine::SearchContext::Id2DistanceMap src = target_result[0]; - status = engine::SearchTask::MergeResult(src, target, 30); + status = engine::SearchTask::MergeResult(src, target, 30, true); ASSERT_TRUE(status.ok()); ASSERT_EQ(target.size(), src_count + target_count); CheckResult(src_result[0], target_result[0], target); -- GitLab