未验证 提交 5d4a1106 编写于 作者: W whs 提交者: GitHub

Fix model average on multi-GPUs. (#11814) (#11826)

* Fix average_accumulate_op for parallel executor.

* Fix model average on multi-GPUs.
上级 494cecd6
...@@ -1113,7 +1113,6 @@ class ModelAverage(Optimizer): ...@@ -1113,7 +1113,6 @@ class ModelAverage(Optimizer):
Args: Args:
average_window_rate: The rate of average window. average_window_rate: The rate of average window.
params_grads: A list of parameter-grad variable pairs.
min_average_window: The minimum size of average window. min_average_window: The minimum size of average window.
max_average_window: The maximum size of average window. max_average_window: The maximum size of average window.
...@@ -1122,8 +1121,8 @@ class ModelAverage(Optimizer): ...@@ -1122,8 +1121,8 @@ class ModelAverage(Optimizer):
.. code-block:: python .. code-block:: python
optimizer = fluid.optimizer.Momentum() optimizer = fluid.optimizer.Momentum()
_, params_grads = optimizer.minimize(cost) optimizer.minimize(cost)
model_average = fluid.optimizer.ModelAverage(params_grads, 0.15, model_average = fluid.optimizer.ModelAverage(0.15,
min_average_window=10000, min_average_window=10000,
max_average_window=20000) max_average_window=20000)
for pass_id in range(args.pass_num): for pass_id in range(args.pass_num):
...@@ -1137,7 +1136,6 @@ class ModelAverage(Optimizer): ...@@ -1137,7 +1136,6 @@ class ModelAverage(Optimizer):
def __init__(self, def __init__(self,
average_window_rate, average_window_rate,
params_grads=None,
min_average_window=10000, min_average_window=10000,
max_average_window=10000, max_average_window=10000,
**kwargs): **kwargs):
...@@ -1146,21 +1144,16 @@ class ModelAverage(Optimizer): ...@@ -1146,21 +1144,16 @@ class ModelAverage(Optimizer):
self.min_average_window = min_average_window self.min_average_window = min_average_window
self.max_average_window = max_average_window self.max_average_window = max_average_window
self.params_grads = [] if params_grads is None else params_grads self.params_grads = []
params = {}
for param, grad in self.params_grads:
if param.do_model_average != False:
params[param.name] = (param, grad)
for param in framework.default_main_program().global_block( for param in framework.default_main_program().global_block(
).all_parameters(): ).all_parameters():
if param.name not in params and param.do_model_average != False: if param.do_model_average != False:
grad = param.block.create_var( grad = param.block.create_var(
name=unique_name.generate(".".join([param.name, 'tmp'])), name=unique_name.generate(".".join([param.name, 'tmp'])),
dtype=param.dtype, dtype=param.dtype,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True)
params[param.name] = (param, grad) self.params_grads.append((param, grad))
self.params_grads = params.values()
for param, grad in self.params_grads: for param, grad in self.params_grads:
self._append_average_accumulate_op(param) self._append_average_accumulate_op(param)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册