diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index a08db28f51be0e556eb5ca7de5fb03389710934a..db6231e99193d96714e4205fa174444a4ffede83 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -37,6 +37,10 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { dump_fields_[i] = desc.dump_fields(i); } + for (int i = 0; i < param_.stat_var_names_size(); ++i) { + stat_var_name_map_[param_.stat_var_names(i)] = 1; + } + need_dump_param_ = false; dump_param_.resize(desc.dump_param_size()); for (int i = 0; i < desc.dump_param_size(); ++i) { diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 43b2dd63e4035e8585549a0bb094dd9b14e5f52b..f442063313f03321931112ed293ccdf8ebabeb89 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -59,7 +59,10 @@ message TrainerDesc { optional DataFeedDesc data_desc = 201; } -message HogwildWorkerParameter { repeated string skip_ops = 1; } +message HogwildWorkerParameter { + repeated string skip_ops = 1; + repeated string stat_var_names = 2; +} message DownpourWorkerParameter { repeated TableParameter sparse_table = 1; diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 3c265a8f567feff5ca8fce797de7095ca1fcc1b4..1035e405a3710ea15fe28a453c63f5efd15337e6 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -102,6 +102,7 @@ class Hogwild(DeviceWorker): program_configs = opt_info["program_configs"] downpour = trainer_desc.downpour_param + hogwild = trainer_desc.hogwild_param for pid in program_configs: if pid == program_id: @@ -154,6 +155,7 @@ class Hogwild(DeviceWorker): sparse_table.label_var_name = "" if opt_info["stat_var_names"]: for i in opt_info["stat_var_names"]: + hogwild.stat_var_names.extend([i]) downpour.stat_var_names.extend([i]) for i in worker.get_desc().dense_table: @@ -163,10 +165,10 @@ class Hogwild(DeviceWorker): dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_grad_name.extend( i.dense_gradient_variable_name) - downpour.skip_ops.extend(worker.get_desc().skip_op) + hogwild.skip_ops.extend(worker.get_desc().skip_op) if self._infer: - downpour.push_dense = False - downpour.push_sparse = False + hogwild.skip_ops.extend( + ["push_sparse", "push_sparse_v2", "push_dense"]) class DownpourSGD(DeviceWorker):