未验证 提交 6e11f977 编写于 作者: Y Yibing Liu 提交者: GitHub

Add exponential moving average (#17562)

* Add exponential moving average

test=develop, test=document_preview

* Polish documents

test=develop, test=document_preview

* Update API spec

test=develop, test=document_preview
上级 0600b370
......@@ -522,6 +522,9 @@ paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss
paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LambOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LambOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.ExponentialMovingAverage.__init__ (ArgSpec(args=['self', 'decay', 'zero_init', 'name'], varargs=None, keywords=None, defaults=(0.999, False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.ExponentialMovingAverage.apply (ArgSpec(args=['self', 'executor', 'need_restore'], varargs=None, keywords=None, defaults=(True,)), ('document', '30f494752ac8921dc5835a63637f453a'))
paddle.fluid.optimizer.ExponentialMovingAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '8c8a1791608b02a1ede53d6dd3a4fcec'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '08a5dd9f6f376ff3d55e0b1d92115cbd'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
......@@ -41,7 +41,8 @@ __all__ = [
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer'
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
'ExponentialMovingAverage'
]
......@@ -1999,10 +2000,10 @@ Lamb = LambOptimizer
class ModelAverage(Optimizer):
"""Accumulate the average of parameters whtin sliding window. The average
"""Accumulate the average of parameters within sliding window. The average
result will be saved in temporary variables which can be applied to
parameter variables of current model by calling 'apply()' method. And the
'restore()' method is used to restored the parameter values of current model.
'restore()' method is used to restore the parameter values of current model.
The size of average window is determined by average_window_rate,
min_average_window, max_average_window and current update times.
......@@ -2155,3 +2156,140 @@ class ModelAverage(Optimizer):
"""Restore parameter values of current model.
"""
executor.run(self.restore_program)
class ExponentialMovingAverage(object):
"""
Compute the moving average of parameters with exponential decay.
Given a parameter :math:`\\theta`, its exponential moving average (EMA)
will be
.. math::
\\text{EMA}_t = \\text{decay} * \\text{EMA}_{t-1} + (1 - \\text{decay}) * \\theta_t
The average results will be saved in temporary variables which can be
applied to parameters of current model by calling `apply()` method. And
the `restore()` method is used to restore the parameters.
Args:
decay (float|Variable): The exponential decay rate. Can be scheduled like
learning rate.
zero_init (bool): Whether using zero to initialize EMA Variable. If set to
`True`, :math:`\\text{EMA}_0 = 0.0` else :math:`\\text{EMA}_0 = \\theta_0`.
name (str|None): An optional name prefix.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='x', shape=[5], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
cost = fluid.layers.mean(hidden)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(cost)
ema = fluid.optimizer.ExponentialMovingAverage(0.99)
# pseudo code
for pass_id in range(args.pass_num):
for data in train_reader():
exe.run(fluid.default_main_program()...)
with ema.apply(exe):
for data in test_reader():
exe.run(inference_program...)
"""
def __init__(self, decay=0.999, zero_init=False, name=None):
self._decay = decay
self._zero_init = zero_init
self._name = name if name is not None else ''
self.params_tmps = []
for param in framework.default_main_program().global_block(
).all_parameters():
if param.do_model_average != False:
tmp = param.block.create_var(
name=unique_name.generate(".".join(
[self._name + param.name, 'ema_tmp'])),
dtype=param.dtype,
persistable=False,
stop_gradient=True)
self.params_tmps.append((param, tmp))
startup_block = default_startup_program().global_block()
ema_vars = {}
for param, tmp in self.params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
ema_vars[param.name] = self._append_ema_ops(startup_block,
param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
for param, tmp in self.params_tmps:
param = block._clone_variable(param)
tmp = block._clone_variable(tmp)
ema = block._clone_variable(ema_vars[param.name])
layers.assign(input=param, output=tmp)
layers.assign(input=ema, output=param)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param, tmp in self.params_tmps:
tmp = block._clone_variable(tmp)
param = block._clone_variable(param)
layers.assign(input=tmp, output=param)
def _append_ema_ops(self, startup_block, param):
param_ema = layers.create_global_var(
name=unique_name.generate(self._name + param.name + '_ema'),
shape=param.shape,
value=0.0,
dtype=param.dtype,
persistable=True)
# t = 0
if self._zero_init is not True:
startup_p_ema = startup_block._clone_variable(param_ema)
startup_p = startup_block.var(param.name)
startup_block.append_op(
type="assign",
inputs={"X": startup_p},
outputs={"Out": startup_p_ema})
# t > 0
ema_t = param_ema * self._decay - param * (self._decay - 1)
layers.assign(input=ema_t, output=param_ema)
return param_ema
@signature_safe_contextmanager
def apply(self, executor, need_restore=True):
"""
Apply moving average to parameters for evaluation.
Args:
executor (Executor): The Executor to execute applying.
need_restore (bool): Whether to restore parameters after applying.
"""
executor.run(self.apply_program)
try:
yield
finally:
if need_restore:
self.restore(executor)
def restore(self, executor):
"""Restore parameters.
Args:
executor (Executor): The Executor to execute restoring.
"""
executor.run(self.restore_program)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册