diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp index 4f0e28c07c35d0b04b8f976f7a39cead1ffb345d..b2110df952b63313a6e75138d698c95351654920 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp @@ -27,7 +27,8 @@ namespace zilliz { namespace knowhere { -IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { build_cfg->CheckValid(); // throw exception @@ -58,7 +59,8 @@ IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config } } -VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) { +VectorIndexPtr +IVFSQHybrid::CopyGpuToCpu(const Config &config) { std::lock_guard lk(mutex_); if (auto device_idx = std::dynamic_pointer_cast(index_)) { @@ -74,7 +76,8 @@ VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) { } } -VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) { +VectorIndexPtr +IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) { if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); faiss::gpu::GpuClonerOptions option; @@ -95,11 +98,13 @@ VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config } } -void IVFSQHybrid::LoadImpl(const BinarySet &index_binary) { +void +IVFSQHybrid::LoadImpl(const BinarySet &index_binary) { FaissBaseIndex::LoadImpl(index_binary); // load on cpu } -void IVFSQHybrid::search_impl(int64_t n, +void +IVFSQHybrid::search_impl(int64_t n, const float *data, int64_t k, float *distances, @@ -112,7 +117,8 @@ void IVFSQHybrid::search_impl(int64_t n, } } -QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) { +QuantizerPtr +IVFSQHybrid::LoadQuantizer(const Config &conf) { auto quantizer_conf = std::dynamic_pointer_cast(conf); if (quantizer_conf != nullptr) { quantizer_conf->CheckValid(); // throw exception @@ -140,7 +146,8 @@ QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) { } } -void IVFSQHybrid::SetQuantizer(QuantizerPtr q) { +void +IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) { auto ivf_quantizer = std::dynamic_pointer_cast(q); if (ivf_quantizer == nullptr) { KNOWHERE_THROW_MSG("Quantizer type error"); @@ -158,5 +165,15 @@ void IVFSQHybrid::SetQuantizer(QuantizerPtr q) { } } +void +IVFSQHybrid::UnsetQuantizer() { + auto *ivf_index = dynamic_cast(index_.get()); + if(ivf_index == nullptr) { + KNOWHERE_THROW_MSG("Index type error"); + } + + ivf_index->quantizer = nullptr; +} + } } diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h index 020fe110b40c23dd815b098e54c05aba3b3ad870..f2f7c2003921cb9a19c73b521b306591bb08dbb6 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h @@ -49,7 +49,10 @@ class IVFSQHybrid : public GPUIVFSQ { LoadQuantizer(const Config &conf); void - SetQuantizer(QuantizerPtr q); + SetQuantizer(const QuantizerPtr& q); + + void + UnsetQuantizer(); IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override; diff --git a/cpp/src/wrapper/VecImpl.cpp b/cpp/src/wrapper/VecImpl.cpp index 0bbe97cbbed1b265911146b241938ba36fc26ddf..4530757f6810facda4e2e0343b97f92c71aadc02 100644 --- a/cpp/src/wrapper/VecImpl.cpp +++ b/cpp/src/wrapper/VecImpl.cpp @@ -277,7 +277,8 @@ IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { return Status::OK(); } -knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) { +knowhere::QuantizerPtr +IVFHybridIndex::LoadQuantizer(const Config& conf) { // TODO(linxj): Hardcode here if (auto new_idx = std::dynamic_pointer_cast(index_)){ return new_idx->LoadQuantizer(conf); @@ -286,7 +287,8 @@ knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) { } } -Status IVFHybridIndex::SetQuantizer(knowhere::QuantizerPtr q) { +Status +IVFHybridIndex::SetQuantizer(const knowhere::QuantizerPtr& q) { try { // TODO(linxj): Hardcode here if (auto new_idx = std::dynamic_pointer_cast(index_)) { @@ -304,6 +306,25 @@ Status IVFHybridIndex::SetQuantizer(knowhere::QuantizerPtr q) { } } +Status +IVFHybridIndex::UnsetQuantizer() { + try { + // TODO(linxj): Hardcode here + if (auto new_idx = std::dynamic_pointer_cast(index_)) { + new_idx->UnsetQuantizer(); + } else { + WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type); + return Status(KNOWHERE_ERROR, "not support"); + } + } catch (knowhere::KnowhereException &e) { + WRAPPER_LOG_ERROR << e.what(); + return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); + } catch (std::exception &e) { + WRAPPER_LOG_ERROR << e.what(); + return Status(KNOWHERE_ERROR, e.what()); + } +} + } // namespace engine } // namespace milvus } // namespace zilliz diff --git a/cpp/src/wrapper/VecImpl.h b/cpp/src/wrapper/VecImpl.h index 7a131f7e07fe28c03600dd104a400ed1ab25e7d1..ba0934cf923fdfc957cc383a52a4da4305c6fb9a 100644 --- a/cpp/src/wrapper/VecImpl.h +++ b/cpp/src/wrapper/VecImpl.h @@ -103,9 +103,14 @@ class IVFMixIndex : public VecIndexImpl { class IVFHybridIndex : public IVFMixIndex { public: - knowhere::QuantizerPtr LoadQuantizer(const Config& conf) override; + knowhere::QuantizerPtr + LoadQuantizer(const Config& conf) override; - Status SetQuantizer(knowhere::QuantizerPtr q) override; + Status + SetQuantizer(const knowhere::QuantizerPtr& q) override; + + Status + UnsetQuantizer() override; }; class BFIndex : public VecIndexImpl { diff --git a/cpp/src/wrapper/VecIndex.h b/cpp/src/wrapper/VecIndex.h index 8a9c91377bf7c47bafea18992def56d2d0e849ef..24352b8d973847ba4b2ddfe3c8b936b7a78e9254 100644 --- a/cpp/src/wrapper/VecIndex.h +++ b/cpp/src/wrapper/VecIndex.h @@ -105,11 +105,14 @@ class VecIndex { // TODO(linxj): refactor later virtual knowhere::QuantizerPtr - LoadQuantizer(const Config& conf) {} + LoadQuantizer(const Config& conf) { return Status::OK(); } // TODO(linxj): refactor later virtual Status - SetQuantizer(knowhere::QuantizerPtr q) {} + SetQuantizer(const knowhere::QuantizerPtr& q) { return Status::OK(); } + + virtual Status + UnsetQuantizer() { return Status::OK(); } }; extern Status