From 412951d7d25c1dd91cccb7f674bd52a67c8ec592 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 23 Jun 2019 15:57:44 +0800 Subject: [PATCH] Fix ema's example & fp16 update (#18273) test=develop, test=document_preview --- python/paddle/fluid/optimizer.py | 96 +++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 33 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 67d12f0648f..4fa5738d327 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2472,36 +2472,50 @@ class ExponentialMovingAverage(object): Examples: .. code-block:: python - - import paddle.fluid as fluid - - data = fluid.layers.data(name='x', shape=[5], dtype='float32') - hidden = fluid.layers.fc(input=data, size=10) - cost = fluid.layers.mean(hidden) - - optimizer = fluid.optimizer.Adam(learning_rate=0.001) - optimizer.minimize(cost) - - global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter() - ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps) - ema.update() - - # pseudo code - for pass_id in range(args.pass_num): - for data in train_reader(): - exe.run(fluid.default_main_program()...) - - # usage 1 - with ema.apply(exe): - for data in test_reader(): - exe.run(inference_program...) - - # usage 2 - with ema.apply(exe, need_restore=False): - for data in test_reader(): - exe.run(inference_program...) - ... - ema.restore(exe) + + import numpy + import paddle + import paddle.fluid as fluid + + data = fluid.layers.data(name='x', shape=[5], dtype='float32') + hidden = fluid.layers.fc(input=data, size=10) + cost = fluid.layers.mean(hidden) + + test_program = fluid.default_main_program().clone(for_test=True) + + optimizer = fluid.optimizer.Adam(learning_rate=0.001) + optimizer.minimize(cost) + + global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter() + ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps) + ema.update() + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + for pass_id in range(3): + for batch_id in range(6): + data = numpy.random.random(size=(10, 5)).astype('float32') + exe.run(program=fluid.default_main_program(), + feed={'x': data}, + fetch_list=[cost.name]) + + # usage 1 + with ema.apply(exe): + data = numpy.random.random(size=(10, 5)).astype('float32') + exe.run(program=test_program, + feed={'x': data}, + fetch_list=[hidden.name]) + + + # usage 2 + with ema.apply(exe, need_restore=False): + data = numpy.random.random(size=(10, 5)).astype('float32') + exe.run(program=test_program, + feed={'x': data}, + fetch_list=[hidden.name]) + ema.restore(exe) """ def __init__(self, decay=0.999, thres_steps=None, name=None): @@ -2590,13 +2604,29 @@ class ExponentialMovingAverage(object): Update Exponential Moving Average. Should only call this method in train program. """ + param_master_emas = [] for param, tmp in self._params_tmps: with param.block.program._optimized_guard( [param, tmp]), name_scope('moving_average'): param_ema = self._ema_vars[param.name] - ema_t = param_ema * self._decay_var + param * (1 - - self._decay_var) - layers.assign(input=ema_t, output=param_ema) + if self._ema_vars.has_key(param.name + '.master'): + master_ema = self._ema_vars[param.name + '.master'] + param_master_emas.append([param_ema, master_ema]) + else: + ema_t = param_ema * self._decay_var + param * ( + 1 - self._decay_var) + layers.assign(input=ema_t, output=param_ema) + + # for fp16 params + for param_ema, master_ema in param_master_emas: + default_main_program().global_block().append_op( + type="cast", + inputs={"X": master_ema}, + outputs={"Out": param_ema}, + attrs={ + "in_dtype": master_ema.dtype, + "out_dtype": param_ema.dtype + }) @signature_safe_contextmanager def apply(self, executor, need_restore=True): -- GitLab