提交 b7966df1 编写于 作者: P peng.xu

Merge branch 'add_unittest' into 'branch-0.3.1'

Add unittest

See merge request megasearch/milvus!185

Former-commit-id: fe37fe22833770f07ccf552364e3e7e31659232c
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#include <src/server/ServerConfig.h> #include <src/server/ServerConfig.h>
#include <src/metrics/Metrics.h>
#include "Log.h" #include "Log.h"
#include "src/cache/CpuCacheMgr.h" #include "src/cache/CpuCacheMgr.h"
...@@ -16,55 +17,6 @@ namespace zilliz { ...@@ -16,55 +17,6 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension,
const std::string &location, const std::string &location,
EngineType type) EngineType type)
...@@ -89,7 +41,7 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) { ...@@ -89,7 +41,7 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
break; break;
} }
case EngineType::FAISS_IVFFLAT_GPU: { case EngineType::FAISS_IVFFLAT_GPU: {
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_GPU); index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_MIX);
break; break;
} }
case EngineType::FAISS_IVFFLAT_CPU: { case EngineType::FAISS_IVFFLAT_CPU: {
...@@ -130,89 +82,32 @@ size_t ExecutionEngineImpl::PhysicalSize() const { ...@@ -130,89 +82,32 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
} }
Status ExecutionEngineImpl::Serialize() { Status ExecutionEngineImpl::Serialize() {
auto binaryset = index_->Serialize(); write_index(index_, location_);
FileIOWriter writer(location_);
writer(&current_type, sizeof(current_type));
for (auto &iter: binaryset.binary_map_) {
auto meta = iter.first.c_str();
size_t meta_length = iter.first.length();
writer(&meta_length, sizeof(meta_length));
writer((void *) meta, meta_length);
auto binary = iter.second;
size_t binary_length = binary->size;
writer(&binary_length, sizeof(binary_length));
writer((void *) binary->data.get(), binary_length);
}
return Status::OK(); return Status::OK();
} }
Status ExecutionEngineImpl::Load() { Status ExecutionEngineImpl::Load() {
index_ = Load(location_); index_ = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
return Status::OK(); bool to_cache = false;
} auto start_time = METRICS_NOW_TIME;
if (!index_) {
VecIndexPtr ExecutionEngineImpl::Load(const std::string &location) { index_ = read_index(location_);
knowhere::BinarySet load_data_list; to_cache = true;
FileIOReader reader(location); ENGINE_LOG_DEBUG << "Disk io from: " << location_;
reader.fs.seekg(0, reader.fs.end); }
size_t length = reader.fs.tellg();
reader.fs.seekg(0);
size_t rp = 0;
reader(&current_type, sizeof(current_type));
rp += sizeof(current_type);
while (rp < length) {
size_t meta_length;
reader(&meta_length, sizeof(meta_length));
rp += sizeof(meta_length);
reader.fs.seekg(rp);
auto meta = new char[meta_length];
reader(meta, meta_length);
rp += meta_length;
reader.fs.seekg(rp);
size_t bin_length;
reader(&bin_length, sizeof(bin_length));
rp += sizeof(bin_length);
reader.fs.seekg(rp);
auto bin = new uint8_t[bin_length]; if (to_cache) {
reader(bin, bin_length); Cache();
rp += bin_length; auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
auto binptr = std::make_shared<uint8_t>(); server::Metrics::GetInstance().FaissDiskLoadDurationSecondsHistogramObserve(total_time);
binptr.reset(bin); double total_size = Size();
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
}
auto index_type = IndexType::INVALID; server::Metrics::GetInstance().FaissDiskLoadSizeBytesHistogramObserve(total_size);
switch (current_type) { server::Metrics::GetInstance().FaissDiskLoadIOSpeedGaugeSet(total_size / double(total_time));
case EngineType::FAISS_IDMAP: {
index_type = IndexType::FAISS_IDMAP;
break;
}
case EngineType::FAISS_IVFFLAT_CPU: {
index_type = IndexType::FAISS_IVFFLAT_CPU;
break;
}
case EngineType::FAISS_IVFFLAT_GPU: {
index_type = IndexType::FAISS_IVFFLAT_GPU;
break;
}
case EngineType::SPTAG_KDT_RNT_CPU: {
index_type = IndexType::SPTAG_KDT_RNT_CPU;
break;
}
default: {
ENGINE_LOG_ERROR << "wrong index_type";
return nullptr;
}
} }
return Status::OK();
return LoadVecIndex(index_type, load_data_list);
} }
Status ExecutionEngineImpl::Merge(const std::string &location) { Status ExecutionEngineImpl::Merge(const std::string &location) {
...@@ -223,15 +118,17 @@ Status ExecutionEngineImpl::Merge(const std::string &location) { ...@@ -223,15 +118,17 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location); auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
if (!to_merge) { if (!to_merge) {
to_merge = Load(location); to_merge = read_index(location);
} }
auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge); if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
return Status::OK(); return Status::OK();
} else {
return Status::Error("file index type is not idmap");
}
} }
// TODO(linxj): add config
ExecutionEnginePtr ExecutionEnginePtr
ExecutionEngineImpl::BuildIndex(const std::string &location) { ExecutionEngineImpl::BuildIndex(const std::string &location) {
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_; ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <src/utils/Log.h> #include <src/utils/Log.h>
#include "knowhere/index/vector_index/idmap.h" #include "knowhere/index/vector_index/idmap.h"
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "vec_impl.h" #include "vec_impl.h"
#include "data_transfer.h" #include "data_transfer.h"
...@@ -98,6 +99,10 @@ int64_t VecIndexImpl::Count() { ...@@ -98,6 +99,10 @@ int64_t VecIndexImpl::Count() {
return index_->Count(); return index_->Count();
} }
IndexType VecIndexImpl::GetType() {
return type;
}
float *BFIndex::GetRawVectors() { float *BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_); auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); } if (raw_index) { return raw_index->GetRawVectors(); }
...@@ -126,6 +131,38 @@ void BFIndex::BuildAll(const long &nb, ...@@ -126,6 +131,38 @@ void BFIndex::BuildAll(const long &nb,
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} }
// TODO(linxj): add lock here.
void IVFMixIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
index_->set_preprocessor(preprocessor);
auto nlist = int(nb / 1000000.0 * 16384);
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
auto model = index_->Train(dataset, cfg_t);
index_->set_index_model(model);
index_->Add(dataset, cfg);
if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
auto host_index = device_index->Copy_index_gpu_to_cpu();
index_ = host_index;
} else {
// TODO(linxj): LOG ERROR
}
}
void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
index_ = std::make_shared<IVF>();
index_->Load(index_binary);
dim = Dimension();
}
} }
} }
} }
...@@ -17,13 +17,15 @@ namespace engine { ...@@ -17,13 +17,15 @@ namespace engine {
class VecIndexImpl : public VecIndex { class VecIndexImpl : public VecIndex {
public: public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : index_(std::move(index)) {}; explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: index_(std::move(index)), type(type) {};
void BuildAll(const long &nb, void BuildAll(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
const Config &cfg, const Config &cfg,
const long &nt, const long &nt,
const float *xt) override; const float *xt) override;
IndexType GetType() override;
int64_t Dimension() override; int64_t Dimension() override;
int64_t Count() override; int64_t Count() override;
void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override; void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
...@@ -33,21 +35,36 @@ class VecIndexImpl : public VecIndex { ...@@ -33,21 +35,36 @@ class VecIndexImpl : public VecIndex {
protected: protected:
int64_t dim = 0; int64_t dim = 0;
IndexType type = IndexType::INVALID;
std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr; std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
}; };
class IVFMixIndex : public VecIndexImpl {
public:
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IVFFLAT_MIX) {};
void BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
void Load(const zilliz::knowhere::BinarySet &index_binary) override;
};
class BFIndex : public VecIndexImpl { class BFIndex : public VecIndexImpl {
public: public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index)) {}; explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
void Build(const int64_t& d); IndexType::FAISS_IDMAP) {};
float* GetRawVectors(); void Build(const int64_t &d);
float *GetRawVectors();
void BuildAll(const long &nb, void BuildAll(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
const Config &cfg, const Config &cfg,
const long &nt, const long &nt,
const float *xt) override; const float *xt) override;
int64_t* GetRawIds(); int64_t *GetRawIds();
}; };
} }
......
...@@ -16,7 +16,56 @@ namespace zilliz { ...@@ -16,7 +16,56 @@ namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
// TODO(linxj): index_type => enum struct struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
VecIndexPtr GetVecIndexFactory(const IndexType &type) { VecIndexPtr GetVecIndexFactory(const IndexType &type) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index; std::shared_ptr<zilliz::knowhere::VectorIndex> index;
switch (type) { switch (type) {
...@@ -32,6 +81,10 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { ...@@ -32,6 +81,10 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0); index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
break; break;
} }
case IndexType::FAISS_IVFFLAT_MIX: {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
return std::make_shared<IVFMixIndex>(index);
}
case IndexType::FAISS_IVFPQ_CPU: { case IndexType::FAISS_IVFPQ_CPU: {
index = std::make_shared<zilliz::knowhere::IVFPQ>(); index = std::make_shared<zilliz::knowhere::IVFPQ>();
break; break;
...@@ -44,15 +97,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { ...@@ -44,15 +97,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
index = std::make_shared<zilliz::knowhere::CPUKDTRNG>(); index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
break; break;
} }
//case IndexType::NSG: { // TODO(linxj): bug. //case IndexType::NSG: { // TODO(linxj): bug.
// index = std::make_shared<zilliz::knowhere::NSG>(); // index = std::make_shared<zilliz::knowhere::NSG>();
// break; // break;
//} //}
default: { default: {
return nullptr; return nullptr;
} }
} }
return std::make_shared<VecIndexImpl>(index); return std::make_shared<VecIndexImpl>(index, type);
} }
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) { VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
...@@ -61,6 +114,64 @@ VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::Bi ...@@ -61,6 +114,64 @@ VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::Bi
return index; return index;
} }
VecIndexPtr read_index(const std::string &location) {
knowhere::BinarySet load_data_list;
FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end);
size_t length = reader.fs.tellg();
reader.fs.seekg(0);
size_t rp = 0;
auto current_type = IndexType::INVALID;
reader(&current_type, sizeof(current_type));
rp += sizeof(current_type);
while (rp < length) {
size_t meta_length;
reader(&meta_length, sizeof(meta_length));
rp += sizeof(meta_length);
reader.fs.seekg(rp);
auto meta = new char[meta_length];
reader(meta, meta_length);
rp += meta_length;
reader.fs.seekg(rp);
size_t bin_length;
reader(&bin_length, sizeof(bin_length));
rp += sizeof(bin_length);
reader.fs.seekg(rp);
auto bin = new uint8_t[bin_length];
reader(bin, bin_length);
rp += bin_length;
auto binptr = std::make_shared<uint8_t>();
binptr.reset(bin);
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
}
return LoadVecIndex(current_type, load_data_list);
}
void write_index(VecIndexPtr index, const std::string &location) {
auto binaryset = index->Serialize();
auto index_type = index->GetType();
FileIOWriter writer(location);
writer(&index_type, sizeof(IndexType));
for (auto &iter: binaryset.binary_map_) {
auto meta = iter.first.c_str();
size_t meta_length = iter.first.length();
writer(&meta_length, sizeof(meta_length));
writer((void *) meta, meta_length);
auto binary = iter.second;
int64_t binary_length = binary->size;
writer(&binary_length, sizeof(binary_length));
writer((void *) binary->data.get(), binary_length);
}
}
} }
} }
} }
...@@ -20,6 +20,18 @@ namespace engine { ...@@ -20,6 +20,18 @@ namespace engine {
// TODO(linxj): jsoncons => rapidjson or other. // TODO(linxj): jsoncons => rapidjson or other.
using Config = zilliz::knowhere::Config; using Config = zilliz::knowhere::Config;
enum class IndexType {
INVALID = 0,
FAISS_IDMAP = 1,
FAISS_IVFFLAT_CPU,
FAISS_IVFFLAT_GPU,
FAISS_IVFFLAT_MIX, // build on gpu and search on cpu
FAISS_IVFPQ_CPU,
FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU,
//NSG,
};
class VecIndex { class VecIndex {
public: public:
virtual void BuildAll(const long &nb, virtual void BuildAll(const long &nb,
...@@ -40,6 +52,8 @@ class VecIndex { ...@@ -40,6 +52,8 @@ class VecIndex {
long *ids, long *ids,
const Config &cfg = Config()) = 0; const Config &cfg = Config()) = 0;
virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0; virtual int64_t Dimension() = 0;
virtual int64_t Count() = 0; virtual int64_t Count() = 0;
...@@ -51,16 +65,9 @@ class VecIndex { ...@@ -51,16 +65,9 @@ class VecIndex {
using VecIndexPtr = std::shared_ptr<VecIndex>; using VecIndexPtr = std::shared_ptr<VecIndex>;
enum class IndexType { extern void write_index(VecIndexPtr index, const std::string &location);
INVALID = 0,
FAISS_IDMAP = 1, extern VecIndexPtr read_index(const std::string &location);
FAISS_IVFFLAT_CPU,
FAISS_IVFFLAT_GPU,
FAISS_IVFPQ_CPU,
FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU,
//NSG,
};
extern VecIndexPtr GetVecIndexFactory(const IndexType &type); extern VecIndexPtr GetVecIndexFactory(const IndexType &type);
......
knowhere @ ca99a689
Subproject commit c3123501d62f69f9eacaa73ee96c0daeb24620a5 Subproject commit ca99a6899be4e8a0806452656cf0f2be19d79c1a
...@@ -28,11 +28,37 @@ class KnowhereWrapperTest ...@@ -28,11 +28,37 @@ class KnowhereWrapperTest
//auto generator = GetGenerateFactory(generator_type); //auto generator = GetGenerateFactory(generator_type);
auto generator = std::make_shared<DataGenBase>(); auto generator = std::make_shared<DataGenBase>();
generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids); generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis);
index_ = GetVecIndexFactory(index_type); index_ = GetVecIndexFactory(index_type);
} }
void AssertResult(const std::vector<long> &ids, const std::vector<float> &dis) {
EXPECT_EQ(ids.size(), nq * k);
EXPECT_EQ(dis.size(), nq * k);
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(ids[i * k], gt_ids[i * k]);
EXPECT_EQ(dis[i * k], gt_dis[i * k]);
}
int match = 0;
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < k; ++j) {
for (int l = 0; l < k; ++l) {
if (ids[i * nq + j] == gt_ids[i * nq + l]) match++;
}
}
}
auto precision = float(match) / (nq * k);
EXPECT_GT(precision, 0.5);
std::cout << std::endl << "Precision: " << precision
<< ", match: " << match
<< ", total: " << nq * k
<< std::endl;
}
protected: protected:
IndexType index_type; IndexType index_type;
Config train_cfg; Config train_cfg;
...@@ -50,126 +76,88 @@ class KnowhereWrapperTest ...@@ -50,126 +76,88 @@ class KnowhereWrapperTest
// Ground Truth // Ground Truth
std::vector<long> gt_ids; std::vector<long> gt_ids;
std::vector<float> gt_dis;
}; };
INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
Values( Values(
// ["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] //["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default", std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default",
64, 10000, 10, 10, 64, 100000, 10, 10,
Config::object{{"nlist", 100}, {"dim", 64}}, Config::object{{"nlist", 100}, {"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 20}} Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
), ),
std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", //std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default",
64, 10000, 10, 10, // 64, 10000, 10, 10,
Config::object{{"TPTNumber", 1}, {"dim", 64}}, // Config::object{{"nlist", 100}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}}
//),
std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default",
64, 100000, 10, 10,
Config::object{{"nlist", 100}, {"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
),
std::make_tuple(IndexType::FAISS_IDMAP, "Default",
64, 100000, 10, 10,
Config::object{{"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}} Config::object{{"dim", 64}, {"k", 10}}
) )
//std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
// 64, 10000, 10, 10,
// Config::object{{"TPTNumber", 1}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}}
//)
) )
); );
void AssertAnns(const std::vector<long> &gt,
const std::vector<long> &res,
const int &nq,
const int &k) {
EXPECT_EQ(res.size(), nq * k);
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(gt[i * k], res[i * k]);
}
int match = 0;
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < k; ++j) {
for (int l = 0; l < k; ++l) {
if (gt[i * nq + j] == res[i * nq + l]) match++;
}
}
}
// TODO(linxj): percision check
EXPECT_GT(float(match/nq*k), 0.5);
}
TEST_P(KnowhereWrapperTest, base_test) { TEST_P(KnowhereWrapperTest, base_test) {
std::vector<long> res_ids; EXPECT_EQ(index_->GetType(), index_type);
float *D = new float[k * nq];
res_ids.resize(nq * k); auto elems = nq * k;
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg); index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg); index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k); AssertResult(res_ids, res_dis);
delete[] D;
} }
TEST_P(KnowhereWrapperTest, serialize_test) { TEST_P(KnowhereWrapperTest, serialize) {
std::vector<long> res_ids; EXPECT_EQ(index_->GetType(), index_type);
float *D = new float[k * nq];
res_ids.resize(nq * k);
auto elems = nq * k;
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg); index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg); index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k); AssertResult(res_ids, res_dis);
{ {
auto binaryset = index_->Serialize(); auto binary = index_->Serialize();
//int fileno = 0; auto type = index_->GetType();
//const std::string &base_name = "/tmp/wrapper_serialize_test_bin_"; auto new_index = GetVecIndexFactory(type);
//std::vector<std::string> filename_list; new_index->Load(binary);
//std::vector<std::pair<std::string, size_t >> meta_list; EXPECT_EQ(new_index->Dimension(), index_->Dimension());
//for (auto &iter: binaryset.binary_map_) { EXPECT_EQ(new_index->Count(), index_->Count());
// const std::string &filename = base_name + std::to_string(fileno);
// FileIOWriter writer(filename); std::vector<int64_t> res_ids(elems);
// writer(iter.second->data.get(), iter.second->size); std::vector<float> res_dis(elems);
// new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
// meta_list.push_back(std::make_pair(iter.first, iter.second.size)); AssertResult(res_ids, res_dis);
// filename_list.push_back(filename);
// ++fileno;
//}
//
//BinarySet load_data_list;
//for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
// auto bin_size = meta_list[i].second;
// FileIOReader reader(filename_list[i]);
// std::vector<uint8_t> load_data(bin_size);
// reader(load_data.data(), bin_size);
// load_data_list.Append(meta_list[i].first, load_data);
//}
int fileno = 0;
std::vector<std::string> filename_list;
const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
std::vector<std::pair<std::string, size_t >> meta_list;
for (auto &iter: binaryset.binary_map_) {
const std::string &filename = base_name + std::to_string(fileno);
FileIOWriter writer(filename);
writer(iter.second->data.get(), iter.second->size);
meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
filename_list.push_back(filename);
++fileno;
}
BinarySet load_data_list;
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
auto bin_size = meta_list[i].second;
FileIOReader reader(filename_list[i]);
auto load_data = new uint8_t[bin_size];
reader(load_data, bin_size);
auto data = std::make_shared<uint8_t>();
data.reset(load_data);
load_data_list.Append(meta_list[i].first, data, bin_size);
}
res_ids.clear();
res_ids.resize(nq * k);
auto new_index = GetVecIndexFactory(index_type);
new_index->Load(load_data_list);
new_index->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
} }
delete[] D; {
std::string file_location = "/tmp/whatever";
write_index(index_, file_location);
auto new_index = read_index(file_location);
EXPECT_EQ(new_index->GetType(), index_type);
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
EXPECT_EQ(new_index->Count(), index_->Count());
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertResult(res_ids, res_dis);
}
} }
...@@ -19,7 +19,7 @@ DataGenPtr GetGenerateFactory(const std::string &gen_type) { ...@@ -19,7 +19,7 @@ DataGenPtr GetGenerateFactory(const std::string &gen_type) {
void DataGenBase::GenData(const int &dim, const int &nb, const int &nq, void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
float *xb, float *xq, long *ids, float *xb, float *xq, long *ids,
const int &k, long *gt_ids) { const int &k, long *gt_ids, float *gt_dis) {
for (auto i = 0; i < nb; ++i) { for (auto i = 0; i < nb; ++i) {
for (auto j = 0; j < dim; ++j) { for (auto j = 0; j < dim; ++j) {
//p_data[i * d + j] = float(base + i); //p_data[i * d + j] = float(base + i);
...@@ -35,8 +35,7 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq, ...@@ -35,8 +35,7 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
faiss::IndexFlatL2 index(dim); faiss::IndexFlatL2 index(dim);
//index.add_with_ids(nb, xb, ids); //index.add_with_ids(nb, xb, ids);
index.add(nb, xb); index.add(nb, xb);
float *D = new float[k * nq]; index.search(nq, xq, k, gt_dis, gt_ids);
index.search(nq, xq, k, D, gt_ids);
} }
void DataGenBase::GenData(const int &dim, void DataGenBase::GenData(const int &dim,
...@@ -46,36 +45,12 @@ void DataGenBase::GenData(const int &dim, ...@@ -46,36 +45,12 @@ void DataGenBase::GenData(const int &dim,
std::vector<float> &xq, std::vector<float> &xq,
std::vector<long> &ids, std::vector<long> &ids,
const int &k, const int &k,
std::vector<long> &gt_ids) { std::vector<long> &gt_ids,
std::vector<float> &gt_dis) {
xb.resize(nb * dim); xb.resize(nb * dim);
xq.resize(nq * dim); xq.resize(nq * dim);
ids.resize(nb); ids.resize(nb);
gt_ids.resize(nq * k); gt_ids.resize(nq * k);
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data()); gt_dis.resize(nq * k);
} GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data());
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
} }
...@@ -23,7 +23,7 @@ extern DataGenPtr GetGenerateFactory(const std::string &gen_type); ...@@ -23,7 +23,7 @@ extern DataGenPtr GetGenerateFactory(const std::string &gen_type);
class DataGenBase { class DataGenBase {
public: public:
virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
const int &k, long *gt_ids); const int &k, long *gt_ids, float *gt_dis);
virtual void GenData(const int &dim, virtual void GenData(const int &dim,
const int &nb, const int &nb,
...@@ -32,30 +32,14 @@ class DataGenBase { ...@@ -32,30 +32,14 @@ class DataGenBase {
std::vector<float> &xq, std::vector<float> &xq,
std::vector<long> &ids, std::vector<long> &ids,
const int &k, const int &k,
std::vector<long> &gt_ids); std::vector<long> &gt_ids,
std::vector<float> &gt_dis);
}; };
class SanityCheck : public DataGenBase { //class SanityCheck : public DataGenBase {
public: // public:
void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, // void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
const int &k, long *gt_ids) override; // const int &k, long *gt_ids, float *gt_dis) override;
}; //};
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
};
...@@ -38,6 +38,10 @@ public: ...@@ -38,6 +38,10 @@ public:
} }
engine::IndexType GetType() override {
return engine::IndexType::INVALID;
}
virtual void Add(const long &nb, virtual void Add(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册