From 07b68eb34b41c21df900f05c9003f4224b52f441 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Wed, 6 Jul 2022 10:00:44 +0800 Subject: [PATCH] [gpups]fix sparse config work (#44090) --- .../framework/fleet/heter_ps/heter_comm_inl.h | 18 +++++++--- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 33 ++++++++----------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index ace533cb0c7..a7333cd01c6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -426,16 +426,26 @@ int HeterComm::get_index_by_devid(int devid) { template void HeterComm::set_sparse_sgd( const OptimizerConfig& optimizer_config) { - for (auto& table : tables_) { - table->set_sparse_sgd(optimizer_config); + for (int i = 0; i < resource_->total_device(); ++i) { + AnyDeviceGuard guard(resource_->dev_id(i)); + if (!multi_mf_dim_) { + tables_[i]->set_sparse_sgd(optimizer_config); + } else { + ptr_tables_[i]->set_sparse_sgd(optimizer_config); + } } } template void HeterComm::set_embedx_sgd( const OptimizerConfig& optimizer_config) { - for (auto& table : tables_) { - table->set_embedx_sgd(optimizer_config); + for (int i = 0; i < resource_->total_device(); ++i) { + AnyDeviceGuard guard(resource_->dev_id(i)); + if (!multi_mf_dim_) { + tables_[i]->set_embedx_sgd(optimizer_config); + } else { + ptr_tables_[i]->set_embedx_sgd(optimizer_config); + } } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index fae30a45d2e..65f86acce91 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -323,26 +323,19 @@ class PSGPUWrapper { float mf_max_bound = (config.find("mf_max_bound") == config.end()) ? 1.0 : config["mf_max_bound"]; - for (size_t i = 0; i < heter_devices_.size(); i++) { -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); -#elif defined(PADDLE_WITH_XPU_KP) - PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i])); -#endif - this->SetSparseSGD(nonclk_coeff, - clk_coeff, - min_bound, - max_bound, - learning_rate, - initial_g2sum, - initial_range); - this->SetEmbedxSGD(mf_create_thresholds, - mf_learning_rate, - mf_initial_g2sum, - mf_initial_range, - mf_min_bound, - mf_max_bound); - } + this->SetSparseSGD(nonclk_coeff, + clk_coeff, + min_bound, + max_bound, + learning_rate, + initial_g2sum, + initial_range); + this->SetEmbedxSGD(mf_create_thresholds, + mf_learning_rate, + mf_initial_g2sum, + mf_initial_range, + mf_min_bound, + mf_max_bound); } void SetDate(int year, int month, int day) { -- GitLab