diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c745905aad89d8f2f69e8a2dd17400cbbef9de9b..b2c11df7e0d7e715dde5af39ae198575c3bea086 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -528,6 +528,7 @@ paddle.fluid.optimizer.LambOptimizer.minimize (ArgSpec(args=['self', 'loss', 'st paddle.fluid.optimizer.ExponentialMovingAverage.__init__ (ArgSpec(args=['self', 'decay', 'thres_steps', 'name'], varargs=None, keywords=None, defaults=(0.999, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.ExponentialMovingAverage.apply (ArgSpec(args=['self', 'executor', 'need_restore'], varargs=None, keywords=None, defaults=(True,)), ('document', '30f494752ac8921dc5835a63637f453a')) paddle.fluid.optimizer.ExponentialMovingAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '8c8a1791608b02a1ede53d6dd3a4fcec')) +paddle.fluid.optimizer.ExponentialMovingAverage.update (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'ea10f08af6d7aac3b7974aa976e4085f')) paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '08a5dd9f6f376ff3d55e0b1d92115cbd')) paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 50705bba2e93ece695991212f635ab7ab8010b8c..a85344ecf216aa8dbe28287fa2797b4646ef5519 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2333,10 +2333,10 @@ class ExponentialMovingAverage(object): \\text{EMA}_t & = \\text{decay} * \\text{EMA}_{t-1} + (1 - \\text{decay}) * \\theta_t - The average results will be saved in temporary variables which are created - and maintained by the object, and can be applied to parameters of current - model by calling **apply()** method. And the **restore()** method is used to - restore the parameters. + The average results calculated by **update()** method will be saved in + temporary variables which are created and maintained by the object, and can + be applied to parameters of current model by calling **apply()** method. And + the **restore()** method is used to restore the parameters. **Bias correction**. All EMAs are initialized to :math:`0` and hence they will be zero biased, which can be corrected by divided by a factor @@ -2382,6 +2382,7 @@ class ExponentialMovingAverage(object): global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter() ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps) + ema.update() # pseudo code for pass_id in range(args.pass_num): @@ -2407,7 +2408,7 @@ class ExponentialMovingAverage(object): self._name = name if name is not None else '' self._decay_var = self._get_ema_decay() - self.params_tmps = [] + self._params_tmps = [] for param in default_main_program().global_block().all_parameters(): if param.do_model_average != False: tmp = param.block.create_var( @@ -2416,22 +2417,22 @@ class ExponentialMovingAverage(object): dtype=param.dtype, persistable=False, stop_gradient=True) - self.params_tmps.append((param, tmp)) + self._params_tmps.append((param, tmp)) - ema_vars = {} - for param, tmp in self.params_tmps: + self._ema_vars = {} + for param, tmp in self._params_tmps: with param.block.program._optimized_guard( [param, tmp]), name_scope('moving_average'): - ema_vars[param.name] = self._append_ema_ops(param) + self._ema_vars[param.name] = self._create_ema_vars(param) self.apply_program = Program() block = self.apply_program.global_block() with program_guard(main_program=self.apply_program): decay_pow = self._get_decay_pow(block) - for param, tmp in self.params_tmps: + for param, tmp in self._params_tmps: param = block._clone_variable(param) tmp = block._clone_variable(tmp) - ema = block._clone_variable(ema_vars[param.name]) + ema = block._clone_variable(self._ema_vars[param.name]) layers.assign(input=param, output=tmp) # bias correction ema = ema / (1.0 - decay_pow) @@ -2440,7 +2441,7 @@ class ExponentialMovingAverage(object): self.restore_program = Program() block = self.restore_program.global_block() with program_guard(main_program=self.restore_program): - for param, tmp in self.params_tmps: + for param, tmp in self._params_tmps: tmp = block._clone_variable(tmp) param = block._clone_variable(param) layers.assign(input=tmp, output=param) @@ -2472,7 +2473,7 @@ class ExponentialMovingAverage(object): decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1) return decay_pow_acc - def _append_ema_ops(self, param): + def _create_ema_vars(self, param): param_ema = layers.create_global_var( name=unique_name.generate(self._name + param.name + '_ema'), shape=param.shape, @@ -2480,10 +2481,21 @@ class ExponentialMovingAverage(object): dtype=param.dtype, persistable=True) - ema_t = param_ema * self._decay_var + param * (1 - self._decay_var) - layers.assign(input=ema_t, output=param_ema) return param_ema + def update(self): + """ + Update Exponential Moving Average. Should only call this method in + train program. + """ + for param, tmp in self._params_tmps: + with param.block.program._optimized_guard( + [param, tmp]), name_scope('moving_average'): + param_ema = self._ema_vars[param.name] + ema_t = param_ema * self._decay_var + param * (1 - + self._decay_var) + layers.assign(input=ema_t, output=param_ema) + @signature_safe_contextmanager def apply(self, executor, need_restore=True): """