From c0326906d972687d0ba8588056b21160cc96cb81 Mon Sep 17 00:00:00 2001 From: Xu Peng Date: Wed, 17 Apr 2019 11:47:57 +0800 Subject: [PATCH] feat(wrapper): add one more build_all api in IndexBuilder Former-commit-id: d6d7187a419bafb81751b815e2ebd235058e43f9 --- .gitignore | 1 + cpp/src/wrapper/IndexBuilder.cpp | 22 +++++++++++++++------- cpp/src/wrapper/IndexBuilder.h | 8 ++++++-- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 3eb961bb..e0cd6a4d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,6 @@ cmake_build *.o *.lo *.tar.gz +*.log cpp/third_party/thrift-0.12.0/ diff --git a/cpp/src/wrapper/IndexBuilder.cpp b/cpp/src/wrapper/IndexBuilder.cpp index 1e28df19..5d0180a3 100644 --- a/cpp/src/wrapper/IndexBuilder.cpp +++ b/cpp/src/wrapper/IndexBuilder.cpp @@ -21,9 +21,11 @@ IndexBuilder::IndexBuilder(const Operand_ptr &opd) { opd_ = opd; } -Index_ptr IndexBuilder::build_all(const long &nb, const vector &xb, - const vector &ids, - const long &nt, const vector &xt) { +Index_ptr IndexBuilder::build_all(const long &nb, + const float* xb, + const long* ids, + const long &nt, + const float* xt) { std::shared_ptr index = nullptr; index.reset(faiss::index_factory(opd_->d, opd_->index_type.c_str())); @@ -31,14 +33,20 @@ Index_ptr IndexBuilder::build_all(const long &nb, const vector &xb, // currently only cpu resources are used. std::lock_guard lk(cpu_resource); if (!index->is_trained) { - nt == 0 || xt.empty() ? index->train(nb, xb.data()) - : index->train(nt, xt.data()); + nt == 0 || xt == nullptr ? index->train(nb, xb) + : index->train(nt, xt); } - index->add(nb, xb.data()); - index->add_with_ids(nb, xb.data(), ids.data()); // todo(linxj): support add_with_idmap + index->add_with_ids(nb, xb, ids); // todo(linxj): support add_with_idmap } return std::make_shared(index); + +} + +Index_ptr IndexBuilder::build_all(const long &nb, const vector &xb, + const vector &ids, + const long &nt, const vector &xt) { + return build_all(nb, xb.data(), ids.data(), nt, xt.data()); } // Be Factory pattern later diff --git a/cpp/src/wrapper/IndexBuilder.h b/cpp/src/wrapper/IndexBuilder.h index 97479b91..ed5f8a39 100644 --- a/cpp/src/wrapper/IndexBuilder.h +++ b/cpp/src/wrapper/IndexBuilder.h @@ -19,6 +19,12 @@ class IndexBuilder { public: explicit IndexBuilder(const Operand_ptr &opd); + Index_ptr build_all(const long &nb, + const float* xb, + const long* ids, + const long &nt = 0, + const float* xt = nullptr); + Index_ptr build_all(const long &nb, const std::vector &xb, const std::vector &ids, @@ -47,5 +53,3 @@ extern IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd); } } } - - -- GitLab