提交 8540b356 编写于 作者: W wxyu

SQ8H in GPU part2


Former-commit-id: 4c48987574ed24bf8a543d97520eb3a6b554fca5
上级 e675158b
......@@ -180,7 +180,7 @@ IVFSQHybrid::UnsetQuantizer() {
ivf_index->quantizer = nullptr;
}
void
VectorIndexPtr
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
......@@ -207,8 +207,10 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
index_composition->mode = quantizer_conf->mode; // only 2
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
index_.reset(gpu_index);
gpu_mode = 2; // all in gpu
std::shared_ptr<faiss::Index> new_idx;
new_idx.reset(gpu_index);
auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id_, res);
return sq_idx;
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
......
......@@ -60,8 +60,7 @@ class IVFSQHybrid : public GPUIVFSQ {
void
UnsetQuantizer();
// todo(xiaojun): return void => VecIndex
void
VectorIndexPtr
LoadData(const knowhere::QuantizerPtr& q, const Config& conf);
IndexModelPtr
......
......@@ -253,9 +253,9 @@ TEST_P(IVFTest, hybrid) {
quantizer_conf->gpu_id = device_id;
auto q = hybrid_2_idx->LoadQuantizer(quantizer_conf);
quantizer_conf->mode = 2;
hybrid_2_idx->LoadData(q, quantizer_conf);
auto gpu_idx = hybrid_2_idx->LoadData(q, quantizer_conf);
auto result = hybrid_2_idx->Search(query_dataset, conf);
auto result = gpu_idx->Search(query_dataset, conf);
AssertAnns(result, nq, conf->k);
PrintResult(result, nq, k);
}
......
......@@ -256,11 +256,14 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
conf->gpu_id = device_id;
if (quantizer) {
std::cout << "cache hit" << std::endl;
// cache hit
conf->mode = 2;
index_->SetQuantizer(quantizer->Data());
index_->LoadData(quantizer->Data(), conf);
auto new_index = index_->LoadData(quantizer->Data(), conf);
index_ = new_index;
} else {
std::cout << "cache miss" << std::endl;
// cache hit
// cache miss
if (index_ == nullptr) {
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to copy to gpu";
......@@ -268,9 +271,9 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
}
conf->mode = 1;
auto q = index_->LoadQuantizer(conf);
index_->SetQuantizer(q);
conf->mode = 2;
index_->LoadData(q, conf);
auto new_index = index_->LoadData(q, conf);
index_ = new_index;
// cache
auto cached_quantizer = std::make_shared<CachedQuantizer>(q);
......@@ -445,7 +448,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
auto status = index_->Search(n, data, distances, labels, conf);
HybridUnset();
if (hybrid) {
HybridUnset();
}
if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error";
......
......@@ -315,24 +315,21 @@ IVFHybridIndex::UnsetQuantizer() {
return Status::OK();
}
Status
VecIndexPtr
IVFHybridIndex::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
try {
// TODO(linxj): Hardcode here
if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
new_idx->LoadData(q, conf);
return std::make_shared<IVFHybridIndex>(new_idx->LoadData(q, conf), type);
} else {
WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
return Status(KNOWHERE_ERROR, "not support");
}
} catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what());
}
return Status::OK();
return nullptr;
}
} // namespace engine
......
......@@ -106,7 +106,7 @@ class IVFHybridIndex : public IVFMixIndex {
Status
UnsetQuantizer() override;
Status
VecIndexPtr
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) override;
};
......
......@@ -103,9 +103,9 @@ class VecIndex : public cache::DataObj {
return nullptr;
}
virtual Status
virtual VecIndexPtr
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
return Status::OK();
return nullptr;
}
virtual Status
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册