提交 e1290c4f 编写于 作者: W wanghaoshuang

Make Average Model support for 'moving mean' and 'moving variance' of batch_normal op

上级 123cf165
......@@ -850,22 +850,38 @@ class ModelAverage(Optimizer):
self.min_average_window = min_average_window
self.max_average_window = max_average_window
self.params_grads = params_grads
# append 'moving mean' and 'moving variance' to self.params_grads
pattern = re.compile(r"batch_norm_\d+\.w_[1,2]")
for param in framework.default_main_program().global_block(
).all_parameters():
if pattern.match(param.name) is not None:
self.params_grads.append((param, None))
# create a tmp gradient variable to backup parameter value
# for parameter whose grad is None
for i, param_grad in enumerate(self.params_grads):
param, grad = param_grad
if grad is None:
grad = param.block.create_var(
name=unique_name.generate(".".join([param.name, 'tmp'])),
dtype=param.dtype,
persistable=False,
stop_gradient=stop_gradient)
self.params_grads[i] = (param, grad)
for param, grad in self.params_grads:
if grad is not None:
self._append_average_accumulate_op(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
for param_grad in self.params_grads:
if param_grad[1] is not None:
self._add_average_apply_op(block, param_grad)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param_grad in self.params_grads:
if param_grad[1] is not None:
self._add_average_restore_op(block, param_grad)
def _add_average_apply_op(self, block, param_grad):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册