提交 8b627cac 编写于 作者: C Cai Yudong 提交者: JinHai-CN

improve knowhere coverage (#2444)

* increase nb for NSG code coverage
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* test more APIs in test_annoy
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* disable get_vector_by_id and search_by_id
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* improve code coverage
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* improve code coverage
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* update unittest
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* install test_instructionset
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* update changelog
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 e2b23e9c
...@@ -31,6 +31,7 @@ Please mark all change in change log and use the issue from GitHub ...@@ -31,6 +31,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2370 Clean compile warning - \#2370 Clean compile warning
- \#2381 Upgrade FAISS to 1.6.3 - \#2381 Upgrade FAISS to 1.6.3
- \#2410 Logging build index progress - \#2410 Logging build index progress
- \#2441 Improve Knowhere code coverage
## Task ## Task
......
...@@ -109,11 +109,13 @@ class ExecutionEngine { ...@@ -109,11 +109,13 @@ class ExecutionEngine {
// virtual Status // virtual Status
// Merge(const std::string& location) = 0; // Merge(const std::string& location) = 0;
#if 0
virtual Status virtual Status
GetVectorByID(const int64_t id, float* vector, bool hybrid) = 0; GetVectorByID(const int64_t id, float* vector, bool hybrid) = 0;
virtual Status virtual Status
GetVectorByID(const int64_t id, uint8_t* vector, bool hybrid) = 0; GetVectorByID(const int64_t id, uint8_t* vector, bool hybrid) = 0;
#endif
virtual Status virtual Status
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset, ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
......
...@@ -1248,6 +1248,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil ...@@ -1248,6 +1248,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil
return Status::OK(); return Status::OK();
} }
#if 0
Status Status
ExecutionEngineImpl::GetVectorByID(const int64_t id, float* vector, bool hybrid) { ExecutionEngineImpl::GetVectorByID(const int64_t id, float* vector, bool hybrid) {
if (index_ == nullptr) { if (index_ == nullptr) {
...@@ -1299,6 +1300,7 @@ ExecutionEngineImpl::GetVectorByID(const int64_t id, uint8_t* vector, bool hybri ...@@ -1299,6 +1300,7 @@ ExecutionEngineImpl::GetVectorByID(const int64_t id, uint8_t* vector, bool hybri
return Status::OK(); return Status::OK();
} }
#endif
Status Status
ExecutionEngineImpl::Cache() { ExecutionEngineImpl::Cache() {
......
...@@ -63,11 +63,13 @@ class ExecutionEngineImpl : public ExecutionEngine { ...@@ -63,11 +63,13 @@ class ExecutionEngineImpl : public ExecutionEngine {
Status Status
CopyToCpu() override; CopyToCpu() override;
#if 0
Status Status
GetVectorByID(const int64_t id, float* vector, bool hybrid) override; GetVectorByID(const int64_t id, float* vector, bool hybrid) override;
Status Status
GetVectorByID(const int64_t id, uint8_t* vector, bool hybrid) override; GetVectorByID(const int64_t id, uint8_t* vector, bool hybrid) override;
#endif
Status Status
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset, ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
......
...@@ -37,11 +37,6 @@ class ToIndexData : public milvus::cache::DataObj { ...@@ -37,11 +37,6 @@ class ToIndexData : public milvus::cache::DataObj {
explicit ToIndexData(int64_t size) : size_(size) { explicit ToIndexData(int64_t size) : size_(size) {
} }
int64_t
Size() override {
return size_;
}
private: private:
int64_t size_ = 0; int64_t size_ = 0;
}; };
......
...@@ -151,7 +151,6 @@ IndexAnnoy::Count() { ...@@ -151,7 +151,6 @@ IndexAnnoy::Count() {
if (!index_) { if (!index_) {
KNOWHERE_THROW_MSG("index not initialize"); KNOWHERE_THROW_MSG("index not initialize");
} }
return index_->get_n_items(); return index_->get_n_items();
} }
...@@ -160,17 +159,8 @@ IndexAnnoy::Dim() { ...@@ -160,17 +159,8 @@ IndexAnnoy::Dim() {
if (!index_) { if (!index_) {
KNOWHERE_THROW_MSG("index not initialize"); KNOWHERE_THROW_MSG("index not initialize");
} }
return index_->get_dim(); return index_->get_dim();
} }
int64_t
IndexAnnoy::IndexSize() {
if (index_size_ != -1) {
return index_size_;
}
return index_size_ = Dim() * Count() * sizeof(float);
}
} // namespace knowhere } // namespace knowhere
} // namespace milvus } // namespace milvus
...@@ -62,9 +62,6 @@ class IndexAnnoy : public VecIndex { ...@@ -62,9 +62,6 @@ class IndexAnnoy : public VecIndex {
int64_t int64_t
Dim() override; Dim() override;
int64_t
IndexSize() override;
private: private:
MetricType metric_type_; MetricType metric_type_;
std::shared_ptr<AnnoyIndexInterface<int64_t, float>> index_ = nullptr; std::shared_ptr<AnnoyIndexInterface<int64_t, float>> index_ = nullptr;
......
...@@ -72,6 +72,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -72,6 +72,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
return ret_ds; return ret_ds;
} }
#if 0
DatasetPtr DatasetPtr
BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) { if (!index_) {
...@@ -109,6 +110,7 @@ BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -109,6 +110,7 @@ BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
return ret_ds; return ret_ds;
} }
#endif
void void
BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
...@@ -169,6 +171,7 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) ...@@ -169,6 +171,7 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
index_->add_with_ids(rows, (uint8_t*)p_data, new_ids.data()); index_->add_with_ids(rows, (uint8_t*)p_data, new_ids.data());
} }
#if 0
DatasetPtr DatasetPtr
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) { if (!index_) {
...@@ -189,6 +192,7 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) ...@@ -189,6 +192,7 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config)
ret_ds->Set(meta::TENSOR, p_x); ret_ds->Set(meta::TENSOR, p_x);
return ret_ds; return ret_ds;
} }
#endif
void void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
......
...@@ -50,8 +50,10 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { ...@@ -50,8 +50,10 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr DatasetPtr
Query(const DatasetPtr&, const Config&) override; Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr DatasetPtr
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override; QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
int64_t int64_t
Count() override { Count() override {
...@@ -68,8 +70,10 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { ...@@ -68,8 +70,10 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
return Count() * Dim() / 8; return Count() * Dim() / 8;
} }
#if 0
DatasetPtr DatasetPtr
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override; GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
virtual const uint8_t* virtual const uint8_t*
GetRawVectors(); GetRawVectors();
......
...@@ -83,6 +83,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -83,6 +83,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
} }
} }
#if 0
DatasetPtr DatasetPtr
BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) { if (!index_ || !index_->is_trained) {
...@@ -126,6 +127,7 @@ BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -126,6 +127,7 @@ BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG(e.what()); KNOWHERE_THROW_MSG(e.what());
} }
} }
#endif
void void
BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
...@@ -140,13 +142,14 @@ BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -140,13 +142,14 @@ BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
index_ = index; index_ = index;
} }
#if 0
DatasetPtr DatasetPtr
BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) { if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained"); KNOWHERE_THROW_MSG("index not initialize or trained");
} }
// GETBINARYTENSOR(dataset_ptr) // GETBINARYTENSOR(dataset_ptr)
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS); // auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS); auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
auto elems = dataset_ptr->Get<int64_t>(meta::DIM); auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
...@@ -166,6 +169,7 @@ BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -166,6 +169,7 @@ BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG(e.what()); KNOWHERE_THROW_MSG(e.what());
} }
} }
#endif
std::shared_ptr<faiss::IVFSearchParameters> std::shared_ptr<faiss::IVFSearchParameters>
BinaryIVF::GenParams(const Config& config) { BinaryIVF::GenParams(const Config& config) {
......
...@@ -62,8 +62,10 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { ...@@ -62,8 +62,10 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override; Query(const DatasetPtr& dataset_ptr, const Config& config) override;
#if 0
DatasetPtr DatasetPtr
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override; QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
int64_t int64_t
Count() override { Count() override {
...@@ -75,8 +77,10 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { ...@@ -75,8 +77,10 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
return index_->d; return index_->d;
} }
#if 0
DatasetPtr DatasetPtr
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config); GetVectorById(const DatasetPtr& dataset_ptr, const Config& config);
#endif
protected: protected:
virtual std::shared_ptr<faiss::IVFSearchParameters> virtual std::shared_ptr<faiss::IVFSearchParameters>
......
...@@ -113,6 +113,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -113,6 +113,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
return ret_ds; return ret_ds;
} }
#if 0
DatasetPtr DatasetPtr
IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) { if (!index_) {
...@@ -139,6 +140,7 @@ IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -139,6 +140,7 @@ IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
ret_ds->Set(meta::DISTANCE, p_dist); ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds; return ret_ds;
} }
#endif
VecIndexPtr VecIndexPtr
IDMAP::CopyCpuToGpu(const int64_t device_id, const Config& config) { IDMAP::CopyCpuToGpu(const int64_t device_id, const Config& config) {
...@@ -179,6 +181,7 @@ IDMAP::GetRawIds() { ...@@ -179,6 +181,7 @@ IDMAP::GetRawIds() {
} }
} }
#if 0
DatasetPtr DatasetPtr
IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) { if (!index_) {
...@@ -198,6 +201,7 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -198,6 +201,7 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
ret_ds->Set(meta::TENSOR, p_x); ret_ds->Set(meta::TENSOR, p_x);
return ret_ds; return ret_ds;
} }
#endif
void void
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
......
...@@ -48,8 +48,10 @@ class IDMAP : public VecIndex, public FaissBaseIndex { ...@@ -48,8 +48,10 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
DatasetPtr DatasetPtr
Query(const DatasetPtr&, const Config&) override; Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr DatasetPtr
QueryById(const DatasetPtr& dataset, const Config& config) override; QueryById(const DatasetPtr& dataset, const Config& config) override;
#endif
int64_t int64_t
Count() override { Count() override {
...@@ -66,8 +68,10 @@ class IDMAP : public VecIndex, public FaissBaseIndex { ...@@ -66,8 +68,10 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
return Count() * Dim() * sizeof(FloatType); return Count() * Dim() * sizeof(FloatType);
} }
#if 0
DatasetPtr DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) override; GetVectorById(const DatasetPtr& dataset, const Config& config) override;
#endif
VecIndexPtr VecIndexPtr
CopyCpuToGpu(const int64_t, const Config&); CopyCpuToGpu(const int64_t, const Config&);
......
...@@ -142,6 +142,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -142,6 +142,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
} }
} }
#if 0
DatasetPtr DatasetPtr
IVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { IVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) { if (!index_ || !index_->is_trained) {
...@@ -214,6 +215,7 @@ IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -214,6 +215,7 @@ IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG(e.what()); KNOWHERE_THROW_MSG(e.what());
} }
} }
#endif
void void
IVF::Seal() { IVF::Seal() {
......
...@@ -53,8 +53,10 @@ class IVF : public VecIndex, public FaissBaseIndex { ...@@ -53,8 +53,10 @@ class IVF : public VecIndex, public FaissBaseIndex {
DatasetPtr DatasetPtr
Query(const DatasetPtr&, const Config&) override; Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr DatasetPtr
QueryById(const DatasetPtr& dataset, const Config& config) override; QueryById(const DatasetPtr& dataset, const Config& config) override;
#endif
int64_t int64_t
Count() override { Count() override {
...@@ -66,8 +68,10 @@ class IVF : public VecIndex, public FaissBaseIndex { ...@@ -66,8 +68,10 @@ class IVF : public VecIndex, public FaissBaseIndex {
return index_->d; return index_->d;
} }
#if 0
DatasetPtr DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) override; GetVectorById(const DatasetPtr& dataset, const Config& config) override;
#endif
virtual void virtual void
Seal(); Seal();
......
...@@ -45,10 +45,12 @@ class VecIndex : public Index { ...@@ -45,10 +45,12 @@ class VecIndex : public Index {
virtual DatasetPtr virtual DatasetPtr
Query(const DatasetPtr& dataset, const Config& config) = 0; Query(const DatasetPtr& dataset, const Config& config) = 0;
#if 0
virtual DatasetPtr virtual DatasetPtr
QueryById(const DatasetPtr& dataset, const Config& config) { QueryById(const DatasetPtr& dataset, const Config& config) {
return nullptr; return nullptr;
} }
#endif
// virtual DatasetPtr // virtual DatasetPtr
// QueryByRange(const DatasetPtr&, const Config&) = 0; // QueryByRange(const DatasetPtr&, const Config&) = 0;
...@@ -72,10 +74,12 @@ class VecIndex : public Index { ...@@ -72,10 +74,12 @@ class VecIndex : public Index {
return index_mode_; return index_mode_;
} }
#if 0
virtual DatasetPtr virtual DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) { GetVectorById(const DatasetPtr& dataset, const Config& config) {
return nullptr; return nullptr;
} }
#endif
faiss::ConcurrentBitsetPtr faiss::ConcurrentBitsetPtr
GetBlacklist() { GetBlacklist() {
......
...@@ -147,10 +147,10 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int ...@@ -147,10 +147,10 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
ResScope rs(res_, gpu_id_); ResScope rs(res_, gpu_id_);
// if query size > 2048 we search by blocks to avoid malloc issue // if query size > 2048 we search by blocks to avoid malloc issue
size_t block_size = 2048; const int64_t block_size = 2048;
size_t dim = device_index->d; int64_t dim = device_index->d;
for (size_t i = 0; i < n; i += block_size) { for (int64_t i = 0; i < n; i += block_size) {
size_t search_size = (n - i > block_size) ? block_size : (n - i); int64_t search_size = (n - i > block_size) ? block_size : (n - i);
device_index->search(search_size, (float*)data + i * dim, k, distances + i * k, labels + i * k, bitset_); device_index->search(search_size, (float*)data + i * dim, k, distances + i * k, labels + i * k, bitset_);
} }
} else { } else {
......
...@@ -52,6 +52,7 @@ void Index::add_with_ids( ...@@ -52,6 +52,7 @@ void Index::add_with_ids(
FAISS_THROW_MSG ("add_with_ids not implemented for this type of index"); FAISS_THROW_MSG ("add_with_ids not implemented for this type of index");
} }
#if 0
void Index::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) { void Index::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) {
FAISS_THROW_MSG ("get_vector_by_id not implemented for this type of index"); FAISS_THROW_MSG ("get_vector_by_id not implemented for this type of index");
} }
...@@ -60,6 +61,7 @@ void Index::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, ...@@ -60,6 +61,7 @@ void Index::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances,
ConcurrentBitsetPtr bitset) { ConcurrentBitsetPtr bitset) {
FAISS_THROW_MSG ("search_by_id not implemented for this type of index"); FAISS_THROW_MSG ("search_by_id not implemented for this type of index");
} }
#endif
size_t Index::remove_ids(const IDSelector& /*sel*/) { size_t Index::remove_ids(const IDSelector& /*sel*/) {
FAISS_THROW_MSG ("remove_ids not implemented for this type of index"); FAISS_THROW_MSG ("remove_ids not implemented for this type of index");
......
...@@ -117,6 +117,7 @@ struct Index { ...@@ -117,6 +117,7 @@ struct Index {
float *distances, idx_t *labels, float *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) const = 0; ConcurrentBitsetPtr bitset = nullptr) const = 0;
#if 0
/** query n raw vectors from the index by ids. /** query n raw vectors from the index by ids.
* *
* return n raw vectors. * return n raw vectors.
...@@ -140,6 +141,7 @@ struct Index { ...@@ -140,6 +141,7 @@ struct Index {
*/ */
virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels, virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr); ConcurrentBitsetPtr bitset = nullptr);
#endif
/** query n vectors of dimension d to the index. /** query n vectors of dimension d to the index.
* *
......
...@@ -36,6 +36,7 @@ void IndexBinary::add_with_ids(idx_t, const uint8_t *, const idx_t *) { ...@@ -36,6 +36,7 @@ void IndexBinary::add_with_ids(idx_t, const uint8_t *, const idx_t *) {
FAISS_THROW_MSG("add_with_ids not implemented for this type of index"); FAISS_THROW_MSG("add_with_ids not implemented for this type of index");
} }
#if 0
void IndexBinary::get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset) { void IndexBinary::get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset) {
FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index"); FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index");
} }
...@@ -44,6 +45,7 @@ void IndexBinary::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *dis ...@@ -44,6 +45,7 @@ void IndexBinary::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *dis
ConcurrentBitsetPtr bitset) { ConcurrentBitsetPtr bitset) {
FAISS_THROW_MSG("search_by_id not implemented for this type of index"); FAISS_THROW_MSG("search_by_id not implemented for this type of index");
} }
#endif
size_t IndexBinary::remove_ids(const IDSelector&) { size_t IndexBinary::remove_ids(const IDSelector&) {
FAISS_THROW_MSG("remove_ids not implemented for this type of index"); FAISS_THROW_MSG("remove_ids not implemented for this type of index");
......
...@@ -99,6 +99,7 @@ struct IndexBinary { ...@@ -99,6 +99,7 @@ struct IndexBinary {
int32_t *distances, idx_t *labels, int32_t *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) const = 0; ConcurrentBitsetPtr bitset = nullptr) const = 0;
#if 0
/** Query n raw vectors from the index by ids. /** Query n raw vectors from the index by ids.
* *
* return n raw vectors. * return n raw vectors.
...@@ -122,6 +123,7 @@ struct IndexBinary { ...@@ -122,6 +123,7 @@ struct IndexBinary {
*/ */
virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels, virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr); ConcurrentBitsetPtr bitset = nullptr);
#endif
/** Query n vectors of dimension d to the index. /** Query n vectors of dimension d to the index.
* *
......
...@@ -157,6 +157,7 @@ void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k, ...@@ -157,6 +157,7 @@ void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k,
indexIVF_stats.search_time += getmillisecs() - t0; indexIVF_stats.search_time += getmillisecs() - t0;
} }
#if 0
void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset) { void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset) {
make_direct_map(true); make_direct_map(true);
...@@ -181,6 +182,7 @@ void IndexBinaryIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t * ...@@ -181,6 +182,7 @@ void IndexBinaryIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *
search(n, x, k, distances, labels, bitset); search(n, x, k, distances, labels, bitset);
delete []x; delete []x;
} }
#endif
void IndexBinaryIVF::reconstruct(idx_t key, uint8_t *recons) const { void IndexBinaryIVF::reconstruct(idx_t key, uint8_t *recons) const {
idx_t lo = direct_map.get (key); idx_t lo = direct_map.get (key);
......
...@@ -115,11 +115,13 @@ struct IndexBinaryIVF : IndexBinary { ...@@ -115,11 +115,13 @@ struct IndexBinaryIVF : IndexBinary {
void search(idx_t n, const uint8_t *x, idx_t k, void search(idx_t n, const uint8_t *x, idx_t k,
int32_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset = nullptr) const override; int32_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset = nullptr) const override;
#if 0
/** get raw vectors by ids */ /** get raw vectors by ids */
void get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset = nullptr) override; void get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset = nullptr) override;
void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels, void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) override; ConcurrentBitsetPtr bitset = nullptr) override;
#endif
void range_search(idx_t n, const uint8_t *x, int radius, void range_search(idx_t n, const uint8_t *x, int radius,
RangeSearchResult *result, RangeSearchResult *result,
......
...@@ -316,6 +316,7 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k, ...@@ -316,6 +316,7 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
indexIVF_stats.search_time += getmillisecs() - t0; indexIVF_stats.search_time += getmillisecs() - t0;
} }
#if 0
void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) { void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) {
make_direct_map(true); make_direct_map(true);
...@@ -340,6 +341,7 @@ void IndexIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distance ...@@ -340,6 +341,7 @@ void IndexIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distance
search(n, x, k, distances, labels, bitset); search(n, x, k, distances, labels, bitset);
delete []x; delete []x;
} }
#endif
void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k, void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
const idx_t *keys, const idx_t *keys,
......
...@@ -192,11 +192,13 @@ struct IndexIVF: Index, Level1Quantizer { ...@@ -192,11 +192,13 @@ struct IndexIVF: Index, Level1Quantizer {
float *distances, idx_t *labels, float *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) const override; ConcurrentBitsetPtr bitset = nullptr) const override;
#if 0
/** get raw vectors by ids */ /** get raw vectors by ids */
void get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset = nullptr) override; void get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset = nullptr) override;
void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels, void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) override; ConcurrentBitsetPtr bitset = nullptr) override;
#endif
void range_search (idx_t n, const float* x, float radius, void range_search (idx_t n, const float* x, float radius,
RangeSearchResult* result, RangeSearchResult* result,
......
...@@ -93,6 +93,7 @@ void IndexIDMapTemplate<IndexT>::search ...@@ -93,6 +93,7 @@ void IndexIDMapTemplate<IndexT>::search
} }
} }
#if 0
template <typename IndexT> template <typename IndexT>
void IndexIDMapTemplate<IndexT>::get_vector_by_id(idx_t n, const idx_t *xid, component_t *x, void IndexIDMapTemplate<IndexT>::get_vector_by_id(idx_t n, const idx_t *xid, component_t *x,
ConcurrentBitsetPtr bitset) ConcurrentBitsetPtr bitset)
...@@ -117,6 +118,7 @@ void IndexIDMapTemplate<IndexT>::search_by_id (idx_t n, const idx_t *xid, idx_t ...@@ -117,6 +118,7 @@ void IndexIDMapTemplate<IndexT>::search_by_id (idx_t n, const idx_t *xid, idx_t
index->search(n, x, k, distances, labels, bitset); index->search(n, x, k, distances, labels, bitset);
delete []x; delete []x;
} }
#endif
template <typename IndexT> template <typename IndexT>
void IndexIDMapTemplate<IndexT>::range_search void IndexIDMapTemplate<IndexT>::range_search
......
...@@ -42,10 +42,12 @@ struct IndexIDMapTemplate : IndexT { ...@@ -42,10 +42,12 @@ struct IndexIDMapTemplate : IndexT {
distance_t* distances, idx_t* labels, distance_t* distances, idx_t* labels,
ConcurrentBitsetPtr bitset = nullptr) const override; ConcurrentBitsetPtr bitset = nullptr) const override;
#if 0
void get_vector_by_id(idx_t n, const idx_t *xid, component_t *x, ConcurrentBitsetPtr bitset = nullptr) override; void get_vector_by_id(idx_t n, const idx_t *xid, component_t *x, ConcurrentBitsetPtr bitset = nullptr) override;
void search_by_id (idx_t n, const idx_t *xid, idx_t k, distance_t *distances, idx_t *labels, void search_by_id (idx_t n, const idx_t *xid, idx_t k, distance_t *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) override; ConcurrentBitsetPtr bitset = nullptr) override;
#endif
void train(idx_t n, const component_t* x) override; void train(idx_t n, const component_t* x) override;
......
...@@ -80,6 +80,7 @@ if (NOT TARGET test_instructionset) ...@@ -80,6 +80,7 @@ if (NOT TARGET test_instructionset)
add_executable(test_instructionset test_instructionset.cpp) add_executable(test_instructionset test_instructionset.cpp)
endif () endif ()
target_link_libraries(test_instructionset ${depend_libs} ${unittest_libs}) target_link_libraries(test_instructionset ${depend_libs} ${unittest_libs})
install(TARGETS test_instructionset DESTINATION unittest)
################################################################################ ################################################################################
#<KNOWHERE-COMMON-TEST> #<KNOWHERE-COMMON-TEST>
......
...@@ -28,7 +28,6 @@ class AnnoyTest : public DataGen, public TestWithParam<std::string> { ...@@ -28,7 +28,6 @@ class AnnoyTest : public DataGen, public TestWithParam<std::string> {
void void
SetUp() override { SetUp() override {
IndexType = GetParam(); IndexType = GetParam();
// std::cout << "IndexType from GetParam() is: " << IndexType << std::endl;
Generate(128, 10000, 10); Generate(128, 10000, 10);
index_ = std::make_shared<milvus::knowhere::IndexAnnoy>(); index_ = std::make_shared<milvus::knowhere::IndexAnnoy>();
conf = milvus::knowhere::Config{ conf = milvus::knowhere::Config{
...@@ -38,8 +37,6 @@ class AnnoyTest : public DataGen, public TestWithParam<std::string> { ...@@ -38,8 +37,6 @@ class AnnoyTest : public DataGen, public TestWithParam<std::string> {
{milvus::knowhere::IndexParams::search_k, 100}, {milvus::knowhere::IndexParams::search_k, 100},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
}; };
// Init_with_default();
} }
protected: protected:
...@@ -53,10 +50,20 @@ INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy")); ...@@ -53,10 +50,20 @@ INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy"));
TEST_P(AnnoyTest, annoy_basic) { TEST_P(AnnoyTest, annoy_basic) {
assert(!xb.empty()); assert(!xb.empty());
// index_->Train(base_dataset, conf); // null faiss index
{
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Serialize(conf));
ASSERT_ANY_THROW(index_->Add(base_dataset, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
ASSERT_ANY_THROW(index_->Count());
ASSERT_ANY_THROW(index_->Dim());
}
index_->BuildAll(base_dataset, conf); // Train + Add index_->BuildAll(base_dataset, conf); // Train + Add
EXPECT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim); ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf); auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k); AssertAnns(result, nq, k);
...@@ -89,8 +96,8 @@ TEST_P(AnnoyTest, annoy_delete) { ...@@ -89,8 +96,8 @@ TEST_P(AnnoyTest, annoy_delete) {
assert(!xb.empty()); assert(!xb.empty());
index_->BuildAll(base_dataset, conf); // Train + Add index_->BuildAll(base_dataset, conf); // Train + Add
EXPECT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim); ASSERT_EQ(index_->Dim(), dim);
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb); faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
for (auto i = 0; i < nq; ++i) { for (auto i = 0; i < nq; ++i) {
...@@ -191,8 +198,8 @@ TEST_P(AnnoyTest, annoy_serialize) { ...@@ -191,8 +198,8 @@ TEST_P(AnnoyTest, annoy_serialize) {
binaryset.Append("annoy_dim", dim_data, bin_dim->size); binaryset.Append("annoy_dim", dim_data, bin_dim->size);
index_->Load(binaryset); index_->Load(binaryset);
EXPECT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim); ASSERT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf); auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
} }
......
...@@ -49,6 +49,14 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { ...@@ -49,6 +49,14 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
{milvus::knowhere::Metric::TYPE, MetricType}, {milvus::knowhere::Metric::TYPE, MetricType},
}; };
// null faiss index
{
ASSERT_ANY_THROW(index_->Serialize());
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
}
index_->Train(base_dataset, conf); index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf); index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Count(), nb);
...@@ -62,7 +70,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { ...@@ -62,7 +70,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
auto binaryset = index_->Serialize(); auto binaryset = index_->Serialize();
auto new_index = std::make_shared<milvus::knowhere::BinaryIDMAP>(); auto new_index = std::make_shared<milvus::knowhere::BinaryIDMAP>();
new_index->Load(binaryset); new_index->Load(binaryset);
auto result2 = index_->Query(query_dataset, conf); auto result2 = new_index->Query(query_dataset, conf);
AssertAnns(result2, nq, k); AssertAnns(result2, nq, k);
// PrintResult(re_result, nq, k); // PrintResult(re_result, nq, k);
...@@ -72,11 +80,11 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { ...@@ -72,11 +80,11 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
} }
index_->SetBlacklist(concurrent_bitset_ptr); index_->SetBlacklist(concurrent_bitset_ptr);
auto result3 = index_->Query(query_dataset, conf); auto result_bs_1 = index_->Query(query_dataset, conf);
AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
// auto result4 = index_->SearchById(id_dataset, conf); // auto result4 = index_->SearchById(id_dataset, conf);
// AssertAneq(result4, nq, k); // AssertAneq(result4, nq, k);
} }
TEST_P(BinaryIDMAPTest, binaryidmap_serialize) { TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
...@@ -98,7 +106,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) { ...@@ -98,7 +106,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
{ {
// serialize index // serialize index
index_->Train(base_dataset, conf); index_->Train(base_dataset, conf);
index_->Add(base_dataset, milvus::knowhere::Config()); index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
auto re_result = index_->Query(query_dataset, conf); auto re_result = index_->Query(query_dataset, conf);
AssertAnns(re_result, nq, k); AssertAnns(re_result, nq, k);
// PrintResult(re_result, nq, k); // PrintResult(re_result, nq, k);
...@@ -120,6 +128,6 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) { ...@@ -120,6 +128,6 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
EXPECT_EQ(index_->Dim(), dim); EXPECT_EQ(index_->Dim(), dim);
auto result = index_->Query(query_dataset, conf); auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k); AssertAnns(result, nq, k);
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
} }
} }
...@@ -60,7 +60,15 @@ INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest, ...@@ -60,7 +60,15 @@ INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest,
TEST_P(BinaryIVFTest, binaryivf_basic) { TEST_P(BinaryIVFTest, binaryivf_basic) {
assert(!xb_bin.empty()); assert(!xb_bin.empty());
index_->Train(base_dataset, conf); // null faiss index
{
ASSERT_ANY_THROW(index_->Serialize());
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
}
index_->BuildAll(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim); EXPECT_EQ(index_->Dim(), dim);
...@@ -77,11 +85,13 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { ...@@ -77,11 +85,13 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
auto result2 = index_->Query(query_dataset, conf); auto result2 = index_->Query(query_dataset, conf);
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
#if 0
auto result3 = index_->QueryById(id_dataset, conf); auto result3 = index_->QueryById(id_dataset, conf);
AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL);
// auto result4 = index_->GetVectorById(xid_dataset, conf); auto result4 = index_->GetVectorById(xid_dataset, conf);
// AssertBinVeceq(result4, base_dataset, xid_dataset, nq, dim/8); AssertBinVeceq(result4, base_dataset, xid_dataset, nq, dim/8);
#endif
} }
TEST_P(BinaryIVFTest, binaryivf_serialize) { TEST_P(BinaryIVFTest, binaryivf_serialize) {
...@@ -93,32 +103,32 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) { ...@@ -93,32 +103,32 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
reader(ret, bin->size); reader(ret, bin->size);
}; };
// { // {
// // serialize index-model // // serialize index-model
// auto model = index_->Train(base_dataset, conf); // auto model = index_->Train(base_dataset, conf);
// auto binaryset = model->Serialize(); // auto binaryset = model->Serialize();
// auto bin = binaryset.GetByName("BinaryIVF"); // auto bin = binaryset.GetByName("BinaryIVF");
// //
// std::string filename = "/tmp/binaryivf_test_model_serialize.bin"; // std::string filename = "/tmp/binaryivf_test_model_serialize.bin";
// auto load_data = new uint8_t[bin->size]; // auto load_data = new uint8_t[bin->size];
// serialize(filename, bin, load_data); // serialize(filename, bin, load_data);
// //
// binaryset.clear(); // binaryset.clear();
// auto data = std::make_shared<uint8_t>(); // auto data = std::make_shared<uint8_t>();
// data.reset(load_data); // data.reset(load_data);
// binaryset.Append("BinaryIVF", data, bin->size); // binaryset.Append("BinaryIVF", data, bin->size);
// //
// model->Load(binaryset); // model->Load(binaryset);
// //
// index_->set_index_model(model); // index_->set_index_model(model);
// index_->Add(base_dataset, conf); // index_->Add(base_dataset, conf);
// auto result = index_->Query(query_dataset, conf); // auto result = index_->Query(query_dataset, conf);
// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// } // }
{ {
// serialize index // serialize index
index_->Train(base_dataset, conf); index_->BuildAll(base_dataset, conf);
// index_->set_index_model(model); // index_->set_index_model(model);
// index_->Add(base_dataset, conf); // index_->Add(base_dataset, conf);
auto binaryset = index_->Serialize(); auto binaryset = index_->Serialize();
......
...@@ -82,20 +82,20 @@ TEST_F(SingleIndexTest, IVFSQHybrid) { ...@@ -82,20 +82,20 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
ASSERT_ANY_THROW(cpu_idx->CopyCpuToGpuWithQuantizer(-1, conf)); ASSERT_ANY_THROW(cpu_idx->CopyCpuToGpuWithQuantizer(-1, conf));
auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf);
auto gpu_idx = pair.first; auto gpu_idx = pair.first;
auto quantization = pair.second;
auto result = gpu_idx->Query(query_dataset, conf); auto result = gpu_idx->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
milvus::json quantizer_conf{{milvus::knowhere::meta::DEVICEID, DEVICEID}, {"mode", 2}}; milvus::json quantizer_conf{{milvus::knowhere::meta::DEVICEID, DEVICEID}, {"mode", 2}};
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
auto hybrid_idx = std::make_shared<milvus::knowhere::IVFSQHybrid>(DEVICEID); auto hybrid_idx = std::make_shared<milvus::knowhere::IVFSQHybrid>(DEVICEID);
hybrid_idx->Load(binaryset); hybrid_idx->Load(binaryset);
auto quantization = hybrid_idx->LoadQuantizer(quantizer_conf);
auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf); auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf);
auto result = new_idx->Query(query_dataset, conf); auto result = new_idx->Query(query_dataset, conf);
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
} }
} }
......
...@@ -47,6 +47,16 @@ INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWTest, Values("HNSW")); ...@@ -47,6 +47,16 @@ INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWTest, Values("HNSW"));
TEST_P(HNSWTest, HNSW_basic) { TEST_P(HNSWTest, HNSW_basic) {
assert(!xb.empty()); assert(!xb.empty());
// null faiss index
{
ASSERT_ANY_THROW(index_->Serialize());
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
ASSERT_ANY_THROW(index_->Count());
ASSERT_ANY_THROW(index_->Dim());
}
index_->Train(base_dataset, conf); index_->Train(base_dataset, conf);
index_->Add(base_dataset, conf); index_->Add(base_dataset, conf);
EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Count(), nb);
......
...@@ -98,15 +98,17 @@ TEST_P(IDMAPTest, idmap_basic) { ...@@ -98,15 +98,17 @@ TEST_P(IDMAPTest, idmap_basic) {
auto binaryset = index_->Serialize(); auto binaryset = index_->Serialize();
auto new_index = std::make_shared<milvus::knowhere::IDMAP>(); auto new_index = std::make_shared<milvus::knowhere::IDMAP>();
new_index->Load(binaryset); new_index->Load(binaryset);
auto result2 = index_->Query(query_dataset, conf); auto result2 = new_index->Query(query_dataset, conf);
AssertAnns(result2, nq, k); AssertAnns(result2, nq, k);
// PrintResult(re_result, nq, k); // PrintResult(re_result, nq, k);
auto result3 = index_->QueryById(id_dataset, conf); #if 0
auto result3 = new_index->QueryById(id_dataset, conf);
AssertAnns(result3, nq, k); AssertAnns(result3, nq, k);
auto result4 = index_->GetVectorById(xid_dataset, conf); auto result4 = new_index->GetVectorById(xid_dataset, conf);
AssertVec(result4, base_dataset, xid_dataset, 1, dim); AssertVec(result4, base_dataset, xid_dataset, 1, dim);
#endif
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb); faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb);
for (int64_t i = 0; i < nq; ++i) { for (int64_t i = 0; i < nq; ++i) {
...@@ -117,11 +119,13 @@ TEST_P(IDMAPTest, idmap_basic) { ...@@ -117,11 +119,13 @@ TEST_P(IDMAPTest, idmap_basic) {
auto result_bs_1 = index_->Query(query_dataset, conf); auto result_bs_1 = index_->Query(query_dataset, conf);
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
#if 0
auto result_bs_2 = index_->QueryById(id_dataset, conf); auto result_bs_2 = index_->QueryById(id_dataset, conf);
AssertAnns(result_bs_2, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result_bs_2, nq, k, CheckMode::CHECK_NOT_EQUAL);
auto result_bs_3 = index_->GetVectorById(xid_dataset, conf); auto result_bs_3 = index_->GetVectorById(xid_dataset, conf);
AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL); AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL);
#endif
} }
TEST_P(IDMAPTest, idmap_serialize) { TEST_P(IDMAPTest, idmap_serialize) {
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
TEST(InstructionSetTest, INSTRUCTION_SET_TEST) { void
ShowInstructionSet() {
auto& outstream = std::cout; auto& outstream = std::cout;
auto support_message = [&outstream](const std::string& isa_feature, bool is_supported) { auto support_message = [&outstream](const std::string& isa_feature, bool is_supported) {
...@@ -88,3 +89,7 @@ TEST(InstructionSetTest, INSTRUCTION_SET_TEST) { ...@@ -88,3 +89,7 @@ TEST(InstructionSetTest, INSTRUCTION_SET_TEST) {
support_message("XOP", instruction_set_inst.XOP()); support_message("XOP", instruction_set_inst.XOP());
support_message("XSAVE", instruction_set_inst.XSAVE()); support_message("XSAVE", instruction_set_inst.XSAVE());
} }
TEST(InstructionSetTest, INSTRUCTION_SET_TEST) {
ASSERT_NO_FATAL_FAILURE(ShowInstructionSet());
}
...@@ -102,7 +102,7 @@ TEST_P(IVFTest, ivf_basic_cpu) { ...@@ -102,7 +102,7 @@ TEST_P(IVFTest, ivf_basic_cpu) {
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_));
index_->Train(base_dataset, conf_); index_->Train(base_dataset, conf_);
index_->Add(base_dataset, conf_); index_->AddWithoutIds(base_dataset, conf_);
EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim); EXPECT_EQ(index_->Dim(), dim);
...@@ -111,6 +111,7 @@ TEST_P(IVFTest, ivf_basic_cpu) { ...@@ -111,6 +111,7 @@ TEST_P(IVFTest, ivf_basic_cpu) {
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
#if 0
auto result2 = index_->QueryById(id_dataset, conf_); auto result2 = index_->QueryById(id_dataset, conf_);
AssertAnns(result2, nq, k); AssertAnns(result2, nq, k);
...@@ -122,6 +123,7 @@ TEST_P(IVFTest, ivf_basic_cpu) { ...@@ -122,6 +123,7 @@ TEST_P(IVFTest, ivf_basic_cpu) {
/* for SQ8, sometimes the mean diff can bigger than 20% */ /* for SQ8, sometimes the mean diff can bigger than 20% */
// AssertVec(result3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_APPROXIMATE_EQUAL); // AssertVec(result3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_APPROXIMATE_EQUAL);
} }
#endif
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb); faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(nb);
for (int64_t i = 0; i < nq; ++i) { for (int64_t i = 0; i < nq; ++i) {
...@@ -133,12 +135,14 @@ TEST_P(IVFTest, ivf_basic_cpu) { ...@@ -133,12 +135,14 @@ TEST_P(IVFTest, ivf_basic_cpu) {
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
#if 0
auto result_bs_2 = index_->QueryById(id_dataset, conf_); auto result_bs_2 = index_->QueryById(id_dataset, conf_);
AssertAnns(result_bs_2, nq, k, CheckMode::CHECK_NOT_EQUAL); AssertAnns(result_bs_2, nq, k, CheckMode::CHECK_NOT_EQUAL);
// PrintResult(result, nq, k); // PrintResult(result, nq, k);
auto result_bs_3 = index_->GetVectorById(xid_dataset, conf_); auto result_bs_3 = index_->GetVectorById(xid_dataset, conf_);
AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL); AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL);
#endif
} }
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
...@@ -157,8 +161,7 @@ TEST_P(IVFTest, ivf_basic_gpu) { ...@@ -157,8 +161,7 @@ TEST_P(IVFTest, ivf_basic_gpu) {
ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); ASSERT_ANY_THROW(index_->Add(base_dataset, conf_));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_));
index_->Train(base_dataset, conf_); index_->BuildAll(base_dataset, conf_);
index_->Add(base_dataset, conf_);
EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dim(), dim); EXPECT_EQ(index_->Dim(), dim);
......
...@@ -33,7 +33,7 @@ using ::testing::Combine; ...@@ -33,7 +33,7 @@ using ::testing::Combine;
using ::testing::TestWithParam; using ::testing::TestWithParam;
using ::testing::Values; using ::testing::Values;
constexpr int64_t DEVICEID = 0; constexpr int64_t DEVICE_GPU0 = 0;
class NSGInterfaceTest : public DataGen, public ::testing::Test { class NSGInterfaceTest : public DataGen, public ::testing::Test {
protected: protected:
...@@ -41,10 +41,10 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test { ...@@ -41,10 +41,10 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test {
SetUp() override { SetUp() override {
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
int64_t MB = 1024 * 1024; int64_t MB = 1024 * 1024;
milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, MB * 200, MB * 600, 1); milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_GPU0, MB * 200, MB * 600, 1);
#endif #endif
int nsg_dim = 256; int nsg_dim = 256;
Generate(nsg_dim, nb, nq); Generate(nsg_dim, 20000, nq);
index_ = std::make_shared<milvus::knowhere::NSG>(); index_ = std::make_shared<milvus::knowhere::NSG>();
train_conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, 256}, train_conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, 256},
...@@ -80,12 +80,14 @@ TEST_F(NSGInterfaceTest, basic_test) { ...@@ -80,12 +80,14 @@ TEST_F(NSGInterfaceTest, basic_test) {
fiu_init(0); fiu_init(0);
// untrained index // untrained index
{ {
ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf));
ASSERT_ANY_THROW(index_->Serialize()); ASSERT_ANY_THROW(index_->Serialize());
ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf));
ASSERT_ANY_THROW(index_->Add(base_dataset, search_conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, search_conf));
} }
train_conf[milvus::knowhere::meta::DEVICEID] = DEVICEID; train_conf[milvus::knowhere::meta::DEVICEID] = -1;
index_->Train(base_dataset, train_conf); index_->BuildAll(base_dataset, train_conf);
auto result = index_->Query(query_dataset, search_conf); auto result = index_->Query(query_dataset, search_conf);
AssertAnns(result, nq, k); AssertAnns(result, nq, k);
...@@ -96,16 +98,24 @@ TEST_F(NSGInterfaceTest, basic_test) { ...@@ -96,16 +98,24 @@ TEST_F(NSGInterfaceTest, basic_test) {
fiu_disable("NSG.Serialize.throw_exception"); fiu_disable("NSG.Serialize.throw_exception");
} }
auto new_index = std::make_shared<milvus::knowhere::NSG>(); /* test NSG GPU train */
new_index->Load(binaryset); auto new_index_1 = std::make_shared<milvus::knowhere::NSG>(DEVICE_GPU0);
train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0;
new_index_1->BuildAll(base_dataset, train_conf);
auto new_result_1 = new_index_1->Query(query_dataset, search_conf);
AssertAnns(new_result_1, nq, k);
/* test NSG index load */
auto new_index_2 = std::make_shared<milvus::knowhere::NSG>();
new_index_2->Load(binaryset);
{ {
fiu_enable("NSG.Load.throw_exception", 1, nullptr, 0); fiu_enable("NSG.Load.throw_exception", 1, nullptr, 0);
ASSERT_ANY_THROW(new_index->Load(binaryset)); ASSERT_ANY_THROW(new_index_2->Load(binaryset));
fiu_disable("NSG.Load.throw_exception"); fiu_disable("NSG.Load.throw_exception");
} }
auto new_result = new_index->Query(query_dataset, search_conf); auto new_result_2 = new_index_2->Query(query_dataset, search_conf);
AssertAnns(result, nq, k); AssertAnns(new_result_2, nq, k);
ASSERT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim); ASSERT_EQ(index_->Dim(), dim);
...@@ -129,7 +139,7 @@ TEST_F(NSGInterfaceTest, compare_test) { ...@@ -129,7 +139,7 @@ TEST_F(NSGInterfaceTest, compare_test) {
TEST_F(NSGInterfaceTest, delete_test) { TEST_F(NSGInterfaceTest, delete_test) {
assert(!xb.empty()); assert(!xb.empty());
train_conf[milvus::knowhere::meta::DEVICEID] = DEVICEID; train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0;
index_->Train(base_dataset, train_conf); index_->Train(base_dataset, train_conf);
auto result = index_->Query(query_dataset, search_conf); auto result = index_->Query(query_dataset, search_conf);
......
...@@ -60,7 +60,13 @@ INSTANTIATE_TEST_CASE_P(SPTAGParameters, SPTAGTest, Values("KDT", "BKT")); ...@@ -60,7 +60,13 @@ INSTANTIATE_TEST_CASE_P(SPTAGParameters, SPTAGTest, Values("KDT", "BKT"));
TEST_P(SPTAGTest, sptag_basic) { TEST_P(SPTAGTest, sptag_basic) {
assert(!xb.empty()); assert(!xb.empty());
index_->Train(base_dataset, conf); // null faiss index
{
ASSERT_ANY_THROW(index_->Add(nullptr, conf));
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
}
index_->BuildAll(base_dataset, conf);
// index_->Add(base_dataset, conf); // index_->Add(base_dataset, conf);
auto result = index_->Query(query_dataset, conf); auto result = index_->Query(query_dataset, conf);
AssertAnns(result, nq, k); AssertAnns(result, nq, k);
......
...@@ -161,6 +161,7 @@ AssertAnns(const milvus::knowhere::DatasetPtr& result, const int nq, const int k ...@@ -161,6 +161,7 @@ AssertAnns(const milvus::knowhere::DatasetPtr& result, const int nq, const int k
} }
} }
#if 0
void void
AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset,
const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) { const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) {
...@@ -194,18 +195,19 @@ AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::Da ...@@ -194,18 +195,19 @@ AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::Da
} }
void void
AssertBinVeceq(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, AssertBinVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset,
const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim) { const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) {
auto base = base_dataset->Get<const uint8_t*>(milvus::knowhere::meta::TENSOR); auto base = (uint8_t*)base_dataset->Get<const void*>(milvus::knowhere::meta::TENSOR);
auto ids = id_dataset->Get<const int64_t*>(milvus::knowhere::meta::IDS); auto ids = id_dataset->Get<const int64_t*>(milvus::knowhere::meta::IDS);
auto x = result->Get<uint8_t*>(milvus::knowhere::meta::TENSOR); auto x = result->Get<float*>(milvus::knowhere::meta::TENSOR);
for (auto i = 0; i < 1; i++) { for (auto i = 0; i < 1; i++) {
auto id = ids[i]; auto id = ids[i];
for (auto j = 0; j < dim; j++) { for (auto j = 0; j < dim; j++) {
EXPECT_EQ(*(base + id * dim + j), *(x + i * dim + j)); ASSERT_EQ(*(base + id * dim + j), *(x + i * dim + j));
} }
} }
} }
#endif
void void
PrintResult(const milvus::knowhere::DatasetPtr& result, const int& nq, const int& k) { PrintResult(const milvus::knowhere::DatasetPtr& result, const int& nq, const int& k) {
......
...@@ -76,9 +76,9 @@ AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::Da ...@@ -76,9 +76,9 @@ AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::Da
const CheckMode check_mode = CheckMode::CHECK_EQUAL); const CheckMode check_mode = CheckMode::CHECK_EQUAL);
void void
AssertBinVeceq(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, AssertBinVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset,
const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim,
const CheckMode check_mode = CheckMode::CHECK_EQUAL); const CheckMode check_mode = CheckMode::CHECK_EQUAL);
void void
PrintResult(const milvus::knowhere::DatasetPtr& result, const int& nq, const int& k); PrintResult(const milvus::knowhere::DatasetPtr& result, const int& nq, const int& k);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册