提交 8866d362 编写于 作者: X Xu Peng

refactor(db): refactor execution_engine


Former-commit-id: 90ddd165224135d190108f5d7bce544b5c0f305e
上级 ca195424
......@@ -15,6 +15,71 @@ Status ExecutionEngine::AddWithIds(const std::vector<float>& vectors, const std:
return AddWithIds(n1, vectors.data(), vector_ids.data());
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
long n1 = (long)vectors.size();
long n2 = (long)vector_ids.size();
if (n1 != n2) {
LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2;
return Status::Error("Error: AddWithIds");
}
return AddWithIds(n1, vectors.data(), vector_ids.data());
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::AddWithIds(long n, const float *xdata, const long *xids) {
return static_cast<Derived*>(this)->AddWithIds(n, xdata, xids);
}
template<typename Derived>
size_t ExecutionEngineBase<Derived>::Count() const {
return static_cast<Derived*>(this)->Count();
}
template<typename Derived>
size_t ExecutionEngineBase<Derived>::Size() const {
return static_cast<Derived*>(this)->Size();
}
template<typename Derived>
size_t ExecutionEngineBase<Derived>::PhysicalSize() const {
return static_cast<Derived*>(this)->PhysicalSize();
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Serialize() {
return static_cast<Derived*>(this)->Serialize();
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Load() {
return static_cast<Derived*>(this)->Load();
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Merge(const std::string& location) {
return static_cast<Derived*>(this)->Merge(location);
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Search(long n,
const float *data,
long k,
float *distances,
long *labels) const {
return static_cast<Derived*>(this)->Search(n, data, k, distances, labels);
}
template<typename Derived>
Status ExecutionEngineBase<Derived>::Cache() {
return static_cast<Derived*>(this)->Cache();
}
template<typename Derived>
std::shared_ptr<Derived> ExecutionEngineBase<Derived>::BuildIndex(const std::string& location) {
return static_cast<Derived*>(this)->BuildIndex(location);
}
} // namespace engine
} // namespace vecwise
......
......@@ -44,6 +44,38 @@ public:
virtual ~ExecutionEngine() {}
};
template <typename Derived>
class ExecutionEngineBase {
public:
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
Status AddWithIds(long n, const float *xdata, const long *xids);
size_t Count() const;
size_t Size() const;
size_t PhysicalSize() const;
Status Serialize();
Status Load();
Status Merge(const std::string& location);
Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const;
std::shared_ptr<Derived> BuildIndex(const std::string&);
Status Cache();
};
} // namespace engine
} // namespace vecwise
......
......@@ -108,6 +108,100 @@ Status FaissExecutionEngine::Cache() {
return Status::OK();
}
FaissExecutionEngineBase::FaissExecutionEngineBase(uint16_t dimension, const std::string& location)
: pIndex_(faiss::index_factory(dimension, RawIndexType.c_str())),
location_(location) {
}
FaissExecutionEngineBase::FaissExecutionEngineBase(std::shared_ptr<faiss::Index> index, const std::string& location)
: pIndex_(index),
location_(location) {
}
Status FaissExecutionEngineBase::AddWithIds(long n, const float *xdata, const long *xids) {
pIndex_->add_with_ids(n, xdata, xids);
return Status::OK();
}
size_t FaissExecutionEngineBase::Count() const {
return (size_t)(pIndex_->ntotal);
}
size_t FaissExecutionEngineBase::Size() const {
return (size_t)(Count() * pIndex_->d);
}
size_t FaissExecutionEngineBase::PhysicalSize() const {
return (size_t)(Size()*sizeof(float));
}
Status FaissExecutionEngineBase::Serialize() {
write_index(pIndex_.get(), location_.c_str());
return Status::OK();
}
Status FaissExecutionEngineBase::Load() {
auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
if (!index) {
index = read_index(location_);
Cache();
LOG(DEBUG) << "Disk io from: " << location_;
}
pIndex_ = index->data();
return Status::OK();
}
Status FaissExecutionEngineBase::Merge(const std::string& location) {
if (location == location_) {
return Status::Error("Cannot Merge Self");
}
auto to_merge = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
if (!to_merge) {
to_merge = read_index(location);
}
auto file_index = dynamic_cast<faiss::IndexIDMap*>(to_merge->data().get());
pIndex_->add_with_ids(file_index->ntotal, dynamic_cast<faiss::IndexFlat*>(file_index->index)->xb.data(),
file_index->id_map.data());
return Status::OK();
}
std::shared_ptr<FaissExecutionEngineBase> FaissExecutionEngineBase::BuildIndex(const std::string& location) {
auto opd = std::make_shared<Operand>();
opd->d = pIndex_->d;
opd->index_type = BuildIndexType;
IndexBuilderPtr pBuilder = GetIndexBuilder(opd);
auto from_index = dynamic_cast<faiss::IndexIDMap*>(pIndex_.get());
auto index = pBuilder->build_all(from_index->ntotal,
dynamic_cast<faiss::IndexFlat*>(from_index->index)->xb.data(),
from_index->id_map.data());
std::shared_ptr<FaissExecutionEngineBase> new_ee(new FaissExecutionEngineBase(index->data(), location));
new_ee->Serialize();
return new_ee;
}
Status FaissExecutionEngineBase::Search(long n,
const float *data,
long k,
float *distances,
long *labels) const {
pIndex_->search(n, data, k, distances, labels);
return Status::OK();
}
Status FaissExecutionEngineBase::Cache() {
zilliz::vecwise::cache::CpuCacheMgr::GetInstance(
)->InsertItem(location_, std::make_shared<Index>(pIndex_));
return Status::OK();
}
} // namespace engine
} // namespace vecwise
} // namespace zilliz
......@@ -46,6 +46,42 @@ protected:
std::string location_;
};
class FaissExecutionEngineBase : public ExecutionEngineBase<FaissExecutionEngineBase> {
public:
FaissExecutionEngineBase(uint16_t dimension, const std::string& location);
FaissExecutionEngineBase(std::shared_ptr<faiss::Index> index, const std::string& location);
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
Status AddWithIds(long n, const float *xdata, const long *xids);
size_t Count() const;
size_t Size() const;
size_t PhysicalSize() const;
Status Serialize();
Status Load();
Status Merge(const std::string& location);
Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const;
std::shared_ptr<FaissExecutionEngineBase> BuildIndex(const std::string&);
Status Cache();
protected:
std::shared_ptr<faiss::Index> pIndex_;
std::string location_;
};
} // namespace engine
} // namespace vecwise
......
......@@ -18,7 +18,7 @@ MemVectors::MemVectors(const std::shared_ptr<meta::Meta>& meta_ptr,
options_(options),
schema_(schema),
_pIdGenerator(new SimpleIDGenerator()),
pEE_(new FaissExecutionEngine(schema_.dimension, schema_.location)) {
pEE_(new FaissExecutionEngineBase(schema_.dimension, schema_.location)) {
}
void MemVectors::add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) {
......
......@@ -19,7 +19,7 @@ namespace meta {
class Meta;
}
class ExecutionEngine;
class FaissExecutionEngineBase;
class MemVectors {
public:
......@@ -47,7 +47,7 @@ private:
Options options_;
meta::GroupFileSchema schema_;
IDGenerator* _pIdGenerator;
std::shared_ptr<ExecutionEngine> pEE_;
std::shared_ptr<FaissExecutionEngineBase> pEE_;
}; // MemVectors
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册