From 8866d36252b4c58fe804a27937b7a827587129a6 Mon Sep 17 00:00:00 2001 From: Xu Peng Date: Tue, 30 Apr 2019 20:14:38 +0800 Subject: [PATCH] refactor(db): refactor execution_engine Former-commit-id: 90ddd165224135d190108f5d7bce544b5c0f305e --- cpp/src/db/ExecutionEngine.cpp | 65 ++++++++++++++++++++ cpp/src/db/ExecutionEngine.h | 32 ++++++++++ cpp/src/db/FaissExecutionEngine.cpp | 94 +++++++++++++++++++++++++++++ cpp/src/db/FaissExecutionEngine.h | 36 +++++++++++ cpp/src/db/MemManager.cpp | 2 +- cpp/src/db/MemManager.h | 4 +- 6 files changed, 230 insertions(+), 3 deletions(-) diff --git a/cpp/src/db/ExecutionEngine.cpp b/cpp/src/db/ExecutionEngine.cpp index 9aea3f7e..ffb54504 100644 --- a/cpp/src/db/ExecutionEngine.cpp +++ b/cpp/src/db/ExecutionEngine.cpp @@ -15,6 +15,71 @@ Status ExecutionEngine::AddWithIds(const std::vector& vectors, const std: return AddWithIds(n1, vectors.data(), vector_ids.data()); } +template +Status ExecutionEngineBase::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { + long n1 = (long)vectors.size(); + long n2 = (long)vector_ids.size(); + if (n1 != n2) { + LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2; + return Status::Error("Error: AddWithIds"); + } + return AddWithIds(n1, vectors.data(), vector_ids.data()); +} + +template +Status ExecutionEngineBase::AddWithIds(long n, const float *xdata, const long *xids) { + return static_cast(this)->AddWithIds(n, xdata, xids); +} + +template +size_t ExecutionEngineBase::Count() const { + return static_cast(this)->Count(); +} + +template +size_t ExecutionEngineBase::Size() const { + return static_cast(this)->Size(); +} + +template +size_t ExecutionEngineBase::PhysicalSize() const { + return static_cast(this)->PhysicalSize(); +} + +template +Status ExecutionEngineBase::Serialize() { + return static_cast(this)->Serialize(); +} + +template +Status ExecutionEngineBase::Load() { + return static_cast(this)->Load(); +} + +template +Status ExecutionEngineBase::Merge(const std::string& location) { + return static_cast(this)->Merge(location); +} + +template +Status ExecutionEngineBase::Search(long n, + const float *data, + long k, + float *distances, + long *labels) const { + return static_cast(this)->Search(n, data, k, distances, labels); +} + +template +Status ExecutionEngineBase::Cache() { + return static_cast(this)->Cache(); +} + +template +std::shared_ptr ExecutionEngineBase::BuildIndex(const std::string& location) { + return static_cast(this)->BuildIndex(location); +} + } // namespace engine } // namespace vecwise diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h index 97440225..30b5f6ea 100644 --- a/cpp/src/db/ExecutionEngine.h +++ b/cpp/src/db/ExecutionEngine.h @@ -44,6 +44,38 @@ public: virtual ~ExecutionEngine() {} }; +template +class ExecutionEngineBase { +public: + + Status AddWithIds(const std::vector& vectors, + const std::vector& vector_ids); + + Status AddWithIds(long n, const float *xdata, const long *xids); + + size_t Count() const; + + size_t Size() const; + + size_t PhysicalSize() const; + + Status Serialize(); + + Status Load(); + + Status Merge(const std::string& location); + + Status Search(long n, + const float *data, + long k, + float *distances, + long *labels) const; + + std::shared_ptr BuildIndex(const std::string&); + + Status Cache(); +}; + } // namespace engine } // namespace vecwise diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index 26bc161f..ab836d68 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -108,6 +108,100 @@ Status FaissExecutionEngine::Cache() { return Status::OK(); } + +FaissExecutionEngineBase::FaissExecutionEngineBase(uint16_t dimension, const std::string& location) + : pIndex_(faiss::index_factory(dimension, RawIndexType.c_str())), + location_(location) { +} + +FaissExecutionEngineBase::FaissExecutionEngineBase(std::shared_ptr index, const std::string& location) + : pIndex_(index), + location_(location) { +} + +Status FaissExecutionEngineBase::AddWithIds(long n, const float *xdata, const long *xids) { + pIndex_->add_with_ids(n, xdata, xids); + return Status::OK(); +} + +size_t FaissExecutionEngineBase::Count() const { + return (size_t)(pIndex_->ntotal); +} + +size_t FaissExecutionEngineBase::Size() const { + return (size_t)(Count() * pIndex_->d); +} + +size_t FaissExecutionEngineBase::PhysicalSize() const { + return (size_t)(Size()*sizeof(float)); +} + +Status FaissExecutionEngineBase::Serialize() { + write_index(pIndex_.get(), location_.c_str()); + return Status::OK(); +} + +Status FaissExecutionEngineBase::Load() { + auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location_); + if (!index) { + index = read_index(location_); + Cache(); + LOG(DEBUG) << "Disk io from: " << location_; + } + + pIndex_ = index->data(); + return Status::OK(); +} + +Status FaissExecutionEngineBase::Merge(const std::string& location) { + if (location == location_) { + return Status::Error("Cannot Merge Self"); + } + auto to_merge = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location); + if (!to_merge) { + to_merge = read_index(location); + } + auto file_index = dynamic_cast(to_merge->data().get()); + pIndex_->add_with_ids(file_index->ntotal, dynamic_cast(file_index->index)->xb.data(), + file_index->id_map.data()); + return Status::OK(); +} + +std::shared_ptr FaissExecutionEngineBase::BuildIndex(const std::string& location) { + auto opd = std::make_shared(); + opd->d = pIndex_->d; + opd->index_type = BuildIndexType; + IndexBuilderPtr pBuilder = GetIndexBuilder(opd); + + auto from_index = dynamic_cast(pIndex_.get()); + + auto index = pBuilder->build_all(from_index->ntotal, + dynamic_cast(from_index->index)->xb.data(), + from_index->id_map.data()); + + std::shared_ptr new_ee(new FaissExecutionEngineBase(index->data(), location)); + new_ee->Serialize(); + return new_ee; +} + +Status FaissExecutionEngineBase::Search(long n, + const float *data, + long k, + float *distances, + long *labels) const { + + pIndex_->search(n, data, k, distances, labels); + return Status::OK(); +} + +Status FaissExecutionEngineBase::Cache() { + zilliz::vecwise::cache::CpuCacheMgr::GetInstance( + )->InsertItem(location_, std::make_shared(pIndex_)); + + return Status::OK(); +} + + } // namespace engine } // namespace vecwise } // namespace zilliz diff --git a/cpp/src/db/FaissExecutionEngine.h b/cpp/src/db/FaissExecutionEngine.h index a1208b00..925b2685 100644 --- a/cpp/src/db/FaissExecutionEngine.h +++ b/cpp/src/db/FaissExecutionEngine.h @@ -46,6 +46,42 @@ protected: std::string location_; }; +class FaissExecutionEngineBase : public ExecutionEngineBase { +public: + FaissExecutionEngineBase(uint16_t dimension, const std::string& location); + FaissExecutionEngineBase(std::shared_ptr index, const std::string& location); + + Status AddWithIds(const std::vector& vectors, + const std::vector& vector_ids); + + Status AddWithIds(long n, const float *xdata, const long *xids); + + size_t Count() const; + + size_t Size() const; + + size_t PhysicalSize() const; + + Status Serialize(); + + Status Load(); + + Status Merge(const std::string& location); + + Status Search(long n, + const float *data, + long k, + float *distances, + long *labels) const; + + std::shared_ptr BuildIndex(const std::string&); + + Status Cache(); +protected: + std::shared_ptr pIndex_; + std::string location_; +}; + } // namespace engine } // namespace vecwise diff --git a/cpp/src/db/MemManager.cpp b/cpp/src/db/MemManager.cpp index 904c5db1..fa2858e3 100644 --- a/cpp/src/db/MemManager.cpp +++ b/cpp/src/db/MemManager.cpp @@ -18,7 +18,7 @@ MemVectors::MemVectors(const std::shared_ptr& meta_ptr, options_(options), schema_(schema), _pIdGenerator(new SimpleIDGenerator()), - pEE_(new FaissExecutionEngine(schema_.dimension, schema_.location)) { + pEE_(new FaissExecutionEngineBase(schema_.dimension, schema_.location)) { } void MemVectors::add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) { diff --git a/cpp/src/db/MemManager.h b/cpp/src/db/MemManager.h index 0a298bf2..374f75cd 100644 --- a/cpp/src/db/MemManager.h +++ b/cpp/src/db/MemManager.h @@ -19,7 +19,7 @@ namespace meta { class Meta; } -class ExecutionEngine; +class FaissExecutionEngineBase; class MemVectors { public: @@ -47,7 +47,7 @@ private: Options options_; meta::GroupFileSchema schema_; IDGenerator* _pIdGenerator; - std::shared_ptr pEE_; + std::shared_ptr pEE_; }; // MemVectors -- GitLab