未验证 提交 07b68eb3 编写于 作者: D danleifeng 提交者: GitHub

[gpups]fix sparse config work (#44090)

上级 953024ff
...@@ -426,16 +426,26 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) { ...@@ -426,16 +426,26 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::set_sparse_sgd( void HeterComm<KeyType, ValType, GradType>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) { const OptimizerConfig& optimizer_config) {
for (auto& table : tables_) { for (int i = 0; i < resource_->total_device(); ++i) {
table->set_sparse_sgd(optimizer_config); AnyDeviceGuard guard(resource_->dev_id(i));
if (!multi_mf_dim_) {
tables_[i]->set_sparse_sgd(optimizer_config);
} else {
ptr_tables_[i]->set_sparse_sgd(optimizer_config);
}
} }
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::set_embedx_sgd( void HeterComm<KeyType, ValType, GradType>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) { const OptimizerConfig& optimizer_config) {
for (auto& table : tables_) { for (int i = 0; i < resource_->total_device(); ++i) {
table->set_embedx_sgd(optimizer_config); AnyDeviceGuard guard(resource_->dev_id(i));
if (!multi_mf_dim_) {
tables_[i]->set_embedx_sgd(optimizer_config);
} else {
ptr_tables_[i]->set_embedx_sgd(optimizer_config);
}
} }
} }
......
...@@ -323,26 +323,19 @@ class PSGPUWrapper { ...@@ -323,26 +323,19 @@ class PSGPUWrapper {
float mf_max_bound = (config.find("mf_max_bound") == config.end()) float mf_max_bound = (config.find("mf_max_bound") == config.end())
? 1.0 ? 1.0
: config["mf_max_bound"]; : config["mf_max_bound"];
for (size_t i = 0; i < heter_devices_.size(); i++) { this->SetSparseSGD(nonclk_coeff,
#ifdef PADDLE_WITH_CUDA clk_coeff,
PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); min_bound,
#elif defined(PADDLE_WITH_XPU_KP) max_bound,
PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i])); learning_rate,
#endif initial_g2sum,
this->SetSparseSGD(nonclk_coeff, initial_range);
clk_coeff, this->SetEmbedxSGD(mf_create_thresholds,
min_bound, mf_learning_rate,
max_bound, mf_initial_g2sum,
learning_rate, mf_initial_range,
initial_g2sum, mf_min_bound,
initial_range); mf_max_bound);
this->SetEmbedxSGD(mf_create_thresholds,
mf_learning_rate,
mf_initial_g2sum,
mf_initial_range,
mf_min_bound,
mf_max_bound);
}
} }
void SetDate(int year, int month, int day) { void SetDate(int year, int month, int day) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册