未验证 提交 07b68eb3 编写于 作者: D danleifeng 提交者: GitHub

[gpups]fix sparse config work (#44090)

上级 953024ff
......@@ -426,16 +426,26 @@ int HeterComm<KeyType, ValType, GradType>::get_index_by_devid(int devid) {
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) {
for (auto& table : tables_) {
table->set_sparse_sgd(optimizer_config);
for (int i = 0; i < resource_->total_device(); ++i) {
AnyDeviceGuard guard(resource_->dev_id(i));
if (!multi_mf_dim_) {
tables_[i]->set_sparse_sgd(optimizer_config);
} else {
ptr_tables_[i]->set_sparse_sgd(optimizer_config);
}
}
}
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) {
for (auto& table : tables_) {
table->set_embedx_sgd(optimizer_config);
for (int i = 0; i < resource_->total_device(); ++i) {
AnyDeviceGuard guard(resource_->dev_id(i));
if (!multi_mf_dim_) {
tables_[i]->set_embedx_sgd(optimizer_config);
} else {
ptr_tables_[i]->set_embedx_sgd(optimizer_config);
}
}
}
......
......@@ -323,26 +323,19 @@ class PSGPUWrapper {
float mf_max_bound = (config.find("mf_max_bound") == config.end())
? 1.0
: config["mf_max_bound"];
for (size_t i = 0; i < heter_devices_.size(); i++) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i]));
#elif defined(PADDLE_WITH_XPU_KP)
PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i]));
#endif
this->SetSparseSGD(nonclk_coeff,
clk_coeff,
min_bound,
max_bound,
learning_rate,
initial_g2sum,
initial_range);
this->SetEmbedxSGD(mf_create_thresholds,
mf_learning_rate,
mf_initial_g2sum,
mf_initial_range,
mf_min_bound,
mf_max_bound);
}
this->SetSparseSGD(nonclk_coeff,
clk_coeff,
min_bound,
max_bound,
learning_rate,
initial_g2sum,
initial_range);
this->SetEmbedxSGD(mf_create_thresholds,
mf_learning_rate,
mf_initial_g2sum,
mf_initial_range,
mf_min_bound,
mf_max_bound);
}
void SetDate(int year, int month, int day) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册