提交 18b0303f 编写于 作者: J JinHai-CN

Add one more interface: UnsetQuantizer


Former-commit-id: 34b6b4ac1f9b2841a8af6e5e4166e9ad45a9de1f
上级 2cce8978
......@@ -27,7 +27,8 @@
namespace zilliz {
namespace knowhere {
IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
......@@ -58,7 +59,8 @@ IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config
}
}
VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) {
VectorIndexPtr
IVFSQHybrid::CopyGpuToCpu(const Config &config) {
std::lock_guard<std::mutex> lk(mutex_);
if (auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
......@@ -74,7 +76,8 @@ VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) {
}
}
VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
VectorIndexPtr
IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
faiss::gpu::GpuClonerOptions option;
......@@ -95,11 +98,13 @@ VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config
}
}
void IVFSQHybrid::LoadImpl(const BinarySet &index_binary) {
void
IVFSQHybrid::LoadImpl(const BinarySet &index_binary) {
FaissBaseIndex::LoadImpl(index_binary); // load on cpu
}
void IVFSQHybrid::search_impl(int64_t n,
void
IVFSQHybrid::search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
......@@ -112,7 +117,8 @@ void IVFSQHybrid::search_impl(int64_t n,
}
}
QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) {
QuantizerPtr
IVFSQHybrid::LoadQuantizer(const Config &conf) {
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
quantizer_conf->CheckValid(); // throw exception
......@@ -140,7 +146,8 @@ QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) {
}
}
void IVFSQHybrid::SetQuantizer(QuantizerPtr q) {
void
IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) {
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(q);
if (ivf_quantizer == nullptr) {
KNOWHERE_THROW_MSG("Quantizer type error");
......@@ -158,5 +165,15 @@ void IVFSQHybrid::SetQuantizer(QuantizerPtr q) {
}
}
void
IVFSQHybrid::UnsetQuantizer() {
auto *ivf_index = dynamic_cast<faiss::IndexIVF *>(index_.get());
if(ivf_index == nullptr) {
KNOWHERE_THROW_MSG("Index type error");
}
ivf_index->quantizer = nullptr;
}
}
}
......@@ -49,7 +49,10 @@ class IVFSQHybrid : public GPUIVFSQ {
LoadQuantizer(const Config &conf);
void
SetQuantizer(QuantizerPtr q);
SetQuantizer(const QuantizerPtr& q);
void
UnsetQuantizer();
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
......
......@@ -277,7 +277,8 @@ IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
return Status::OK();
}
knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) {
knowhere::QuantizerPtr
IVFHybridIndex::LoadQuantizer(const Config& conf) {
// TODO(linxj): Hardcode here
if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)){
return new_idx->LoadQuantizer(conf);
......@@ -286,7 +287,8 @@ knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) {
}
}
Status IVFHybridIndex::SetQuantizer(knowhere::QuantizerPtr q) {
Status
IVFHybridIndex::SetQuantizer(const knowhere::QuantizerPtr& q) {
try {
// TODO(linxj): Hardcode here
if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
......@@ -304,6 +306,25 @@ Status IVFHybridIndex::SetQuantizer(knowhere::QuantizerPtr q) {
}
}
Status
IVFHybridIndex::UnsetQuantizer() {
try {
// TODO(linxj): Hardcode here
if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
new_idx->UnsetQuantizer();
} else {
WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
return Status(KNOWHERE_ERROR, "not support");
}
} catch (knowhere::KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
}
}
} // namespace engine
} // namespace milvus
} // namespace zilliz
......@@ -103,9 +103,14 @@ class IVFMixIndex : public VecIndexImpl {
class IVFHybridIndex : public IVFMixIndex {
public:
knowhere::QuantizerPtr LoadQuantizer(const Config& conf) override;
knowhere::QuantizerPtr
LoadQuantizer(const Config& conf) override;
Status SetQuantizer(knowhere::QuantizerPtr q) override;
Status
SetQuantizer(const knowhere::QuantizerPtr& q) override;
Status
UnsetQuantizer() override;
};
class BFIndex : public VecIndexImpl {
......
......@@ -105,11 +105,14 @@ class VecIndex {
// TODO(linxj): refactor later
virtual knowhere::QuantizerPtr
LoadQuantizer(const Config& conf) {}
LoadQuantizer(const Config& conf) { return Status::OK(); }
// TODO(linxj): refactor later
virtual Status
SetQuantizer(knowhere::QuantizerPtr q) {}
SetQuantizer(const knowhere::QuantizerPtr& q) { return Status::OK(); }
virtual Status
UnsetQuantizer() { return Status::OK(); }
};
extern Status
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册