From 93ea9dd27a2ad4ca174c26381e67d2e373f8d82e Mon Sep 17 00:00:00 2001 From: xujiaqi01 <173596896@qq.com> Date: Thu, 2 Apr 2020 11:07:15 +0800 Subject: [PATCH] fix stat var in hogwild worker (#23367) * fix stat var in hogwild worker * test=develop --- paddle/fluid/framework/hogwild_worker.cc | 4 ++++ paddle/fluid/framework/trainer_desc.proto | 5 ++++- python/paddle/fluid/device_worker.py | 8 +++++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index a08db28f51b..db6231e9919 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -37,6 +37,10 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { dump_fields_[i] = desc.dump_fields(i); } + for (int i = 0; i < param_.stat_var_names_size(); ++i) { + stat_var_name_map_[param_.stat_var_names(i)] = 1; + } + need_dump_param_ = false; dump_param_.resize(desc.dump_param_size()); for (int i = 0; i < desc.dump_param_size(); ++i) { diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 43b2dd63e40..f442063313f 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -59,7 +59,10 @@ message TrainerDesc { optional DataFeedDesc data_desc = 201; } -message HogwildWorkerParameter { repeated string skip_ops = 1; } +message HogwildWorkerParameter { + repeated string skip_ops = 1; + repeated string stat_var_names = 2; +} message DownpourWorkerParameter { repeated TableParameter sparse_table = 1; diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 3c265a8f567..1035e405a37 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -102,6 +102,7 @@ class Hogwild(DeviceWorker): program_configs = opt_info["program_configs"] downpour = trainer_desc.downpour_param + hogwild = trainer_desc.hogwild_param for pid in program_configs: if pid == program_id: @@ -154,6 +155,7 @@ class Hogwild(DeviceWorker): sparse_table.label_var_name = "" if opt_info["stat_var_names"]: for i in opt_info["stat_var_names"]: + hogwild.stat_var_names.extend([i]) downpour.stat_var_names.extend([i]) for i in worker.get_desc().dense_table: @@ -163,10 +165,10 @@ class Hogwild(DeviceWorker): dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_grad_name.extend( i.dense_gradient_variable_name) - downpour.skip_ops.extend(worker.get_desc().skip_op) + hogwild.skip_ops.extend(worker.get_desc().skip_op) if self._infer: - downpour.push_dense = False - downpour.push_sparse = False + hogwild.skip_ops.extend( + ["push_sparse", "push_sparse_v2", "push_dense"]) class DownpourSGD(DeviceWorker): -- GitLab