提交 c6ce5772 编写于 作者: X xj.lin

fix bad alloc and add idmap


Former-commit-id: bd2686574ad9010e33dcf34f3ae45308d5b3c971
上级 4fe9622b
......@@ -4,6 +4,7 @@
* Proprietary and confidential.
******************************************************************************/
#include <src/server/ServerConfig.h>
#include <src/metrics/Metrics.h>
#include "Log.h"
#include "src/cache/CpuCacheMgr.h"
......@@ -16,55 +17,6 @@ namespace zilliz {
namespace milvus {
namespace engine {
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension,
const std::string &location,
EngineType type)
......@@ -89,7 +41,7 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
break;
}
case EngineType::FAISS_IVFFLAT_GPU: {
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_GPU);
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_MIX);
break;
}
case EngineType::FAISS_IVFFLAT_CPU: {
......@@ -130,89 +82,32 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
}
Status ExecutionEngineImpl::Serialize() {
auto binaryset = index_->Serialize();
FileIOWriter writer(location_);
writer(&current_type, sizeof(current_type));
for (auto &iter: binaryset.binary_map_) {
auto meta = iter.first.c_str();
size_t meta_length = iter.first.length();
writer(&meta_length, sizeof(meta_length));
writer((void *) meta, meta_length);
auto binary = iter.second;
size_t binary_length = binary->size;
writer(&binary_length, sizeof(binary_length));
writer((void *) binary->data.get(), binary_length);
}
write_index(index_, location_);
return Status::OK();
}
Status ExecutionEngineImpl::Load() {
index_ = Load(location_);
return Status::OK();
}
VecIndexPtr ExecutionEngineImpl::Load(const std::string &location) {
knowhere::BinarySet load_data_list;
FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end);
size_t length = reader.fs.tellg();
reader.fs.seekg(0);
size_t rp = 0;
reader(&current_type, sizeof(current_type));
rp += sizeof(current_type);
while (rp < length) {
size_t meta_length;
reader(&meta_length, sizeof(meta_length));
rp += sizeof(meta_length);
reader.fs.seekg(rp);
auto meta = new char[meta_length];
reader(meta, meta_length);
rp += meta_length;
reader.fs.seekg(rp);
size_t bin_length;
reader(&bin_length, sizeof(bin_length));
rp += sizeof(bin_length);
reader.fs.seekg(rp);
index_ = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
bool to_cache = false;
auto start_time = METRICS_NOW_TIME;
if (!index_) {
index_ = read_index(location_);
to_cache = true;
ENGINE_LOG_DEBUG << "Disk io from: " << location_;
}
auto bin = new uint8_t[bin_length];
reader(bin, bin_length);
rp += bin_length;
if (to_cache) {
Cache();
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
auto binptr = std::make_shared<uint8_t>();
binptr.reset(bin);
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
}
server::Metrics::GetInstance().FaissDiskLoadDurationSecondsHistogramObserve(total_time);
double total_size = Size();
auto index_type = IndexType::INVALID;
switch (current_type) {
case EngineType::FAISS_IDMAP: {
index_type = IndexType::FAISS_IDMAP;
break;
}
case EngineType::FAISS_IVFFLAT_CPU: {
index_type = IndexType::FAISS_IVFFLAT_CPU;
break;
}
case EngineType::FAISS_IVFFLAT_GPU: {
index_type = IndexType::FAISS_IVFFLAT_GPU;
break;
}
case EngineType::SPTAG_KDT_RNT_CPU: {
index_type = IndexType::SPTAG_KDT_RNT_CPU;
break;
}
default: {
ENGINE_LOG_ERROR << "wrong index_type";
return nullptr;
}
server::Metrics::GetInstance().FaissDiskLoadSizeBytesHistogramObserve(total_size);
server::Metrics::GetInstance().FaissDiskLoadIOSpeedGaugeSet(total_size / double(total_time));
}
return LoadVecIndex(index_type, load_data_list);
return Status::OK();
}
Status ExecutionEngineImpl::Merge(const std::string &location) {
......@@ -223,15 +118,17 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
if (!to_merge) {
to_merge = Load(location);
to_merge = read_index(location);
}
auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge);
index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
return Status::OK();
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
return Status::OK();
} else {
return Status::Error("file index type is not idmap");
}
}
// TODO(linxj): add config
ExecutionEnginePtr
ExecutionEngineImpl::BuildIndex(const std::string &location) {
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
......
......@@ -6,6 +6,7 @@
#include <src/utils/Log.h>
#include "knowhere/index/vector_index/idmap.h"
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "vec_impl.h"
#include "data_transfer.h"
......@@ -98,6 +99,10 @@ int64_t VecIndexImpl::Count() {
return index_->Count();
}
IndexType VecIndexImpl::GetType() {
return type;
}
float *BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); }
......@@ -126,6 +131,38 @@ void BFIndex::BuildAll(const long &nb,
index_->Add(dataset, cfg);
}
// TODO(linxj): add lock here.
void IVFMixIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
index_->set_preprocessor(preprocessor);
auto nlist = int(nb / 1000000.0 * 16384);
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
auto model = index_->Train(dataset, cfg_t);
index_->set_index_model(model);
index_->Add(dataset, cfg);
if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
auto host_index = device_index->Copy_index_gpu_to_cpu();
index_ = host_index;
} else {
// TODO(linxj): LOG ERROR
}
}
void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
index_ = std::make_shared<IVF>();
index_->Load(index_binary);
dim = Dimension();
}
}
}
}
......@@ -17,13 +17,15 @@ namespace engine {
class VecIndexImpl : public VecIndex {
public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : index_(std::move(index)) {};
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: index_(std::move(index)), type(type) {};
void BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
IndexType GetType() override;
int64_t Dimension() override;
int64_t Count() override;
void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
......@@ -33,21 +35,36 @@ class VecIndexImpl : public VecIndex {
protected:
int64_t dim = 0;
IndexType type = IndexType::INVALID;
std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
};
class IVFMixIndex : public VecIndexImpl {
public:
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IVFFLAT_MIX) {};
void BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
void Load(const zilliz::knowhere::BinarySet &index_binary) override;
};
class BFIndex : public VecIndexImpl {
public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index)) {};
void Build(const int64_t& d);
float* GetRawVectors();
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IDMAP) {};
void Build(const int64_t &d);
float *GetRawVectors();
void BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
int64_t* GetRawIds();
int64_t *GetRawIds();
};
}
......
......@@ -16,7 +16,56 @@ namespace zilliz {
namespace milvus {
namespace engine {
// TODO(linxj): index_type => enum struct
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
VecIndexPtr GetVecIndexFactory(const IndexType &type) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
switch (type) {
......@@ -32,6 +81,10 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
break;
}
case IndexType::FAISS_IVFFLAT_MIX: {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
return std::make_shared<IVFMixIndex>(index);
}
case IndexType::FAISS_IVFPQ_CPU: {
index = std::make_shared<zilliz::knowhere::IVFPQ>();
break;
......@@ -44,15 +97,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
break;
}
//case IndexType::NSG: { // TODO(linxj): bug.
// index = std::make_shared<zilliz::knowhere::NSG>();
// break;
//}
//case IndexType::NSG: { // TODO(linxj): bug.
// index = std::make_shared<zilliz::knowhere::NSG>();
// break;
//}
default: {
return nullptr;
}
}
return std::make_shared<VecIndexImpl>(index);
return std::make_shared<VecIndexImpl>(index, type);
}
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
......@@ -61,6 +114,64 @@ VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::Bi
return index;
}
VecIndexPtr read_index(const std::string &location) {
knowhere::BinarySet load_data_list;
FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end);
size_t length = reader.fs.tellg();
reader.fs.seekg(0);
size_t rp = 0;
auto current_type = IndexType::INVALID;
reader(&current_type, sizeof(current_type));
rp += sizeof(current_type);
while (rp < length) {
size_t meta_length;
reader(&meta_length, sizeof(meta_length));
rp += sizeof(meta_length);
reader.fs.seekg(rp);
auto meta = new char[meta_length];
reader(meta, meta_length);
rp += meta_length;
reader.fs.seekg(rp);
size_t bin_length;
reader(&bin_length, sizeof(bin_length));
rp += sizeof(bin_length);
reader.fs.seekg(rp);
auto bin = new uint8_t[bin_length];
reader(bin, bin_length);
rp += bin_length;
auto binptr = std::make_shared<uint8_t>();
binptr.reset(bin);
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
}
return LoadVecIndex(current_type, load_data_list);
}
void write_index(VecIndexPtr index, const std::string &location) {
auto binaryset = index->Serialize();
auto index_type = index->GetType();
FileIOWriter writer(location);
writer(&index_type, sizeof(IndexType));
for (auto &iter: binaryset.binary_map_) {
auto meta = iter.first.c_str();
size_t meta_length = iter.first.length();
writer(&meta_length, sizeof(meta_length));
writer((void *) meta, meta_length);
auto binary = iter.second;
int64_t binary_length = binary->size;
writer(&binary_length, sizeof(binary_length));
writer((void *) binary->data.get(), binary_length);
}
}
}
}
}
......@@ -20,6 +20,18 @@ namespace engine {
// TODO(linxj): jsoncons => rapidjson or other.
using Config = zilliz::knowhere::Config;
enum class IndexType {
INVALID = 0,
FAISS_IDMAP = 1,
FAISS_IVFFLAT_CPU,
FAISS_IVFFLAT_GPU,
FAISS_IVFFLAT_MIX, // build on gpu and search on cpu
FAISS_IVFPQ_CPU,
FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU,
//NSG,
};
class VecIndex {
public:
virtual void BuildAll(const long &nb,
......@@ -40,6 +52,8 @@ class VecIndex {
long *ids,
const Config &cfg = Config()) = 0;
virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0;
virtual int64_t Count() = 0;
......@@ -51,16 +65,9 @@ class VecIndex {
using VecIndexPtr = std::shared_ptr<VecIndex>;
enum class IndexType {
INVALID = 0,
FAISS_IDMAP = 1,
FAISS_IVFFLAT_CPU,
FAISS_IVFFLAT_GPU,
FAISS_IVFPQ_CPU,
FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU,
//NSG,
};
extern void write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr read_index(const std::string &location);
extern VecIndexPtr GetVecIndexFactory(const IndexType &type);
......
knowhere @ ca99a689
Subproject commit c3123501d62f69f9eacaa73ee96c0daeb24620a5
Subproject commit ca99a6899be4e8a0806452656cf0f2be19d79c1a
......@@ -28,11 +28,37 @@ class KnowhereWrapperTest
//auto generator = GetGenerateFactory(generator_type);
auto generator = std::make_shared<DataGenBase>();
generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids);
generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis);
index_ = GetVecIndexFactory(index_type);
}
void AssertResult(const std::vector<long> &ids, const std::vector<float> &dis) {
EXPECT_EQ(ids.size(), nq * k);
EXPECT_EQ(dis.size(), nq * k);
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(ids[i * k], gt_ids[i * k]);
EXPECT_EQ(dis[i * k], gt_dis[i * k]);
}
int match = 0;
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < k; ++j) {
for (int l = 0; l < k; ++l) {
if (ids[i * nq + j] == gt_ids[i * nq + l]) match++;
}
}
}
auto precision = float(match) / (nq * k);
EXPECT_GT(precision, 0.5);
std::cout << std::endl << "Precision: " << precision
<< ", match: " << match
<< ", total: " << nq * k
<< std::endl;
}
protected:
IndexType index_type;
Config train_cfg;
......@@ -50,126 +76,88 @@ class KnowhereWrapperTest
// Ground Truth
std::vector<long> gt_ids;
std::vector<float> gt_dis;
};
INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
Values(
// ["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default",
64, 10000, 10, 10,
64, 100000, 10, 10,
Config::object{{"nlist", 100}, {"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 20}}
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
),
std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
64, 10000, 10, 10,
Config::object{{"TPTNumber", 1}, {"dim", 64}},
//std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default",
// 64, 10000, 10, 10,
// Config::object{{"nlist", 100}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}}
//),
std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default",
64, 100000, 10, 10,
Config::object{{"nlist", 100}, {"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
),
std::make_tuple(IndexType::FAISS_IDMAP, "Default",
64, 100000, 10, 10,
Config::object{{"dim", 64}},
Config::object{{"dim", 64}, {"k", 10}}
)
//std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
// 64, 10000, 10, 10,
// Config::object{{"TPTNumber", 1}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}}
//)
)
);
void AssertAnns(const std::vector<long> &gt,
const std::vector<long> &res,
const int &nq,
const int &k) {
EXPECT_EQ(res.size(), nq * k);
for (auto i = 0; i < nq; i++) {
EXPECT_EQ(gt[i * k], res[i * k]);
}
int match = 0;
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < k; ++j) {
for (int l = 0; l < k; ++l) {
if (gt[i * nq + j] == res[i * nq + l]) match++;
}
}
}
// TODO(linxj): percision check
EXPECT_GT(float(match/nq*k), 0.5);
}
TEST_P(KnowhereWrapperTest, base_test) {
std::vector<long> res_ids;
float *D = new float[k * nq];
res_ids.resize(nq * k);
EXPECT_EQ(index_->GetType(), index_type);
auto elems = nq * k;
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
delete[] D;
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertResult(res_ids, res_dis);
}
TEST_P(KnowhereWrapperTest, serialize_test) {
std::vector<long> res_ids;
float *D = new float[k * nq];
res_ids.resize(nq * k);
TEST_P(KnowhereWrapperTest, serialize) {
EXPECT_EQ(index_->GetType(), index_type);
auto elems = nq * k;
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertResult(res_ids, res_dis);
{
auto binaryset = index_->Serialize();
//int fileno = 0;
//const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
//std::vector<std::string> filename_list;
//std::vector<std::pair<std::string, size_t >> meta_list;
//for (auto &iter: binaryset.binary_map_) {
// const std::string &filename = base_name + std::to_string(fileno);
// FileIOWriter writer(filename);
// writer(iter.second->data.get(), iter.second->size);
//
// meta_list.push_back(std::make_pair(iter.first, iter.second.size));
// filename_list.push_back(filename);
// ++fileno;
//}
//
//BinarySet load_data_list;
//for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
// auto bin_size = meta_list[i].second;
// FileIOReader reader(filename_list[i]);
// std::vector<uint8_t> load_data(bin_size);
// reader(load_data.data(), bin_size);
// load_data_list.Append(meta_list[i].first, load_data);
//}
int fileno = 0;
std::vector<std::string> filename_list;
const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
std::vector<std::pair<std::string, size_t >> meta_list;
for (auto &iter: binaryset.binary_map_) {
const std::string &filename = base_name + std::to_string(fileno);
FileIOWriter writer(filename);
writer(iter.second->data.get(), iter.second->size);
meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
filename_list.push_back(filename);
++fileno;
}
BinarySet load_data_list;
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
auto bin_size = meta_list[i].second;
FileIOReader reader(filename_list[i]);
auto load_data = new uint8_t[bin_size];
reader(load_data, bin_size);
auto data = std::make_shared<uint8_t>();
data.reset(load_data);
load_data_list.Append(meta_list[i].first, data, bin_size);
}
res_ids.clear();
res_ids.resize(nq * k);
auto new_index = GetVecIndexFactory(index_type);
new_index->Load(load_data_list);
new_index->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
AssertAnns(gt_ids, res_ids, nq, k);
auto binary = index_->Serialize();
auto type = index_->GetType();
auto new_index = GetVecIndexFactory(type);
new_index->Load(binary);
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
EXPECT_EQ(new_index->Count(), index_->Count());
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertResult(res_ids, res_dis);
}
delete[] D;
{
std::string file_location = "/tmp/whatever";
write_index(index_, file_location);
auto new_index = read_index(file_location);
EXPECT_EQ(new_index->GetType(), index_type);
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
EXPECT_EQ(new_index->Count(), index_->Count());
std::vector<int64_t> res_ids(elems);
std::vector<float> res_dis(elems);
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
AssertResult(res_ids, res_dis);
}
}
......@@ -19,7 +19,7 @@ DataGenPtr GetGenerateFactory(const std::string &gen_type) {
void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
float *xb, float *xq, long *ids,
const int &k, long *gt_ids) {
const int &k, long *gt_ids, float *gt_dis) {
for (auto i = 0; i < nb; ++i) {
for (auto j = 0; j < dim; ++j) {
//p_data[i * d + j] = float(base + i);
......@@ -35,8 +35,7 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
faiss::IndexFlatL2 index(dim);
//index.add_with_ids(nb, xb, ids);
index.add(nb, xb);
float *D = new float[k * nq];
index.search(nq, xq, k, D, gt_ids);
index.search(nq, xq, k, gt_dis, gt_ids);
}
void DataGenBase::GenData(const int &dim,
......@@ -46,36 +45,12 @@ void DataGenBase::GenData(const int &dim,
std::vector<float> &xq,
std::vector<long> &ids,
const int &k,
std::vector<long> &gt_ids) {
std::vector<long> &gt_ids,
std::vector<float> &gt_dis) {
xb.resize(nb * dim);
xq.resize(nq * dim);
ids.resize(nb);
gt_ids.resize(nq * k);
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data());
}
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
gt_dis.resize(nq * k);
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data());
}
......@@ -23,7 +23,7 @@ extern DataGenPtr GetGenerateFactory(const std::string &gen_type);
class DataGenBase {
public:
virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
const int &k, long *gt_ids);
const int &k, long *gt_ids, float *gt_dis);
virtual void GenData(const int &dim,
const int &nb,
......@@ -32,30 +32,14 @@ class DataGenBase {
std::vector<float> &xq,
std::vector<long> &ids,
const int &k,
std::vector<long> &gt_ids);
std::vector<long> &gt_ids,
std::vector<float> &gt_dis);
};
class SanityCheck : public DataGenBase {
public:
void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
const int &k, long *gt_ids) override;
};
//class SanityCheck : public DataGenBase {
// public:
// void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
// const int &k, long *gt_ids, float *gt_dis) override;
//};
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册