提交 0e37089c 编写于 作者: X Xu Peng

refactor(db): refactor all for crtp replacement


Former-commit-id: 3c5d3ddeec04d573ef3916f673a7504c51a0a3bf
上级 67960d5a
......@@ -113,7 +113,7 @@ Status DBImpl::search(const std::string& group_id, size_t k, size_t nq,
auto search_in_index = [&](meta::GroupFilesSchema& file_vec) -> void {
for (auto &file : file_vec) {
FaissExecutionEngineBase index(file.dimension, file.location);
FaissExecutionEngine index(file.dimension, file.location);
index.Load();
auto file_size = index.PhysicalSize()/(1024*1024);
search_set_size += file_size;
......@@ -213,7 +213,7 @@ Status DBImpl::merge_files(const std::string& group_id, const meta::DateT& date,
return status;
}
FaissExecutionEngineBase index(group_file.dimension, group_file.location);
FaissExecutionEngine index(group_file.dimension, group_file.location);
meta::GroupFilesSchema updated;
long index_size = 0;
......@@ -286,7 +286,7 @@ Status DBImpl::build_index(const meta::GroupFileSchema& file) {
return status;
}
FaissExecutionEngineBase to_index(file.dimension, file.location);
FaissExecutionEngine to_index(file.dimension, file.location);
to_index.Load();
auto index = to_index.BuildIndex(group_file.location);
......
......@@ -5,18 +5,8 @@ namespace zilliz {
namespace vecwise {
namespace engine {
Status ExecutionEngine::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
long n1 = (long)vectors.size();
long n2 = (long)vector_ids.size();
if (n1 != n2) {
LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2;
return Status::Error("Error: AddWithIds");
}
return AddWithIds(n1, vectors.data(), vector_ids.data());
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
Status ExecutionEngine<Derived>::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
long n1 = (long)vectors.size();
long n2 = (long)vector_ids.size();
if (n1 != n2) {
......@@ -27,42 +17,42 @@ Status ExecutionEngineBase<Derived>::AddWithIds(const std::vector<float>& vector
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::AddWithIds(long n, const float *xdata, const long *xids) {
Status ExecutionEngine<Derived>::AddWithIds(long n, const float *xdata, const long *xids) {
return static_cast<Derived*>(this)->AddWithIds(n, xdata, xids);
}
template<typename Derived>
size_t ExecutionEngineBase<Derived>::Count() const {
size_t ExecutionEngine<Derived>::Count() const {
return static_cast<Derived*>(this)->Count();
}
template<typename Derived>
size_t ExecutionEngineBase<Derived>::Size() const {
size_t ExecutionEngine<Derived>::Size() const {
return static_cast<Derived*>(this)->Size();
}
template<typename Derived>
size_t ExecutionEngineBase<Derived>::PhysicalSize() const {
size_t ExecutionEngine<Derived>::PhysicalSize() const {
return static_cast<Derived*>(this)->PhysicalSize();
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Serialize() {
Status ExecutionEngine<Derived>::Serialize() {
return static_cast<Derived*>(this)->Serialize();
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Load() {
Status ExecutionEngine<Derived>::Load() {
return static_cast<Derived*>(this)->Load();
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Merge(const std::string& location) {
Status ExecutionEngine<Derived>::Merge(const std::string& location) {
return static_cast<Derived*>(this)->Merge(location);
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Search(long n,
Status ExecutionEngine<Derived>::Search(long n,
const float *data,
long k,
float *distances,
......@@ -71,12 +61,12 @@ Status ExecutionEngineBase<Derived>::Search(long n,
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Cache() {
Status ExecutionEngine<Derived>::Cache() {
return static_cast<Derived*>(this)->Cache();
}
template<typename Derived>
std::shared_ptr<Derived> ExecutionEngineBase<Derived>::BuildIndex(const std::string& location) {
std::shared_ptr<Derived> ExecutionEngine<Derived>::BuildIndex(const std::string& location) {
return static_cast<Derived*>(this)->BuildIndex(location);
}
......
......@@ -9,43 +9,8 @@ namespace zilliz {
namespace vecwise {
namespace engine {
class ExecutionEngine;
class ExecutionEngine {
public:
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0;
virtual size_t Count() const = 0;
virtual size_t Size() const = 0;
virtual size_t PhysicalSize() const = 0;
virtual Status Serialize() = 0;
virtual Status Load() = 0;
virtual Status Merge(const std::string& location) = 0;
virtual Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const = 0;
virtual std::shared_ptr<ExecutionEngine> BuildIndex(const std::string&) = 0;
virtual Status Cache() = 0;
virtual ~ExecutionEngine() {}
};
template <typename Derived>
class ExecutionEngineBase {
class ExecutionEngine {
public:
Status AddWithIds(const std::vector<float>& vectors,
......
......@@ -16,6 +16,7 @@ namespace engine {
const std::string RawIndexType = "IDMap,Flat";
const std::string BuildIndexType = "IDMap,Flat";
FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension, const std::string& location)
: pIndex_(faiss::index_factory(dimension, RawIndexType.c_str())),
location_(location) {
......@@ -74,7 +75,7 @@ Status FaissExecutionEngine::Merge(const std::string& location) {
return Status::OK();
}
std::shared_ptr<ExecutionEngine> FaissExecutionEngine::BuildIndex(const std::string& location) {
std::shared_ptr<FaissExecutionEngine> FaissExecutionEngine::BuildIndex(const std::string& location) {
auto opd = std::make_shared<Operand>();
opd->d = pIndex_->d;
opd->index_type = BuildIndexType;
......@@ -86,7 +87,7 @@ std::shared_ptr<ExecutionEngine> FaissExecutionEngine::BuildIndex(const std::str
dynamic_cast<faiss::IndexFlat*>(from_index->index)->xb.data(),
from_index->id_map.data());
std::shared_ptr<ExecutionEngine> new_ee(new FaissExecutionEngine(index->data(), location));
std::shared_ptr<FaissExecutionEngine> new_ee(new FaissExecutionEngine(index->data(), location));
new_ee->Serialize();
return new_ee;
}
......@@ -109,99 +110,6 @@ Status FaissExecutionEngine::Cache() {
}
FaissExecutionEngineBase::FaissExecutionEngineBase(uint16_t dimension, const std::string& location)
: pIndex_(faiss::index_factory(dimension, RawIndexType.c_str())),
location_(location) {
}
FaissExecutionEngineBase::FaissExecutionEngineBase(std::shared_ptr<faiss::Index> index, const std::string& location)
: pIndex_(index),
location_(location) {
}
Status FaissExecutionEngineBase::AddWithIds(long n, const float *xdata, const long *xids) {
pIndex_->add_with_ids(n, xdata, xids);
return Status::OK();
}
size_t FaissExecutionEngineBase::Count() const {
return (size_t)(pIndex_->ntotal);
}
size_t FaissExecutionEngineBase::Size() const {
return (size_t)(Count() * pIndex_->d);
}
size_t FaissExecutionEngineBase::PhysicalSize() const {
return (size_t)(Size()*sizeof(float));
}
Status FaissExecutionEngineBase::Serialize() {
write_index(pIndex_.get(), location_.c_str());
return Status::OK();
}
Status FaissExecutionEngineBase::Load() {
auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
if (!index) {
index = read_index(location_);
Cache();
LOG(DEBUG) << "Disk io from: " << location_;
}
pIndex_ = index->data();
return Status::OK();
}
Status FaissExecutionEngineBase::Merge(const std::string& location) {
if (location == location_) {
return Status::Error("Cannot Merge Self");
}
auto to_merge = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
if (!to_merge) {
to_merge = read_index(location);
}
auto file_index = dynamic_cast<faiss::IndexIDMap*>(to_merge->data().get());
pIndex_->add_with_ids(file_index->ntotal, dynamic_cast<faiss::IndexFlat*>(file_index->index)->xb.data(),
file_index->id_map.data());
return Status::OK();
}
std::shared_ptr<FaissExecutionEngineBase> FaissExecutionEngineBase::BuildIndex(const std::string& location) {
auto opd = std::make_shared<Operand>();
opd->d = pIndex_->d;
opd->index_type = BuildIndexType;
IndexBuilderPtr pBuilder = GetIndexBuilder(opd);
auto from_index = dynamic_cast<faiss::IndexIDMap*>(pIndex_.get());
auto index = pBuilder->build_all(from_index->ntotal,
dynamic_cast<faiss::IndexFlat*>(from_index->index)->xb.data(),
from_index->id_map.data());
std::shared_ptr<FaissExecutionEngineBase> new_ee(new FaissExecutionEngineBase(index->data(), location));
new_ee->Serialize();
return new_ee;
}
Status FaissExecutionEngineBase::Search(long n,
const float *data,
long k,
float *distances,
long *labels) const {
pIndex_->search(n, data, k, distances, labels);
return Status::OK();
}
Status FaissExecutionEngineBase::Cache() {
zilliz::vecwise::cache::CpuCacheMgr::GetInstance(
)->InsertItem(location_, std::make_shared<Index>(pIndex_));
return Status::OK();
}
} // namespace engine
} // namespace vecwise
} // namespace zilliz
......@@ -13,44 +13,12 @@ namespace zilliz {
namespace vecwise {
namespace engine {
class FaissExecutionEngine : public ExecutionEngine {
class FaissExecutionEngine : public ExecutionEngine<FaissExecutionEngine> {
public:
FaissExecutionEngine(uint16_t dimension, const std::string& location);
FaissExecutionEngine(std::shared_ptr<faiss::Index> index, const std::string& location);
virtual Status AddWithIds(long n, const float *xdata, const long *xids) override;
virtual size_t Count() const override;
virtual size_t Size() const override;
virtual size_t PhysicalSize() const override;
virtual Status Merge(const std::string& location) override;
virtual Status Serialize() override;
virtual Status Load() override;
virtual Status Cache() override;
virtual Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const override;
virtual std::shared_ptr<ExecutionEngine> BuildIndex(const std::string&) override;
protected:
std::shared_ptr<faiss::Index> pIndex_;
std::string location_;
};
class FaissExecutionEngineBase : public ExecutionEngineBase<FaissExecutionEngineBase> {
public:
FaissExecutionEngineBase(uint16_t dimension, const std::string& location);
FaissExecutionEngineBase(std::shared_ptr<faiss::Index> index, const std::string& location);
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
......@@ -74,7 +42,7 @@ public:
float *distances,
long *labels) const;
std::shared_ptr<FaissExecutionEngineBase> BuildIndex(const std::string&);
std::shared_ptr<FaissExecutionEngine> BuildIndex(const std::string&);
Status Cache();
protected:
......
......@@ -18,7 +18,7 @@ MemVectors::MemVectors(const std::shared_ptr<meta::Meta>& meta_ptr,
options_(options),
schema_(schema),
_pIdGenerator(new SimpleIDGenerator()),
pEE_(new FaissExecutionEngineBase(schema_.dimension, schema_.location)) {
pEE_(new FaissExecutionEngine(schema_.dimension, schema_.location)) {
}
void MemVectors::add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) {
......
......@@ -19,7 +19,7 @@ namespace meta {
class Meta;
}
class FaissExecutionEngineBase;
class FaissExecutionEngine;
class MemVectors {
public:
......@@ -47,7 +47,7 @@ private:
Options options_;
meta::GroupFileSchema schema_;
IDGenerator* _pIdGenerator;
std::shared_ptr<FaissExecutionEngineBase> pEE_;
std::shared_ptr<FaissExecutionEngine> pEE_;
}; // MemVectors
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册