diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 88bd9b8d959c1c5c342ba8c8d82c65091ba0be18..859a1d75aa65c0aa133ca711bc68d7b526382d26 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -228,7 +228,7 @@ Status DBImpl::merge_files(const std::string& group_id, const meta::DateT& date, for (auto& file : files) { auto to_merge = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(file.location); if (!to_merge) { - to_merge = read_index(file.location.c_str()); + to_merge = read_index(file.location); } auto file_index = dynamic_cast(to_merge->data().get()); index->add_with_ids(file_index->ntotal, dynamic_cast(file_index->index)->xb.data(), diff --git a/cpp/src/db/ExecutionEngine.cpp b/cpp/src/db/ExecutionEngine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9aea3f7ea12433170b75aa62386fa99c1377d4ce --- /dev/null +++ b/cpp/src/db/ExecutionEngine.cpp @@ -0,0 +1,21 @@ +#include +#include "ExecutionEngine.h" + +namespace zilliz { +namespace vecwise { +namespace engine { + +Status ExecutionEngine::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { + long n1 = (long)vectors.size(); + long n2 = (long)vector_ids.size(); + if (n1 != n2) { + LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2; + return Status::Error("Error: AddWithIds"); + } + return AddWithIds(n1, vectors.data(), vector_ids.data()); +} + + +} // namespace engine +} // namespace vecwise +} // namespace zilliz diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..4b08149f45a2f8437b661a2c7f580b2d608d00e2 --- /dev/null +++ b/cpp/src/db/ExecutionEngine.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include "Status.h" + +namespace zilliz { +namespace vecwise { +namespace engine { + +class ExecutionEngine { +public: + + Status AddWithIds(const std::vector& vectors, + const std::vector& vector_ids); + + virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0; + + virtual size_t Count() const = 0; + + virtual size_t Size() const = 0; + + virtual Status Serialize() = 0; + + virtual Status Cache() = 0; + + virtual ~ExecutionEngine() {} +}; + + +} // namespace engine +} // namespace vecwise +} // namespace zilliz diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffa3ef3c24573fcc7eddc1946b0a26b495704f7b --- /dev/null +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include + +#include "FaissExecutionEngine.h" + +namespace zilliz { +namespace vecwise { +namespace engine { + +const std::string IndexType = "IDMap,Flat"; + +FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension, const std::string& location) + : pIndex_(faiss::index_factory(dimension, IndexType.c_str())), + location_(location) { +} + +Status FaissExecutionEngine::AddWithIds(long n, const float *xdata, const long *xids) { + pIndex_->add_with_ids(n, xdata, xids); + return Status::OK(); +} + +size_t FaissExecutionEngine::Count() const { + return (size_t)(pIndex_->ntotal); +} + +size_t FaissExecutionEngine::Size() const { + return (size_t)(Count() * pIndex_->d); +} + +Status FaissExecutionEngine::Serialize() { + write_index(pIndex_.get(), location_.c_str()); + return Status::OK(); +} + +Status FaissExecutionEngine::Cache() { + zilliz::vecwise::cache::CpuCacheMgr::GetInstance( + )->InsertItem(location_, std::make_shared(pIndex_)); + + return Status::OK(); +} + +} // namespace engine +} // namespace vecwise +} // namespace zilliz diff --git a/cpp/src/db/FaissExecutionEngine.h b/cpp/src/db/FaissExecutionEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..bb1f5a477096cfdb79e855540c428ffde7f5e8c3 --- /dev/null +++ b/cpp/src/db/FaissExecutionEngine.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include "ExecutionEngine.h" + +namespace faiss { + class Index; +} + +namespace zilliz { +namespace vecwise { +namespace engine { + +class FaissExecutionEngine : public ExecutionEngine { +public: + FaissExecutionEngine(uint16_t dimension, const std::string& location); + virtual Status AddWithIds(long n, const float *xdata, const long *xids) override; + + virtual size_t Count() const override; + + virtual size_t Size() const override; + + virtual Status Serialize() override; + + virtual Status Cache() override; + +protected: + std::shared_ptr pIndex_; + std::string location_; +}; + + +} // namespace engine +} // namespace vecwise +} // namespace zilliz diff --git a/cpp/src/db/MemManager.cpp b/cpp/src/db/MemManager.cpp index e3c9407d861038dcc690078da36b57a2d96b583c..10ec5829352ec2d3c89e98db3acfbd05df5c55ae 100644 --- a/cpp/src/db/MemManager.cpp +++ b/cpp/src/db/MemManager.cpp @@ -1,14 +1,12 @@ -#include #include #include #include -#include -#include #include #include "MemManager.h" #include "Meta.h" +#include "FaissExecutionEngine.h" namespace zilliz { @@ -21,43 +19,36 @@ MemVectors::MemVectors(const std::shared_ptr& meta_ptr, options_(options), schema_(schema), _pIdGenerator(new SimpleIDGenerator()), - pIndex_(faiss::index_factory(schema_.dimension, "IDMap,Flat")) { + pEE_(new FaissExecutionEngine(schema_.dimension, schema_.location)) { } void MemVectors::add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) { _pIdGenerator->getNextIDNumbers(n_, vector_ids_); - pIndex_->add_with_ids(n_, vectors_, &vector_ids_[0]); + pEE_->AddWithIds(n_, vectors_, vector_ids_.data()); for(auto i=0 ; intotal; + return pEE_->Count(); } size_t MemVectors::approximate_size() const { - return total() * schema_.dimension; + return pEE_->Size(); } Status MemVectors::serialize(std::string& group_id) { - /* std::stringstream ss; */ - /* ss << "/tmp/test/" << _pIdGenerator->getNextIDNumber(); */ - /* faiss::write_index(pIndex_, ss.str().c_str()); */ - /* std::cout << pIndex_->ntotal << std::endl; */ - /* std::cout << _file_location << std::endl; */ - /* faiss::write_index(pIndex_, _file_location.c_str()); */ group_id = schema_.group_id; auto rows = approximate_size(); - write_index(pIndex_.get(), schema_.location.c_str()); + pEE_->Serialize(); schema_.rows = rows; schema_.file_type = (rows >= options_.index_trigger_size) ? meta::GroupFileSchema::TO_INDEX : meta::GroupFileSchema::RAW; auto status = pMeta_->update_group_file(schema_); - zilliz::vecwise::cache::CpuCacheMgr::GetInstance( - )->InsertItem(schema_.location, std::make_shared(pIndex_)); + pEE_->Cache(); return status; } diff --git a/cpp/src/db/MemManager.h b/cpp/src/db/MemManager.h index 077e045286627e9200d725cc35b5773e4049684f..0a298bf22d2b490350a5b022aed32a4b2eb709ae 100644 --- a/cpp/src/db/MemManager.h +++ b/cpp/src/db/MemManager.h @@ -10,10 +10,6 @@ #include "Status.h" #include "Meta.h" -namespace faiss { - class Index; -} - namespace zilliz { namespace vecwise { @@ -23,6 +19,8 @@ namespace meta { class Meta; } +class ExecutionEngine; + class MemVectors { public: explicit MemVectors(const std::shared_ptr&, @@ -49,7 +47,7 @@ private: Options options_; meta::GroupFileSchema schema_; IDGenerator* _pIdGenerator; - std::shared_ptr pIndex_; + std::shared_ptr pEE_; }; // MemVectors diff --git a/cpp/src/db/Status.h b/cpp/src/db/Status.h index f45c9f6bd13477880ed2e974518c8f158576202f..4db2b4c6e0c03bcdf14a8339e5dafeeb14d7d7da 100644 --- a/cpp/src/db/Status.h +++ b/cpp/src/db/Status.h @@ -21,6 +21,9 @@ public: static Status NotFound(const std::string& msg, const std::string& msg2="") { return Status(kNotFound, msg, msg2); } + static Status Error(const std::string& msg, const std::string& msg2="") { + return Status(kError, msg, msg2); + } static Status InvalidDBPath(const std::string& msg, const std::string& msg2="") { return Status(kInvalidDBPath, msg, msg2); @@ -35,6 +38,7 @@ public: bool ok() const { return state_ == nullptr; } bool IsNotFound() const { return code() == kNotFound; } + bool IsError() const { return code() == kError; } bool IsInvalidDBPath() const { return code() == kInvalidDBPath; } bool IsGroupError() const { return code() == kGroupError; } @@ -48,6 +52,7 @@ private: enum Code { kOK = 0, kNotFound, + kError, kInvalidDBPath, kGroupError,