提交 56bbe40f 编写于 作者: X xj.lin

1. fix operand serialize bug

2. support gpu-build
3. add unittest


Former-commit-id: bb36dcb05220d8f0648f282c7e38fe20f4ab3c16
上级 675777d0
...@@ -39,11 +39,11 @@ public: ...@@ -39,11 +39,11 @@ public:
virtual bool reset(); virtual bool reset();
/** /**
* @brief Same as add, but stores xids instead of sequential ids. * @brief Same as add, but stores xids instead of sequential ids.
* *
* @param data input matrix, size n * d * @param data input matrix, size n * d
* @param if ids is not empty ids for the std::vectors * @param if ids is not empty ids for the std::vectors
*/ */
virtual bool add_with_ids(idx_t n, const float *xdata, const long *xids); virtual bool add_with_ids(idx_t n, const float *xdata, const long *xids);
/** /**
...@@ -57,23 +57,20 @@ public: ...@@ -57,23 +57,20 @@ public:
*/ */
virtual bool search(idx_t n, const float *data, idx_t k, float *distances, long *labels) const; virtual bool search(idx_t n, const float *data, idx_t k, float *distances, long *labels) const;
// virtual bool remove_ids(const faiss::IDSelector &sel, long &nremove, long &location); //virtual bool search(idx_t n, const std::vector<float> &data, idx_t k,
// std::vector<float> &distances, std::vector<float> &labels) const;
// virtual bool remove_ids_range(const faiss::IDSelector &sel, long &nremove); //virtual bool remove_ids(const faiss::IDSelector &sel, long &nremove, long &location);
//virtual bool remove_ids_range(const faiss::IDSelector &sel, long &nremove);
//virtual bool index_display();
// virtual bool index_display();
//
virtual std::shared_ptr<faiss::Index> data() { return index_; } virtual std::shared_ptr<faiss::Index> data() { return index_; }
virtual const std::shared_ptr<faiss::Index>& data() const { return index_; } virtual const std::shared_ptr<faiss::Index>& data() const { return index_; }
private: private:
friend void write_index(const Index_ptr &index, const std::string &file_name); friend void write_index(const Index_ptr &index, const std::string &file_name);
std::shared_ptr<faiss::Index> index_ = nullptr; std::shared_ptr<faiss::Index> index_ = nullptr;
// std::vector<faiss::gpu::GpuResources *> res_;
// std::vector<int> devs_;
// bool usegpu = true;
// int ngpus = 0;
// faiss::gpu::GpuMultipleClonerOptions *options = new faiss::gpu::GpuMultipleClonerOptions();
}; };
......
...@@ -6,41 +6,52 @@ ...@@ -6,41 +6,52 @@
#include "mutex" #include "mutex"
#include <faiss/gpu/StandardGpuResources.h>
#include "faiss/gpu/GpuIndexIVFFlat.h"
#include "faiss/gpu/GpuAutoTune.h"
#include "IndexBuilder.h" #include "IndexBuilder.h"
namespace zilliz { namespace zilliz {
namespace vecwise { namespace vecwise {
namespace engine { namespace engine {
using std::vector; using std::vector;
// todo(linxj): use ResourceMgr instead static std::mutex gpu_resource;
static std::mutex cpu_resource;
IndexBuilder::IndexBuilder(const Operand_ptr &opd) { IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
opd_ = opd; opd_ = opd;
} }
// Default: build use gpu
Index_ptr IndexBuilder::build_all(const long &nb, Index_ptr IndexBuilder::build_all(const long &nb,
const float* xb, const float* xb,
const long* ids, const long* ids,
const long &nt, const long &nt,
const float* xt) { const float* xt) {
std::shared_ptr<faiss::Index> index = nullptr; std::shared_ptr<faiss::Index> host_index = nullptr;
index.reset(faiss::index_factory(opd_->d, opd_->index_type.c_str()));
{ {
// currently only cpu resources are used. // TODO: list support index-type.
std::lock_guard<std::mutex> lk(cpu_resource); faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->index_type.c_str());
if (!index->is_trained) {
nt == 0 || xt == nullptr ? index->train(nb, xb) std::lock_guard<std::mutex> lk(gpu_resource);
: index->train(nt, xt); faiss::gpu::StandardGpuResources res;
auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index);
if (!device_index->is_trained) {
nt == 0 || xt == nullptr ? device_index->train(nb, xb)
: device_index->train(nt, xt);
} }
index->add_with_ids(nb, xb, ids); // todo(linxj): support add_with_idmap device_index->add_with_ids(nb, xb, ids);
}
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
return std::make_shared<Index>(index); delete device_index;
delete ori_index;
}
return std::make_shared<Index>(host_index);
} }
Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb, Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
......
...@@ -43,7 +43,6 @@ public: ...@@ -43,7 +43,6 @@ public:
private: private:
Operand_ptr opd_ = nullptr; Operand_ptr opd_ = nullptr;
// std::shared_ptr<faiss::Index> index_ = nullptr;
}; };
using IndexBuilderPtr = std::shared_ptr<IndexBuilder>; using IndexBuilderPtr = std::shared_ptr<IndexBuilder>;
......
...@@ -13,9 +13,9 @@ namespace engine { ...@@ -13,9 +13,9 @@ namespace engine {
std::ostream &operator<<(std::ostream &os, const Operand &obj) { std::ostream &operator<<(std::ostream &os, const Operand &obj) {
os << obj.d << " " os << obj.d << " "
<< obj.index_type << " " << obj.index_type << " "
<< obj.metric_type << " "
<< obj.preproc << " " << obj.preproc << " "
<< obj.postproc << " " << obj.postproc << " "
<< obj.metric_type << " "
<< obj.ncent; << obj.ncent;
return os; return os;
} }
...@@ -23,16 +23,16 @@ std::ostream &operator<<(std::ostream &os, const Operand &obj) { ...@@ -23,16 +23,16 @@ std::ostream &operator<<(std::ostream &os, const Operand &obj) {
std::istream &operator>>(std::istream &is, Operand &obj) { std::istream &operator>>(std::istream &is, Operand &obj) {
is >> obj.d is >> obj.d
>> obj.index_type >> obj.index_type
>> obj.metric_type
>> obj.preproc >> obj.preproc
>> obj.postproc >> obj.postproc
>> obj.metric_type
>> obj.ncent; >> obj.ncent;
return is; return is;
} }
std::string operand_to_str(const Operand_ptr &opd) { std::string operand_to_str(const Operand_ptr &opd) {
std::ostringstream ss; std::ostringstream ss;
ss << opd; ss << *opd;
return ss.str(); return ss.str();
} }
......
...@@ -22,9 +22,9 @@ struct Operand { ...@@ -22,9 +22,9 @@ struct Operand {
int d; int d;
std::string index_type = "IVF13864,Flat"; std::string index_type = "IVF13864,Flat";
std::string metric_type = "L2"; //> L2 / Inner Product
std::string preproc; std::string preproc;
std::string postproc; std::string postproc;
std::string metric_type = "L2"; // L2 / Inner Product
int ncent; int ncent;
}; };
......
...@@ -14,11 +14,21 @@ using namespace zilliz::vecwise::engine; ...@@ -14,11 +14,21 @@ using namespace zilliz::vecwise::engine;
TEST(operand_test, Wrapper_Test) { TEST(operand_test, Wrapper_Test) {
using std::cout;
using std::endl;
auto opd = std::make_shared<Operand>(); auto opd = std::make_shared<Operand>();
opd->index_type = "IVF16384,Flat"; opd->index_type = "IDMap,Flat";
opd->d = 256; opd->preproc = "opq";
opd->postproc = "pq";
opd->metric_type = "L2";
opd->ncent = 256;
opd->d = 64;
auto opd_str = operand_to_str(opd);
auto new_opd = str_to_operand(opd_str);
std::cout << opd << std::endl; assert(new_opd->index_type == opd->index_type);
} }
TEST(build_test, Wrapper_Test) { TEST(build_test, Wrapper_Test) {
...@@ -68,59 +78,61 @@ TEST(build_test, Wrapper_Test) { ...@@ -68,59 +78,61 @@ TEST(build_test, Wrapper_Test) {
//search in first quadrant //search in first quadrant
int nq = 1, k = 10; int nq = 1, k = 10;
std::vector<float> xq = {0.5, 0.5, 0.5}; std::vector<float> xq = {0.5, 0.5, 0.5};
float* result_dists = new float[k]; float *result_dists = new float[k];
long* result_ids = new long[k]; long *result_ids = new long[k];
index_1->search(nq, xq.data(), k, result_dists, result_ids); index_1->search(nq, xq.data(), k, result_dists, result_ids);
for(int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
if(result_ids[i] < 0) { if (result_ids[i] < 0) {
ASSERT_TRUE(false); ASSERT_TRUE(false);
break; break;
} }
long id = result_ids[i]; long id = result_ids[i];
std::cout << "No." << id << " [" << xb[id*3] << ", " << xb[id*3 + 1] << ", " std::cout << "No." << id << " [" << xb[id * 3] << ", " << xb[id * 3 + 1] << ", "
<< xb[id*3 + 2] <<"] distance = " << result_dists[i] << std::endl; << xb[id * 3 + 2] << "] distance = " << result_dists[i] << std::endl;
//makesure result vector is in first quadrant //makesure result vector is in first quadrant
ASSERT_TRUE(xb[id*3] > 0.0); ASSERT_TRUE(xb[id * 3] > 0.0);
ASSERT_TRUE(xb[id*3 + 1] > 0.0); ASSERT_TRUE(xb[id * 3 + 1] > 0.0);
ASSERT_TRUE(xb[id*3 + 2] > 0.0); ASSERT_TRUE(xb[id * 3 + 2] > 0.0);
} }
delete[] result_dists; delete[] result_dists;
delete[] result_ids; delete[] result_ids;
} }
TEST(search_test, Wrapper_Test) { TEST(gpu_build_test, Wrapper_Test) {
const int dim = 256; using std::vector;
size_t nb = 25000; int d = 256;
size_t nq = 100; int nb = 3 * 1000 * 100;
size_t k = 100; int nq = 100;
std::vector<float> xb(nb*dim); vector<float> xb(d * nb);
std::vector<float> xq(nq*dim); vector<float> xq(d * nq);
std::vector<long> ids(nb*dim); vector<long> ids(nb);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_real_distribution<> dis_xt(-1.0, 1.0); std::uniform_real_distribution<> dis_xt(-1.0, 1.0);
for (size_t i = 0; i < nb*dim; i++) { for (auto &e : xb) { e = float(dis_xt(gen)); }
xb[i] = dis_xt(gen); for (auto &e : xq) { e = float(dis_xt(gen)); }
ids[i] = i; for (int i = 0; i < nb; ++i) { ids[i] = i; }
}
for (size_t i = 0; i < nq*dim; i++) {
xq[i] = dis_xt(gen);
}
// result data auto opd = std::make_shared<Operand>();
std::vector<long> nns_gt(nq*k); // nns = nearst neg search opd->index_type = "IVF256,Flat";
std::vector<long> nns(nq*k); opd->d = d;
std::vector<float> dis_gt(nq*k); opd->ncent = 256;
std::vector<float> dis(nq*k);
faiss::Index* index_gt(faiss::index_factory(dim, "IDMap,Flat"));
index_gt->add_with_ids(nb, xb.data(), ids.data());
index_gt->search(nq, xq.data(), 10, dis_gt.data(), nns_gt.data());
std::cout << "data: " << nns_gt[0];
IndexBuilderPtr index_builder_1 = GetIndexBuilder(opd);
auto index_1 = index_builder_1->build_all(nb, xb.data(), ids.data());
assert(index_1->ntotal == nb);
assert(index_1->dim == d);
// sanity check: search 5 first vectors of xb
int k = 1;
vector<long> I(5 * k);
vector<float> D(5 * k);
index_1->search(5, xb.data(), k, D.data(), I.data());
for (int i = 0; i < 5; ++i) { assert(i == I[i]); }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册