From 4aabe33207c68c84e1c94af7d59033eea935b3da Mon Sep 17 00:00:00 2001 From: yu yunfeng Date: Wed, 3 Jul 2019 11:09:17 +0800 Subject: [PATCH] update IVF nprobe Former-commit-id: 41aa77c14de1db37cfb1a177ba223de90829e003 --- cpp/conf/server_config.template | 5 ++++- cpp/src/db/EngineFactory.cpp | 34 +++++++++++++++++++++-------- cpp/src/db/ExecutionEngine.h | 2 ++ cpp/src/db/FaissExecutionEngine.cpp | 30 ++++++++++++++++++++++++- cpp/src/db/FaissExecutionEngine.h | 9 ++++---- cpp/src/wrapper/Index.cpp | 29 ------------------------ 6 files changed, 65 insertions(+), 44 deletions(-) diff --git a/cpp/conf/server_config.template b/cpp/conf/server_config.template index c2ed7756..0383e00b 100644 --- a/cpp/conf/server_config.template +++ b/cpp/conf/server_config.template @@ -30,4 +30,7 @@ license_config: # license configure license_path: "@MILVUS_DB_PATH@/system.license" # license file path cache_config: # cache configure - cpu_cache_capacity: 16 # how many memory are used as cache, unit: GB, range: 0 ~ less than total memory \ No newline at end of file + cpu_cache_capacity: 16 # how many memory are used as cache, unit: GB, range: 0 ~ less than total memory + +engine_config: + nprobe: 10 \ No newline at end of file diff --git a/cpp/src/db/EngineFactory.cpp b/cpp/src/db/EngineFactory.cpp index 26ef639c..bacce70c 100644 --- a/cpp/src/db/EngineFactory.cpp +++ b/cpp/src/db/EngineFactory.cpp @@ -7,23 +7,39 @@ #include "FaissExecutionEngine.h" #include "Log.h" + namespace zilliz { namespace milvus { namespace engine { ExecutionEnginePtr EngineFactory::Build(uint16_t dimension, - const std::string& location, - EngineType type) { - switch(type) { - case EngineType::FAISS_IDMAP: - return ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, "IDMap", "IDMap,Flat")); - case EngineType::FAISS_IVFFLAT: - return ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, "IVF", "IDMap,Flat")); - default: - ENGINE_LOG_ERROR << "Unsupportted engine type"; + const std::string &location, + EngineType type) { + + ExecutionEnginePtr execution_engine_ptr; + + switch (type) { + case EngineType::FAISS_IDMAP: { + execution_engine_ptr = + ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, "IDMap", "IDMap,Flat")); + break; + } + + case EngineType::FAISS_IVFFLAT: { + execution_engine_ptr = + ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, "IVF", "IDMap,Flat")); + break; + } + + default: { + ENGINE_LOG_ERROR << "Unsupported engine type"; return nullptr; + } } + + execution_engine_ptr->Init(); + return execution_engine_ptr; } } diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h index d2b4d01e..f8c05f6f 100644 --- a/cpp/src/db/ExecutionEngine.h +++ b/cpp/src/db/ExecutionEngine.h @@ -50,6 +50,8 @@ public: virtual std::shared_ptr BuildIndex(const std::string&) = 0; virtual Status Cache() = 0; + + virtual Status Init() = 0; }; using ExecutionEnginePtr = std::shared_ptr; diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index 9dfdd978..20bd530e 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -13,6 +13,7 @@ #include #include #include +#include "faiss/IndexIVF.h" #include "metrics/Metrics.h" @@ -135,7 +136,16 @@ Status FaissExecutionEngine::Search(long n, float *distances, long *labels) const { auto start_time = METRICS_NOW_TIME; - pIndex_->search(n, data, k, distances, labels); + + std::shared_ptr ivf_index = std::dynamic_pointer_cast(pIndex_); + if(ivf_index) { + ENGINE_LOG_DEBUG << "Index type: IVFFLAT nProbe: " << nprobe_; + ivf_index->nprobe = nprobe_; + ivf_index->search(n, data, k, distances, labels); + } else { + pIndex_->search(n, data, k, distances, labels); + } + auto end_time = METRICS_NOW_TIME; auto total_time = METRICS_MICROSECONDS(start_time,end_time); server::Metrics::GetInstance().QueryIndexTypePerSecondSet(build_index_type_, double(n)/double(total_time)); @@ -149,6 +159,24 @@ Status FaissExecutionEngine::Cache() { return Status::OK(); } +Status FaissExecutionEngine::Init() { + + if(build_index_type_ == "IVF") { + + using namespace zilliz::milvus::server; + ServerConfig &config = ServerConfig::GetInstance(); + ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE); + nprobe_ = engine_config.GetInt32Value(CONFIG_NPROBE, 1000); + + } else if(build_index_type_ == "IDMap") { + ; + } else { + return Status::Error("Wrong index type: ", build_index_type_); + } + + return Status::OK(); +} + } // namespace engine } // namespace milvus diff --git a/cpp/src/db/FaissExecutionEngine.h b/cpp/src/db/FaissExecutionEngine.h index 5667df34..f9f37ad9 100644 --- a/cpp/src/db/FaissExecutionEngine.h +++ b/cpp/src/db/FaissExecutionEngine.h @@ -6,14 +6,11 @@ #pragma once #include "ExecutionEngine.h" +#include "faiss/Index.h" #include #include -namespace faiss { - class Index; -} - namespace zilliz { namespace milvus { namespace engine { @@ -58,12 +55,16 @@ public: Status Cache() override; + Status Init() override; + protected: std::shared_ptr pIndex_; std::string location_; std::string build_index_type_; std::string raw_index_type_; + + size_t nprobe_ = 0; }; diff --git a/cpp/src/wrapper/Index.cpp b/cpp/src/wrapper/Index.cpp index 57c462a2..18e20d83 100644 --- a/cpp/src/wrapper/Index.cpp +++ b/cpp/src/wrapper/Index.cpp @@ -25,32 +25,6 @@ using std::string; using std::unordered_map; using std::vector; -class Nprobe { - public: - static Nprobe &GetInstance() { - static Nprobe instance; - return instance; - } - - void SelectNprobe() { - using namespace zilliz::milvus::server; - ServerConfig &config = ServerConfig::GetInstance(); - ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE); - nprobe_ = engine_config.GetInt32Value(CONFIG_NPROBE, 1000); - } - - size_t GetNprobe() { - return nprobe_; - } - - private: - Nprobe() : nprobe_(1000) { SelectNprobe(); } - - private: - size_t nprobe_; -}; - - Index::Index(const std::shared_ptr &raw_index) { index_ = raw_index; dim = index_->d; @@ -84,9 +58,6 @@ bool Index::add_with_ids(idx_t n, const float *xdata, const long *xids) { bool Index::search(idx_t n, const float *data, idx_t k, float *distances, long *labels) const { try { - if(auto ivf_index = std::dynamic_pointer_cast(index_)) { - ivf_index->nprobe = Nprobe::GetInstance().GetNprobe(); - } index_->search(n, data, k, distances, labels); } catch (std::exception &e) { -- GitLab