提交 c0326906 编写于 作者: X Xu Peng

feat(wrapper): add one more build_all api in IndexBuilder


Former-commit-id: d6d7187a419bafb81751b815e2ebd235058e43f9
上级 c87cdc87
...@@ -9,5 +9,6 @@ cmake_build ...@@ -9,5 +9,6 @@ cmake_build
*.o *.o
*.lo *.lo
*.tar.gz *.tar.gz
*.log
cpp/third_party/thrift-0.12.0/ cpp/third_party/thrift-0.12.0/
...@@ -21,9 +21,11 @@ IndexBuilder::IndexBuilder(const Operand_ptr &opd) { ...@@ -21,9 +21,11 @@ IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
opd_ = opd; opd_ = opd;
} }
Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb, Index_ptr IndexBuilder::build_all(const long &nb,
const vector<long> &ids, const float* xb,
const long &nt, const vector<float> &xt) { const long* ids,
const long &nt,
const float* xt) {
std::shared_ptr<faiss::Index> index = nullptr; std::shared_ptr<faiss::Index> index = nullptr;
index.reset(faiss::index_factory(opd_->d, opd_->index_type.c_str())); index.reset(faiss::index_factory(opd_->d, opd_->index_type.c_str()));
...@@ -31,14 +33,20 @@ Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb, ...@@ -31,14 +33,20 @@ Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
// currently only cpu resources are used. // currently only cpu resources are used.
std::lock_guard<std::mutex> lk(cpu_resource); std::lock_guard<std::mutex> lk(cpu_resource);
if (!index->is_trained) { if (!index->is_trained) {
nt == 0 || xt.empty() ? index->train(nb, xb.data()) nt == 0 || xt == nullptr ? index->train(nb, xb)
: index->train(nt, xt.data()); : index->train(nt, xt);
} }
index->add(nb, xb.data()); index->add_with_ids(nb, xb, ids); // todo(linxj): support add_with_idmap
index->add_with_ids(nb, xb.data(), ids.data()); // todo(linxj): support add_with_idmap
} }
return std::make_shared<Index>(index); return std::make_shared<Index>(index);
}
Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
const vector<long> &ids,
const long &nt, const vector<float> &xt) {
return build_all(nb, xb.data(), ids.data(), nt, xt.data());
} }
// Be Factory pattern later // Be Factory pattern later
......
...@@ -19,6 +19,12 @@ class IndexBuilder { ...@@ -19,6 +19,12 @@ class IndexBuilder {
public: public:
explicit IndexBuilder(const Operand_ptr &opd); explicit IndexBuilder(const Operand_ptr &opd);
Index_ptr build_all(const long &nb,
const float* xb,
const long* ids,
const long &nt = 0,
const float* xt = nullptr);
Index_ptr build_all(const long &nb, Index_ptr build_all(const long &nb,
const std::vector<float> &xb, const std::vector<float> &xb,
const std::vector<long> &ids, const std::vector<long> &ids,
...@@ -47,5 +53,3 @@ extern IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd); ...@@ -47,5 +53,3 @@ extern IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd);
} }
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册