提交 2cce8978 编写于 作者: X xiaojun.lin

hybrid enable


Former-commit-id: d64d5a4607a750934716386c0e1f9fc0c8dc9c89
上级 9cc0c598
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "faiss/gpu/GpuIndexIVF.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "IndexIVFSQHybrid.h"
#include "faiss/AutoTune.h"
#include "faiss/gpu/GpuAutoTune.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << "," << "SQ8Hybrid";
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
delete device_index;
delete build_index;
return std::make_shared<IVFIndexModel>(host_index);
} else {
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
}
}
VectorIndexPtr IVFSQHybrid::CopyGpuToCpu(const Config &config) {
std::lock_guard<std::mutex> lk(mutex_);
if (auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
faiss::Index *device_index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
std::shared_ptr<faiss::Index> new_index;
new_index.reset(host_index);
return std::make_shared<IVFSQHybrid>(new_index);
} else {
// TODO(linxj): why? jinhai
return std::make_shared<IVFSQHybrid>(index_);
}
}
VectorIndexPtr IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
faiss::gpu::GpuClonerOptions option;
option.allInGpu = true;
faiss::IndexComposition index_composition;
index_composition.index = index_.get();
index_composition.quantizer = nullptr;
index_composition.mode = 0; // copy all
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, &index_composition, &option);
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<IVFSQHybrid>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
}
void IVFSQHybrid::LoadImpl(const BinarySet &index_binary) {
FaissBaseIndex::LoadImpl(index_binary); // load on cpu
}
void IVFSQHybrid::search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) {
if (gpu_mode) {
GPUIVF::search_impl(n, data, k, distances, labels, cfg);
} else {
IVF::search_impl(n, data, k, distances, labels, cfg);
}
}
QuantizerPtr IVFSQHybrid::LoadQuantizer(const Config &conf) {
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
quantizer_conf->CheckValid(); // throw exception
}
gpu_id_ = quantizer_conf->gpu_id;
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(res, gpu_id_, false);
faiss::gpu::GpuClonerOptions option;
option.allInGpu = true;
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = nullptr;
index_composition->mode = quantizer_conf->mode; // 1 or 2
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
delete gpu_index;
std::shared_ptr<FaissIVFQuantizer> q;
q->quantizer = index_composition;
return q;
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
}
void IVFSQHybrid::SetQuantizer(QuantizerPtr q) {
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(q);
if (ivf_quantizer == nullptr) {
KNOWHERE_THROW_MSG("Quantizer type error");
}
if (ivf_quantizer->quantizer->mode == 2) gpu_mode = true; // all in gpu
faiss::IndexIVF *ivf_index =
dynamic_cast<faiss::IndexIVF *>(index_.get());
faiss::gpu::GpuIndexFlat *is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat *>(ivf_index->quantizer);
if (is_gpu_flat_index == nullptr) {
delete ivf_index->quantizer;
ivf_index->quantizer = ivf_quantizer->quantizer->quantizer;
}
}
}
}
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <faiss/index_io.h>
#include "IndexGPUIVFSQ.h"
#include "Quantizer.h"
namespace zilliz {
namespace knowhere {
struct FaissIVFQuantizer : public Quantizer {
faiss::IndexComposition *quantizer = nullptr;
};
using FaissIVFQuantizerPtr = std::shared_ptr<FaissIVFQuantizer>;
class IVFSQHybrid : public GPUIVFSQ {
public:
explicit IVFSQHybrid(const int &device_id) : GPUIVFSQ(device_id) {}
explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index) : GPUIVFSQ(-1) {gpu_mode = false;}
explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
: GPUIVFSQ(index, device_id, resource) {
gpu_mode = true;
}
public:
QuantizerPtr
LoadQuantizer(const Config &conf);
void
SetQuantizer(QuantizerPtr q);
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
VectorIndexPtr
CopyGpuToCpu(const Config &config) override;
VectorIndexPtr
CopyCpuToGpu(const int64_t &device_id, const Config &config) override;
protected:
void
search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) override;
void LoadImpl(const BinarySet &index_binary) override;
protected:
bool gpu_mode = false;
};
}
}
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
namespace zilliz {
namespace knowhere {
struct Quantizer {
virtual ~Quantizer() = default;
};
using QuantizerPtr = std::shared_ptr<Quantizer>;
struct QuantizerCfg : Cfg {
uint64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
};
using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册