diff --git a/cpp/src/db/ExecutionEngineImpl.cpp b/cpp/src/db/ExecutionEngineImpl.cpp index bba29fc9a077a439d8f0fbfc557864c2e801db89..64aabb3777bc28f9cf9fee7641e94b1acf1c93fc 100644 --- a/cpp/src/db/ExecutionEngineImpl.cpp +++ b/cpp/src/db/ExecutionEngineImpl.cpp @@ -4,6 +4,7 @@ * Proprietary and confidential. ******************************************************************************/ #include +#include #include "Log.h" #include "src/cache/CpuCacheMgr.h" @@ -16,55 +17,6 @@ namespace zilliz { namespace milvus { namespace engine { -struct FileIOWriter { - std::fstream fs; - std::string name; - - FileIOWriter(const std::string &fname); - ~FileIOWriter(); - size_t operator()(void *ptr, size_t size); -}; - -struct FileIOReader { - std::fstream fs; - std::string name; - - FileIOReader(const std::string &fname); - ~FileIOReader(); - size_t operator()(void *ptr, size_t size); - size_t operator()(void *ptr, size_t size, size_t pos); -}; - -FileIOReader::FileIOReader(const std::string &fname) { - name = fname; - fs = std::fstream(name, std::ios::in | std::ios::binary); -} - -FileIOReader::~FileIOReader() { - fs.close(); -} - -size_t FileIOReader::operator()(void *ptr, size_t size) { - fs.read(reinterpret_cast(ptr), size); -} - -size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) { - return 0; -} - -FileIOWriter::FileIOWriter(const std::string &fname) { - name = fname; - fs = std::fstream(name, std::ios::out | std::ios::binary); -} - -FileIOWriter::~FileIOWriter() { - fs.close(); -} - -size_t FileIOWriter::operator()(void *ptr, size_t size) { - fs.write(reinterpret_cast(ptr), size); -} - ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string &location, EngineType type) @@ -89,7 +41,7 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) { break; } case EngineType::FAISS_IVFFLAT_GPU: { - index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_GPU); + index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_MIX); break; } case EngineType::FAISS_IVFFLAT_CPU: { @@ -130,89 +82,32 @@ size_t ExecutionEngineImpl::PhysicalSize() const { } Status ExecutionEngineImpl::Serialize() { - auto binaryset = index_->Serialize(); - - FileIOWriter writer(location_); - writer(¤t_type, sizeof(current_type)); - for (auto &iter: binaryset.binary_map_) { - auto meta = iter.first.c_str(); - size_t meta_length = iter.first.length(); - writer(&meta_length, sizeof(meta_length)); - writer((void *) meta, meta_length); - - auto binary = iter.second; - size_t binary_length = binary->size; - writer(&binary_length, sizeof(binary_length)); - writer((void *) binary->data.get(), binary_length); - } + write_index(index_, location_); return Status::OK(); } Status ExecutionEngineImpl::Load() { - index_ = Load(location_); - return Status::OK(); -} - -VecIndexPtr ExecutionEngineImpl::Load(const std::string &location) { - knowhere::BinarySet load_data_list; - FileIOReader reader(location); - reader.fs.seekg(0, reader.fs.end); - size_t length = reader.fs.tellg(); - reader.fs.seekg(0); - - size_t rp = 0; - reader(¤t_type, sizeof(current_type)); - rp += sizeof(current_type); - while (rp < length) { - size_t meta_length; - reader(&meta_length, sizeof(meta_length)); - rp += sizeof(meta_length); - reader.fs.seekg(rp); - - auto meta = new char[meta_length]; - reader(meta, meta_length); - rp += meta_length; - reader.fs.seekg(rp); - - size_t bin_length; - reader(&bin_length, sizeof(bin_length)); - rp += sizeof(bin_length); - reader.fs.seekg(rp); + index_ = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location_); + bool to_cache = false; + auto start_time = METRICS_NOW_TIME; + if (!index_) { + index_ = read_index(location_); + to_cache = true; + ENGINE_LOG_DEBUG << "Disk io from: " << location_; + } - auto bin = new uint8_t[bin_length]; - reader(bin, bin_length); - rp += bin_length; + if (to_cache) { + Cache(); + auto end_time = METRICS_NOW_TIME; + auto total_time = METRICS_MICROSECONDS(start_time, end_time); - auto binptr = std::make_shared(); - binptr.reset(bin); - load_data_list.Append(std::string(meta, meta_length), binptr, bin_length); - } + server::Metrics::GetInstance().FaissDiskLoadDurationSecondsHistogramObserve(total_time); + double total_size = Size(); - auto index_type = IndexType::INVALID; - switch (current_type) { - case EngineType::FAISS_IDMAP: { - index_type = IndexType::FAISS_IDMAP; - break; - } - case EngineType::FAISS_IVFFLAT_CPU: { - index_type = IndexType::FAISS_IVFFLAT_CPU; - break; - } - case EngineType::FAISS_IVFFLAT_GPU: { - index_type = IndexType::FAISS_IVFFLAT_GPU; - break; - } - case EngineType::SPTAG_KDT_RNT_CPU: { - index_type = IndexType::SPTAG_KDT_RNT_CPU; - break; - } - default: { - ENGINE_LOG_ERROR << "wrong index_type"; - return nullptr; - } + server::Metrics::GetInstance().FaissDiskLoadSizeBytesHistogramObserve(total_size); + server::Metrics::GetInstance().FaissDiskLoadIOSpeedGaugeSet(total_size / double(total_time)); } - - return LoadVecIndex(index_type, load_data_list); + return Status::OK(); } Status ExecutionEngineImpl::Merge(const std::string &location) { @@ -223,15 +118,17 @@ Status ExecutionEngineImpl::Merge(const std::string &location) { auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location); if (!to_merge) { - to_merge = Load(location); + to_merge = read_index(location); } - auto file_index = std::dynamic_pointer_cast(to_merge); - index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); - return Status::OK(); + if (auto file_index = std::dynamic_pointer_cast(to_merge)) { + index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); + return Status::OK(); + } else { + return Status::Error("file index type is not idmap"); + } } -// TODO(linxj): add config ExecutionEnginePtr ExecutionEngineImpl::BuildIndex(const std::string &location) { ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_; diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 9b1afb84efee5bcae839507a3eea4c8bb161575c..d50bfe34da8da850ba37c61288d399ec3da41f8b 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -6,6 +6,7 @@ #include #include "knowhere/index/vector_index/idmap.h" +#include "knowhere/index/vector_index/gpu_ivf.h" #include "vec_impl.h" #include "data_transfer.h" @@ -98,6 +99,10 @@ int64_t VecIndexImpl::Count() { return index_->Count(); } +IndexType VecIndexImpl::GetType() { + return type; +} + float *BFIndex::GetRawVectors() { auto raw_index = std::dynamic_pointer_cast(index_); if (raw_index) { return raw_index->GetRawVectors(); } @@ -126,6 +131,38 @@ void BFIndex::BuildAll(const long &nb, index_->Add(dataset, cfg); } +// TODO(linxj): add lock here. +void IVFMixIndex::BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) { + dim = cfg["dim"].as(); + auto dataset = GenDatasetWithIds(nb, dim, xb, ids); + + auto preprocessor = index_->BuildPreprocessor(dataset, cfg); + index_->set_preprocessor(preprocessor); + auto nlist = int(nb / 1000000.0 * 16384); + auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}}; + auto model = index_->Train(dataset, cfg_t); + index_->set_index_model(model); + index_->Add(dataset, cfg); + + if (auto device_index = std::dynamic_pointer_cast(index_)) { + auto host_index = device_index->Copy_index_gpu_to_cpu(); + index_ = host_index; + } else { + // TODO(linxj): LOG ERROR + } +} + +void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { + index_ = std::make_shared(); + index_->Load(index_binary); + dim = Dimension(); +} + } } } diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index ab6c6b8a791f9697bdc35f88369ae02d90b47f14..1d09a069d231e21bbd43d01e95b00c779621d627 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -17,13 +17,15 @@ namespace engine { class VecIndexImpl : public VecIndex { public: - explicit VecIndexImpl(std::shared_ptr index) : index_(std::move(index)) {}; + explicit VecIndexImpl(std::shared_ptr index, const IndexType &type) + : index_(std::move(index)), type(type) {}; void BuildAll(const long &nb, const float *xb, const long *ids, const Config &cfg, const long &nt, const float *xt) override; + IndexType GetType() override; int64_t Dimension() override; int64_t Count() override; void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override; @@ -33,21 +35,36 @@ class VecIndexImpl : public VecIndex { protected: int64_t dim = 0; + IndexType type = IndexType::INVALID; std::shared_ptr index_ = nullptr; }; +class IVFMixIndex : public VecIndexImpl { + public: + explicit IVFMixIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), + IndexType::FAISS_IVFFLAT_MIX) {}; + void BuildAll(const long &nb, + const float *xb, + const long *ids, + const Config &cfg, + const long &nt, + const float *xt) override; + void Load(const zilliz::knowhere::BinarySet &index_binary) override; +}; + class BFIndex : public VecIndexImpl { public: - explicit BFIndex(std::shared_ptr index) : VecIndexImpl(std::move(index)) {}; - void Build(const int64_t& d); - float* GetRawVectors(); + explicit BFIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), + IndexType::FAISS_IDMAP) {}; + void Build(const int64_t &d); + float *GetRawVectors(); void BuildAll(const long &nb, const float *xb, const long *ids, const Config &cfg, const long &nt, const float *xt) override; - int64_t* GetRawIds(); + int64_t *GetRawIds(); }; } diff --git a/cpp/src/wrapper/knowhere/vec_index.cpp b/cpp/src/wrapper/knowhere/vec_index.cpp index 17aa428613f1fa11edcc31e352c1c9982d1ae904..55e1ea4ceacaedc5a523a10cead6e417d0def1ec 100644 --- a/cpp/src/wrapper/knowhere/vec_index.cpp +++ b/cpp/src/wrapper/knowhere/vec_index.cpp @@ -16,7 +16,56 @@ namespace zilliz { namespace milvus { namespace engine { -// TODO(linxj): index_type => enum struct +struct FileIOWriter { + std::fstream fs; + std::string name; + + FileIOWriter(const std::string &fname); + ~FileIOWriter(); + size_t operator()(void *ptr, size_t size); +}; + +struct FileIOReader { + std::fstream fs; + std::string name; + + FileIOReader(const std::string &fname); + ~FileIOReader(); + size_t operator()(void *ptr, size_t size); + size_t operator()(void *ptr, size_t size, size_t pos); +}; + +FileIOReader::FileIOReader(const std::string &fname) { + name = fname; + fs = std::fstream(name, std::ios::in | std::ios::binary); +} + +FileIOReader::~FileIOReader() { + fs.close(); +} + +size_t FileIOReader::operator()(void *ptr, size_t size) { + fs.read(reinterpret_cast(ptr), size); +} + +size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) { + return 0; +} + +FileIOWriter::FileIOWriter(const std::string &fname) { + name = fname; + fs = std::fstream(name, std::ios::out | std::ios::binary); +} + +FileIOWriter::~FileIOWriter() { + fs.close(); +} + +size_t FileIOWriter::operator()(void *ptr, size_t size) { + fs.write(reinterpret_cast(ptr), size); +} + + VecIndexPtr GetVecIndexFactory(const IndexType &type) { std::shared_ptr index; switch (type) { @@ -32,6 +81,10 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { index = std::make_shared(0); break; } + case IndexType::FAISS_IVFFLAT_MIX: { + index = std::make_shared(0); + return std::make_shared(index); + } case IndexType::FAISS_IVFPQ_CPU: { index = std::make_shared(); break; @@ -44,15 +97,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { index = std::make_shared(); break; } - //case IndexType::NSG: { // TODO(linxj): bug. - // index = std::make_shared(); - // break; - //} + //case IndexType::NSG: { // TODO(linxj): bug. + // index = std::make_shared(); + // break; + //} default: { return nullptr; } } - return std::make_shared(index); + return std::make_shared(index, type); } VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) { @@ -61,6 +114,64 @@ VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::Bi return index; } +VecIndexPtr read_index(const std::string &location) { + knowhere::BinarySet load_data_list; + FileIOReader reader(location); + reader.fs.seekg(0, reader.fs.end); + size_t length = reader.fs.tellg(); + reader.fs.seekg(0); + + size_t rp = 0; + auto current_type = IndexType::INVALID; + reader(¤t_type, sizeof(current_type)); + rp += sizeof(current_type); + while (rp < length) { + size_t meta_length; + reader(&meta_length, sizeof(meta_length)); + rp += sizeof(meta_length); + reader.fs.seekg(rp); + + auto meta = new char[meta_length]; + reader(meta, meta_length); + rp += meta_length; + reader.fs.seekg(rp); + + size_t bin_length; + reader(&bin_length, sizeof(bin_length)); + rp += sizeof(bin_length); + reader.fs.seekg(rp); + + auto bin = new uint8_t[bin_length]; + reader(bin, bin_length); + rp += bin_length; + + auto binptr = std::make_shared(); + binptr.reset(bin); + load_data_list.Append(std::string(meta, meta_length), binptr, bin_length); + } + + return LoadVecIndex(current_type, load_data_list); +} + +void write_index(VecIndexPtr index, const std::string &location) { + auto binaryset = index->Serialize(); + auto index_type = index->GetType(); + + FileIOWriter writer(location); + writer(&index_type, sizeof(IndexType)); + for (auto &iter: binaryset.binary_map_) { + auto meta = iter.first.c_str(); + size_t meta_length = iter.first.length(); + writer(&meta_length, sizeof(meta_length)); + writer((void *) meta, meta_length); + + auto binary = iter.second; + int64_t binary_length = binary->size; + writer(&binary_length, sizeof(binary_length)); + writer((void *) binary->data.get(), binary_length); + } +} + } } } diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 76c69537b53bc9f2ade9a1bb95b9a054d17b1b85..a488922d9edf3295f81628162cc8bc4e930c1b81 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -20,6 +20,18 @@ namespace engine { // TODO(linxj): jsoncons => rapidjson or other. using Config = zilliz::knowhere::Config; +enum class IndexType { + INVALID = 0, + FAISS_IDMAP = 1, + FAISS_IVFFLAT_CPU, + FAISS_IVFFLAT_GPU, + FAISS_IVFFLAT_MIX, // build on gpu and search on cpu + FAISS_IVFPQ_CPU, + FAISS_IVFPQ_GPU, + SPTAG_KDT_RNT_CPU, + //NSG, +}; + class VecIndex { public: virtual void BuildAll(const long &nb, @@ -40,6 +52,8 @@ class VecIndex { long *ids, const Config &cfg = Config()) = 0; + virtual IndexType GetType() = 0; + virtual int64_t Dimension() = 0; virtual int64_t Count() = 0; @@ -51,16 +65,9 @@ class VecIndex { using VecIndexPtr = std::shared_ptr; -enum class IndexType { - INVALID = 0, - FAISS_IDMAP = 1, - FAISS_IVFFLAT_CPU, - FAISS_IVFFLAT_GPU, - FAISS_IVFPQ_CPU, - FAISS_IVFPQ_GPU, - SPTAG_KDT_RNT_CPU, - //NSG, -}; +extern void write_index(VecIndexPtr index, const std::string &location); + +extern VecIndexPtr read_index(const std::string &location); extern VecIndexPtr GetVecIndexFactory(const IndexType &type); diff --git a/cpp/thirdparty/knowhere b/cpp/thirdparty/knowhere index c3123501d62f69f9eacaa73ee96c0daeb24620a5..ca99a6899be4e8a0806452656cf0f2be19d79c1a 160000 --- a/cpp/thirdparty/knowhere +++ b/cpp/thirdparty/knowhere @@ -1 +1 @@ -Subproject commit c3123501d62f69f9eacaa73ee96c0daeb24620a5 +Subproject commit ca99a6899be4e8a0806452656cf0f2be19d79c1a diff --git a/cpp/unittest/index_wrapper/knowhere_test.cpp b/cpp/unittest/index_wrapper/knowhere_test.cpp index b4f8feba0365721f2a2ed67b1350821e8c183dee..30673dba1f5a2c3cef9d9d0dfe3e52a6a152ca21 100644 --- a/cpp/unittest/index_wrapper/knowhere_test.cpp +++ b/cpp/unittest/index_wrapper/knowhere_test.cpp @@ -28,11 +28,37 @@ class KnowhereWrapperTest //auto generator = GetGenerateFactory(generator_type); auto generator = std::make_shared(); - generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids); + generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis); index_ = GetVecIndexFactory(index_type); } + void AssertResult(const std::vector &ids, const std::vector &dis) { + EXPECT_EQ(ids.size(), nq * k); + EXPECT_EQ(dis.size(), nq * k); + + for (auto i = 0; i < nq; i++) { + EXPECT_EQ(ids[i * k], gt_ids[i * k]); + EXPECT_EQ(dis[i * k], gt_dis[i * k]); + } + + int match = 0; + for (int i = 0; i < nq; ++i) { + for (int j = 0; j < k; ++j) { + for (int l = 0; l < k; ++l) { + if (ids[i * nq + j] == gt_ids[i * nq + l]) match++; + } + } + } + + auto precision = float(match) / (nq * k); + EXPECT_GT(precision, 0.5); + std::cout << std::endl << "Precision: " << precision + << ", match: " << match + << ", total: " << nq * k + << std::endl; + } + protected: IndexType index_type; Config train_cfg; @@ -50,126 +76,88 @@ class KnowhereWrapperTest // Ground Truth std::vector gt_ids; + std::vector gt_dis; }; INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, Values( - // ["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] + //["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default", - 64, 10000, 10, 10, + 64, 100000, 10, 10, Config::object{{"nlist", 100}, {"dim", 64}}, - Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 20}} + Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}} ), - std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", - 64, 10000, 10, 10, - Config::object{{"TPTNumber", 1}, {"dim", 64}}, + //std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default", + // 64, 10000, 10, 10, + // Config::object{{"nlist", 100}, {"dim", 64}}, + // Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}} + //), + std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default", + 64, 100000, 10, 10, + Config::object{{"nlist", 100}, {"dim", 64}}, + Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}} + ), + std::make_tuple(IndexType::FAISS_IDMAP, "Default", + 64, 100000, 10, 10, + Config::object{{"dim", 64}}, Config::object{{"dim", 64}, {"k", 10}} ) + //std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", + // 64, 10000, 10, 10, + // Config::object{{"TPTNumber", 1}, {"dim", 64}}, + // Config::object{{"dim", 64}, {"k", 10}} + //) ) ); -void AssertAnns(const std::vector >, - const std::vector &res, - const int &nq, - const int &k) { - EXPECT_EQ(res.size(), nq * k); - - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(gt[i * k], res[i * k]); - } - - int match = 0; - for (int i = 0; i < nq; ++i) { - for (int j = 0; j < k; ++j) { - for (int l = 0; l < k; ++l) { - if (gt[i * nq + j] == res[i * nq + l]) match++; - } - } - } - - // TODO(linxj): percision check - EXPECT_GT(float(match/nq*k), 0.5); -} - TEST_P(KnowhereWrapperTest, base_test) { - std::vector res_ids; - float *D = new float[k * nq]; - res_ids.resize(nq * k); + EXPECT_EQ(index_->GetType(), index_type); + + auto elems = nq * k; + std::vector res_ids(elems); + std::vector res_dis(elems); index_->BuildAll(nb, xb.data(), ids.data(), train_cfg); - index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg); - AssertAnns(gt_ids, res_ids, nq, k); - delete[] D; + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + AssertResult(res_ids, res_dis); } -TEST_P(KnowhereWrapperTest, serialize_test) { - std::vector res_ids; - float *D = new float[k * nq]; - res_ids.resize(nq * k); +TEST_P(KnowhereWrapperTest, serialize) { + EXPECT_EQ(index_->GetType(), index_type); + auto elems = nq * k; + std::vector res_ids(elems); + std::vector res_dis(elems); index_->BuildAll(nb, xb.data(), ids.data(), train_cfg); - index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg); - AssertAnns(gt_ids, res_ids, nq, k); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + AssertResult(res_ids, res_dis); { - auto binaryset = index_->Serialize(); - //int fileno = 0; - //const std::string &base_name = "/tmp/wrapper_serialize_test_bin_"; - //std::vector filename_list; - //std::vector> meta_list; - //for (auto &iter: binaryset.binary_map_) { - // const std::string &filename = base_name + std::to_string(fileno); - // FileIOWriter writer(filename); - // writer(iter.second->data.get(), iter.second->size); - // - // meta_list.push_back(std::make_pair(iter.first, iter.second.size)); - // filename_list.push_back(filename); - // ++fileno; - //} - // - //BinarySet load_data_list; - //for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) { - // auto bin_size = meta_list[i].second; - // FileIOReader reader(filename_list[i]); - // std::vector load_data(bin_size); - // reader(load_data.data(), bin_size); - // load_data_list.Append(meta_list[i].first, load_data); - //} - - int fileno = 0; - std::vector filename_list; - const std::string &base_name = "/tmp/wrapper_serialize_test_bin_"; - std::vector> meta_list; - for (auto &iter: binaryset.binary_map_) { - const std::string &filename = base_name + std::to_string(fileno); - FileIOWriter writer(filename); - writer(iter.second->data.get(), iter.second->size); - - meta_list.emplace_back(std::make_pair(iter.first, iter.second->size)); - filename_list.push_back(filename); - ++fileno; - } - - BinarySet load_data_list; - for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) { - auto bin_size = meta_list[i].second; - FileIOReader reader(filename_list[i]); - - auto load_data = new uint8_t[bin_size]; - reader(load_data, bin_size); - auto data = std::make_shared(); - data.reset(load_data); - load_data_list.Append(meta_list[i].first, data, bin_size); - } - - - res_ids.clear(); - res_ids.resize(nq * k); - auto new_index = GetVecIndexFactory(index_type); - new_index->Load(load_data_list); - new_index->Search(nq, xq.data(), D, res_ids.data(), search_cfg); - AssertAnns(gt_ids, res_ids, nq, k); + auto binary = index_->Serialize(); + auto type = index_->GetType(); + auto new_index = GetVecIndexFactory(type); + new_index->Load(binary); + EXPECT_EQ(new_index->Dimension(), index_->Dimension()); + EXPECT_EQ(new_index->Count(), index_->Count()); + + std::vector res_ids(elems); + std::vector res_dis(elems); + new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + AssertResult(res_ids, res_dis); } - delete[] D; + { + std::string file_location = "/tmp/whatever"; + write_index(index_, file_location); + auto new_index = read_index(file_location); + EXPECT_EQ(new_index->GetType(), index_type); + EXPECT_EQ(new_index->Dimension(), index_->Dimension()); + EXPECT_EQ(new_index->Count(), index_->Count()); + + std::vector res_ids(elems); + std::vector res_dis(elems); + new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + AssertResult(res_ids, res_dis); + } } + diff --git a/cpp/unittest/index_wrapper/utils.cpp b/cpp/unittest/index_wrapper/utils.cpp index e228ae001d9ec5c83d31f6a2c9ddc288473298cc..ede5dd048502f9f5c1fd59611e154a83cccfd5cf 100644 --- a/cpp/unittest/index_wrapper/utils.cpp +++ b/cpp/unittest/index_wrapper/utils.cpp @@ -19,7 +19,7 @@ DataGenPtr GetGenerateFactory(const std::string &gen_type) { void DataGenBase::GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, - const int &k, long *gt_ids) { + const int &k, long *gt_ids, float *gt_dis) { for (auto i = 0; i < nb; ++i) { for (auto j = 0; j < dim; ++j) { //p_data[i * d + j] = float(base + i); @@ -35,8 +35,7 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq, faiss::IndexFlatL2 index(dim); //index.add_with_ids(nb, xb, ids); index.add(nb, xb); - float *D = new float[k * nq]; - index.search(nq, xq, k, D, gt_ids); + index.search(nq, xq, k, gt_dis, gt_ids); } void DataGenBase::GenData(const int &dim, @@ -46,36 +45,12 @@ void DataGenBase::GenData(const int &dim, std::vector &xq, std::vector &ids, const int &k, - std::vector >_ids) { + std::vector >_ids, + std::vector >_dis) { xb.resize(nb * dim); xq.resize(nq * dim); ids.resize(nb); gt_ids.resize(nq * k); - GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data()); -} - -FileIOReader::FileIOReader(const std::string &fname) { - name = fname; - fs = std::fstream(name, std::ios::in | std::ios::binary); -} - -FileIOReader::~FileIOReader() { - fs.close(); -} - -size_t FileIOReader::operator()(void *ptr, size_t size) { - fs.read(reinterpret_cast(ptr), size); -} - -FileIOWriter::FileIOWriter(const std::string &fname) { - name = fname; - fs = std::fstream(name, std::ios::out | std::ios::binary); -} - -FileIOWriter::~FileIOWriter() { - fs.close(); -} - -size_t FileIOWriter::operator()(void *ptr, size_t size) { - fs.write(reinterpret_cast(ptr), size); + gt_dis.resize(nq * k); + GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data()); } diff --git a/cpp/unittest/index_wrapper/utils.h b/cpp/unittest/index_wrapper/utils.h index bbc52a011bbff2717f40b0990f3e1cae0f847b90..ce3c428d68a43ed0e1cf0ba8acae698ad3de57cb 100644 --- a/cpp/unittest/index_wrapper/utils.h +++ b/cpp/unittest/index_wrapper/utils.h @@ -23,7 +23,7 @@ extern DataGenPtr GetGenerateFactory(const std::string &gen_type); class DataGenBase { public: virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, - const int &k, long *gt_ids); + const int &k, long *gt_ids, float *gt_dis); virtual void GenData(const int &dim, const int &nb, @@ -32,30 +32,14 @@ class DataGenBase { std::vector &xq, std::vector &ids, const int &k, - std::vector >_ids); + std::vector >_ids, + std::vector >_dis); }; -class SanityCheck : public DataGenBase { - public: - void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, - const int &k, long *gt_ids) override; -}; +//class SanityCheck : public DataGenBase { +// public: +// void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, +// const int &k, long *gt_ids, float *gt_dis) override; +//}; -struct FileIOWriter { - std::fstream fs; - std::string name; - - FileIOWriter(const std::string &fname); - ~FileIOWriter(); - size_t operator()(void *ptr, size_t size); -}; - -struct FileIOReader { - std::fstream fs; - std::string name; - - FileIOReader(const std::string &fname); - ~FileIOReader(); - size_t operator()(void *ptr, size_t size); -};