提交 12dd738e 编写于 作者: X xj.lin

MS-544 fix


Former-commit-id: 2d9d4afb74ce2522b8cf5bf29ade982d3d6722ad
上级 a16b6a4f
......@@ -26,17 +26,17 @@ namespace knowhere {
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
auto nlist = config["nlist"].as<size_t>();
auto gpu_device = config.get_with_default("gpu_id", gpu_id_);
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
auto metric_type = config["metric_type"].as_string() == "L2" ?
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
GETTENSOR(dataset)
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(gpu_device, temp_resource);
ResScope rs(gpu_id_, temp_resource);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_device;
idx_config.device = gpu_id_;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float *) p_data);
......@@ -204,7 +204,7 @@ VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
auto nlist = config["nlist"].as<size_t>();
auto nbits = config["nbits"].as<size_t>(); // TODO(linxj): gpu only support SQ4 SQ8 SQ16
auto gpu_num = config.get_with_default("gpu_id", gpu_id_);
gpu_id_ = config.get_with_default("gpu_id", gpu_id_);
auto metric_type = config["metric_type"].as_string() == "L2" ?
faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
......@@ -214,10 +214,10 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
index_type << "IVF" << nlist << "," << "SQ" << nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_num);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(gpu_num, temp_resource );
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_num, build_index);
ResScope rs(gpu_id_, temp_resource );
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册