提交 d4ea0151 编写于 作者: J jinhai

Merge branch 'ms544' into 'branch-0.4.0'

MS-544 fix

See merge request megasearch/milvus!552

Former-commit-id: 2c99ce1fce5f15e36cfa854d9d688bc52fc68a02
......@@ -26,17 +26,17 @@ namespace knowhere {
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
auto nlist = config["nlist"].as<size_t>();
auto gpu_device = config.get_with_default("gpu_id", gpu_id_);
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
auto metric_type = config["metric_type"].as_string() == "L2" ?
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
GETTENSOR(dataset)
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(gpu_device, temp_resource);
ResScope rs(gpu_id_, temp_resource);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_device;
idx_config.device = gpu_id_;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float *) p_data);
......@@ -204,7 +204,7 @@ VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
auto nlist = config["nlist"].as<size_t>();
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16
auto gpu_num = config.get_with_default("gpu_id", gpu_id_);
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
auto metric_type = config["metric_type"].as_string() == "L2" ?
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
......@@ -214,10 +214,10 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
index_type << "IVF" << nlist << "," << "SQ" << nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_num);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(gpu_num, temp_resource );
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_num, build_index);
ResScope rs(gpu_id_, temp_resource );
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册