提交 8540b356 编写于 作者: W wxyu

SQ8H in GPU part2


Former-commit-id: 4c48987574ed24bf8a543d97520eb3a6b554fca5
上级 e675158b
...@@ -180,7 +180,7 @@ IVFSQHybrid::UnsetQuantizer() { ...@@ -180,7 +180,7 @@ IVFSQHybrid::UnsetQuantizer() {
ivf_index->quantizer = nullptr; ivf_index->quantizer = nullptr;
} }
void VectorIndexPtr
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf); auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) { if (quantizer_conf != nullptr) {
...@@ -207,8 +207,10 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { ...@@ -207,8 +207,10 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
index_composition->mode = quantizer_conf->mode; // only 2 index_composition->mode = quantizer_conf->mode; // only 2
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option); auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
index_.reset(gpu_index); std::shared_ptr<faiss::Index> new_idx;
gpu_mode = 2; // all in gpu new_idx.reset(gpu_index);
auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id_, res);
return sq_idx;
} else { } else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
} }
......
...@@ -60,8 +60,7 @@ class IVFSQHybrid : public GPUIVFSQ { ...@@ -60,8 +60,7 @@ class IVFSQHybrid : public GPUIVFSQ {
void void
UnsetQuantizer(); UnsetQuantizer();
// todo(xiaojun): return void => VecIndex VectorIndexPtr
void
LoadData(const knowhere::QuantizerPtr& q, const Config& conf); LoadData(const knowhere::QuantizerPtr& q, const Config& conf);
IndexModelPtr IndexModelPtr
......
...@@ -253,9 +253,9 @@ TEST_P(IVFTest, hybrid) { ...@@ -253,9 +253,9 @@ TEST_P(IVFTest, hybrid) {
quantizer_conf->gpu_id = device_id; quantizer_conf->gpu_id = device_id;
auto q = hybrid_2_idx->LoadQuantizer(quantizer_conf); auto q = hybrid_2_idx->LoadQuantizer(quantizer_conf);
quantizer_conf->mode = 2; quantizer_conf->mode = 2;
hybrid_2_idx->LoadData(q, quantizer_conf); auto gpu_idx = hybrid_2_idx->LoadData(q, quantizer_conf);
auto result = hybrid_2_idx->Search(query_dataset, conf); auto result = gpu_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k); AssertAnns(result, nq, conf->k);
PrintResult(result, nq, k); PrintResult(result, nq, k);
} }
......
...@@ -256,11 +256,14 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) { ...@@ -256,11 +256,14 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
conf->gpu_id = device_id; conf->gpu_id = device_id;
if (quantizer) { if (quantizer) {
std::cout << "cache hit" << std::endl;
// cache hit // cache hit
conf->mode = 2; conf->mode = 2;
index_->SetQuantizer(quantizer->Data()); auto new_index = index_->LoadData(quantizer->Data(), conf);
index_->LoadData(quantizer->Data(), conf); index_ = new_index;
} else { } else {
std::cout << "cache miss" << std::endl;
// cache hit
// cache miss // cache miss
if (index_ == nullptr) { if (index_ == nullptr) {
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to copy to gpu"; ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to copy to gpu";
...@@ -268,9 +271,9 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) { ...@@ -268,9 +271,9 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
} }
conf->mode = 1; conf->mode = 1;
auto q = index_->LoadQuantizer(conf); auto q = index_->LoadQuantizer(conf);
index_->SetQuantizer(q);
conf->mode = 2; conf->mode = 2;
index_->LoadData(q, conf); auto new_index = index_->LoadData(q, conf);
index_ = new_index;
// cache // cache
auto cached_quantizer = std::make_shared<CachedQuantizer>(q); auto cached_quantizer = std::make_shared<CachedQuantizer>(q);
...@@ -445,7 +448,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr ...@@ -445,7 +448,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
auto status = index_->Search(n, data, distances, labels, conf); auto status = index_->Search(n, data, distances, labels, conf);
HybridUnset(); if (hybrid) {
HybridUnset();
}
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error"; ENGINE_LOG_ERROR << "Search error";
......
...@@ -315,24 +315,21 @@ IVFHybridIndex::UnsetQuantizer() { ...@@ -315,24 +315,21 @@ IVFHybridIndex::UnsetQuantizer() {
return Status::OK(); return Status::OK();
} }
Status VecIndexPtr
IVFHybridIndex::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { IVFHybridIndex::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
try { try {
// TODO(linxj): Hardcode here // TODO(linxj): Hardcode here
if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) { if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
new_idx->LoadData(q, conf); return std::make_shared<IVFHybridIndex>(new_idx->LoadData(q, conf), type);
} else { } else {
WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type); WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
return Status(KNOWHERE_ERROR, "not support");
} }
} catch (knowhere::KnowhereException& e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception& e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
} }
return Status::OK(); return nullptr;
} }
} // namespace engine } // namespace engine
......
...@@ -106,7 +106,7 @@ class IVFHybridIndex : public IVFMixIndex { ...@@ -106,7 +106,7 @@ class IVFHybridIndex : public IVFMixIndex {
Status Status
UnsetQuantizer() override; UnsetQuantizer() override;
Status VecIndexPtr
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) override; LoadData(const knowhere::QuantizerPtr& q, const Config& conf) override;
}; };
......
...@@ -103,9 +103,9 @@ class VecIndex : public cache::DataObj { ...@@ -103,9 +103,9 @@ class VecIndex : public cache::DataObj {
return nullptr; return nullptr;
} }
virtual Status virtual VecIndexPtr
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
return Status::OK(); return nullptr;
} }
virtual Status virtual Status
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册