未验证 提交 93ea9dd2 编写于 作者: X xujiaqi01 提交者: GitHub

fix stat var in hogwild worker (#23367)

* fix stat var in hogwild worker
* test=develop
上级 8c463700
......@@ -37,6 +37,10 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
dump_fields_[i] = desc.dump_fields(i);
}
for (int i = 0; i < param_.stat_var_names_size(); ++i) {
stat_var_name_map_[param_.stat_var_names(i)] = 1;
}
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
......
......@@ -59,7 +59,10 @@ message TrainerDesc {
optional DataFeedDesc data_desc = 201;
}
message HogwildWorkerParameter { repeated string skip_ops = 1; }
message HogwildWorkerParameter {
repeated string skip_ops = 1;
repeated string stat_var_names = 2;
}
message DownpourWorkerParameter {
repeated TableParameter sparse_table = 1;
......
......@@ -102,6 +102,7 @@ class Hogwild(DeviceWorker):
program_configs = opt_info["program_configs"]
downpour = trainer_desc.downpour_param
hogwild = trainer_desc.hogwild_param
for pid in program_configs:
if pid == program_id:
......@@ -154,6 +155,7 @@ class Hogwild(DeviceWorker):
sparse_table.label_var_name = ""
if opt_info["stat_var_names"]:
for i in opt_info["stat_var_names"]:
hogwild.stat_var_names.extend([i])
downpour.stat_var_names.extend([i])
for i in worker.get_desc().dense_table:
......@@ -163,10 +165,10 @@ class Hogwild(DeviceWorker):
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(worker.get_desc().skip_op)
hogwild.skip_ops.extend(worker.get_desc().skip_op)
if self._infer:
downpour.push_dense = False
downpour.push_sparse = False
hogwild.skip_ops.extend(
["push_sparse", "push_sparse_v2", "push_dense"])
class DownpourSGD(DeviceWorker):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册