From e1290c4fd7facfa9abfbb6e710ab3fa5f4ed3d10 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 28 Mar 2018 23:09:32 +0800 Subject: [PATCH] Make Average Model support for 'moving mean' and 'moving variance' of batch_normal op --- python/paddle/fluid/optimizer.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 180575c35dc..d21320f7058 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -850,23 +850,39 @@ class ModelAverage(Optimizer): self.min_average_window = min_average_window self.max_average_window = max_average_window self.params_grads = params_grads + + # append 'moving mean' and 'moving variance' to self.params_grads + pattern = re.compile(r"batch_norm_\d+\.w_[1,2]") + for param in framework.default_main_program().global_block( + ).all_parameters(): + if pattern.match(param.name) is not None: + self.params_grads.append((param, None)) + # create a tmp gradient variable to backup parameter value + # for parameter whose grad is None + for i, param_grad in enumerate(self.params_grads): + param, grad = param_grad + if grad is None: + grad = param.block.create_var( + name=unique_name.generate(".".join([param.name, 'tmp'])), + dtype=param.dtype, + persistable=False, + stop_gradient=stop_gradient) + self.params_grads[i] = (param, grad) + for param, grad in self.params_grads: - if grad is not None: - self._append_average_accumulate_op(param) + self._append_average_accumulate_op(param) self.apply_program = Program() block = self.apply_program.global_block() with program_guard(main_program=self.apply_program): for param_grad in self.params_grads: - if param_grad[1] is not None: - self._add_average_apply_op(block, param_grad) + self._add_average_apply_op(block, param_grad) self.restore_program = Program() block = self.restore_program.global_block() with program_guard(main_program=self.restore_program): for param_grad in self.params_grads: - if param_grad[1] is not None: - self._add_average_restore_op(block, param_grad) + self._add_average_restore_op(block, param_grad) def _add_average_apply_op(self, block, param_grad): param = block.clone_variable(param_grad[0]) -- GitLab