diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index ace533cb0c745897bf86d2bce476b3227209f30f..a7333cd01c6ec224104f38ad43666cd01ab2cd14 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -426,16 +426,26 @@ int HeterComm::get_index_by_devid(int devid) { template void HeterComm::set_sparse_sgd( const OptimizerConfig& optimizer_config) { - for (auto& table : tables_) { - table->set_sparse_sgd(optimizer_config); + for (int i = 0; i < resource_->total_device(); ++i) { + AnyDeviceGuard guard(resource_->dev_id(i)); + if (!multi_mf_dim_) { + tables_[i]->set_sparse_sgd(optimizer_config); + } else { + ptr_tables_[i]->set_sparse_sgd(optimizer_config); + } } } template void HeterComm::set_embedx_sgd( const OptimizerConfig& optimizer_config) { - for (auto& table : tables_) { - table->set_embedx_sgd(optimizer_config); + for (int i = 0; i < resource_->total_device(); ++i) { + AnyDeviceGuard guard(resource_->dev_id(i)); + if (!multi_mf_dim_) { + tables_[i]->set_embedx_sgd(optimizer_config); + } else { + ptr_tables_[i]->set_embedx_sgd(optimizer_config); + } } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index fae30a45d2e5b7f6e781098a330716da748e9e22..65f86acce9151d8c776e7c67177f933a54fe87d7 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -323,26 +323,19 @@ class PSGPUWrapper { float mf_max_bound = (config.find("mf_max_bound") == config.end()) ? 1.0 : config["mf_max_bound"]; - for (size_t i = 0; i < heter_devices_.size(); i++) { -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); -#elif defined(PADDLE_WITH_XPU_KP) - PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i])); -#endif - this->SetSparseSGD(nonclk_coeff, - clk_coeff, - min_bound, - max_bound, - learning_rate, - initial_g2sum, - initial_range); - this->SetEmbedxSGD(mf_create_thresholds, - mf_learning_rate, - mf_initial_g2sum, - mf_initial_range, - mf_min_bound, - mf_max_bound); - } + this->SetSparseSGD(nonclk_coeff, + clk_coeff, + min_bound, + max_bound, + learning_rate, + initial_g2sum, + initial_range); + this->SetEmbedxSGD(mf_create_thresholds, + mf_learning_rate, + mf_initial_g2sum, + mf_initial_range, + mf_min_bound, + mf_max_bound); } void SetDate(int year, int month, int day) {