From fb4d5689eae67e5e859ed19619eb49b5acc76a1e Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 17 Sep 2021 02:09:33 -0500 Subject: [PATCH] Support EMA in Paddle2.x and Fleet (#35673) * Support EMA in Paddle2.x and Fleet * update * update * update * modify ut of ema * modify docs * modify bugs * update * update * update * modify ut --- python/paddle/fluid/optimizer.py | 101 +++++++++--------- .../fluid/tests/unittests/test_ema_fleet.py | 97 +++++++++++++++++ python/paddle/static/__init__.py | 2 + tools/parallel_UT_rule.py | 3 +- tools/static_mode_white_list.py | 1 + 5 files changed, 151 insertions(+), 53 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_ema_fleet.py diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 709b36ed8e3..8b2495fb2a9 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3959,62 +3959,59 @@ class ExponentialMovingAverage(object): Args: - decay (float, optional): The exponential decay rate, usually close to 1, such as - 0.999, 0.9999, ... . Default 0.999. - thres_steps (Variable|None): If not `None`, schedule the decay rate. - Default None. - name (str|None): For detailed information, please refer to - :ref:`api_guide_Name`. Usually name is no need to set and None by - default. + decay (float, optional): The exponential decay rate, usually close to 1, such as 0.999, 0.9999, ... . Default 0.999. + thres_steps (Variable|None, optional): If not `None`, schedule the decay rate. Default None. + name (str|None, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. Examples: - .. code-block:: python - - import numpy - import paddle - import paddle.fluid as fluid - - data = fluid.data(name='x', shape=[-1, 5], dtype='float32') - hidden = fluid.layers.fc(input=data, size=10) - cost = fluid.layers.mean(hidden) - - test_program = fluid.default_main_program().clone(for_test=True) - - optimizer = fluid.optimizer.Adam(learning_rate=0.001) - optimizer.minimize(cost) - - global_steps = fluid.layers.autoincreased_step_counter() - ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps) - ema.update() - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - - for pass_id in range(3): - for batch_id in range(6): - data = numpy.random.random(size=(10, 5)).astype('float32') - exe.run(program=fluid.default_main_program(), - feed={'x': data}, - fetch_list=[cost.name]) - - # usage 1 - with ema.apply(exe): - data = numpy.random.random(size=(10, 5)).astype('float32') - exe.run(program=test_program, - feed={'x': data}, - fetch_list=[hidden.name]) - - - # usage 2 - with ema.apply(exe, need_restore=False): - data = numpy.random.random(size=(10, 5)).astype('float32') - exe.run(program=test_program, - feed={'x': data}, - fetch_list=[hidden.name]) - ema.restore(exe) + .. code-block:: python + + import numpy + import paddle + import paddle.static as static + from paddle.static import ExponentialMovingAverage + + paddle.enable_static() + + data = static.data(name='x', shape=[-1, 5], dtype='float32') + hidden = static.nn.fc(x=data, size=10) + cost = paddle.mean(hidden) + + test_program = static.default_main_program().clone(for_test=True) + optimizer = paddle.optimizer.Adam(learning_rate=0.001) + optimizer.minimize(cost) + + ema = ExponentialMovingAverage(0.999) + ema.update() + + place = paddle.CPUPlace() + exe = static.Executor(place) + exe.run(static.default_startup_program()) + + for pass_id in range(3): + for batch_id in range(6): + data = numpy.random.random(size=(10, 5)).astype('float32') + exe.run(program=static.default_main_program(), + feed={'x': data}, + fetch_list=[cost.name]) + + # usage 1 + with ema.apply(exe): + data = numpy.random.random(size=(10, 5)).astype('float32') + exe.run(program=test_program, + feed={'x': data}, + fetch_list=[hidden.name]) + + # usage 2 + with ema.apply(exe, need_restore=False): + data = numpy.random.random(size=(10, 5)).astype('float32') + exe.run(program=test_program, + feed={'x': data}, + fetch_list=[hidden.name]) + ema.restore(exe) + """ def __init__(self, decay=0.999, thres_steps=None, name=None): diff --git a/python/paddle/fluid/tests/unittests/test_ema_fleet.py b/python/paddle/fluid/tests/unittests/test_ema_fleet.py new file mode 100644 index 00000000000..e0526deb59a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ema_fleet.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.utils as utils +import paddle.static as static + + +def gen_data(): + return np.random.random(size=(10, 5)).astype('float32') + + +class TestFleetStaticEMA(unittest.TestCase): + def setUp(self): + self._places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + self._places.append(paddle.CUDAPlace(0)) + self._ema_decay = 0.999 + self._param_name = "fc.weight" + self._train_program = static.Program() + self._startup_prog = static.Program() + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.without_graph_optimization = True + paddle.distributed.fleet.init(is_collective=True, strategy=strategy) + + with static.program_guard(self._train_program, self._startup_prog): + with utils.unique_name.guard(): + data = static.data(name='x', shape=[-1, 5], dtype='float32') + hidden = static.nn.fc(x=data, + size=10, + weight_attr=self._param_name) + cost = paddle.mean(hidden) + + self._test_program = static.default_main_program().clone( + for_test=True) + + optimizer = paddle.optimizer.Adam(learning_rate=0.001) + optimizer = paddle.distributed.fleet.distributed_optimizer( + optimizer, strategy) + optimizer.minimize(cost) + + self._ema = static.ExponentialMovingAverage(self._ema_decay) + self._ema.update() + + def train(self, place, restore): + exe = static.Executor(place) + exe.run(self._startup_prog) + + params = [] + for pass_id in range(2): + for batch_id in range(3): + exe.run(program=self._train_program, feed={'x': gen_data()}) + tmp_param = np.array(static.global_scope().find_var( + self._param_name).get_tensor()) + params.append(tmp_param) + + with self._ema.apply(exe, restore): + final_ema = np.array(static.global_scope().find_var( + self._param_name).get_tensor()) + exe.run(program=self._test_program, feed={'x': gen_data()}) + if not restore: + self._ema.restore(exe) + + return params, final_ema + + def test_check_ema(self): + for place in self._places: + for restore in (True, False): + params, final_ema = self.train(place, restore) + manu_ema = np.zeros_like(final_ema) + if len(params) > 0: + for param in params: + manu_ema = self._ema_decay * manu_ema + ( + 1 - self._ema_decay) * param + manu_ema = manu_ema / (1.0 - self._ema_decay**len(params)) + self.assertTrue(np.allclose(manu_ema, final_ema)) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 93f34b22979..0f463b0c7d9 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -48,6 +48,7 @@ from ..fluid.layers.control_flow import Print # noqa: F401 from ..fluid.layers.nn import py_func # noqa: F401 from ..fluid.parallel_executor import ParallelExecutor # noqa: F401 from ..fluid.param_attr import WeightNormParamAttr # noqa: F401 +from ..fluid.optimizer import ExponentialMovingAverage # noqa: F401 from ..fluid.io import save # noqa: F401 from ..fluid.io import load # noqa: F401 from ..fluid.io import load_program_state # noqa: F401 @@ -76,6 +77,7 @@ __all__ = [ #noqa 'ParallelExecutor', 'program_guard', 'WeightNormParamAttr', + 'ExponentialMovingAverage', 'default_main_program', 'default_startup_program', 'Program', diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index b5e12d6f96d..54e8d608ac6 100644 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -528,7 +528,7 @@ TETRAD_PARALLEL_JOB_NEW = [ 'test_trunc_op', 'test_bernoulli_op', 'test_custom_relu_model', 'test_backward', 'test_conv3d_transpose_part2_op', 'test_complex_transpose', 'test_memory_reuse_exclude_feed_var', 'test_polygon_box_transform', - 'math_function_gpu_test', 'test_program_prune_backward', + 'math_function_gpu_test', 'test_program_prune_backward', 'test_ema_fleet', 'test_fleet_amp_init', 'test_normalize', 'test_correlation', 'test_conv_elementwise_add2_act_fuse_pass', 'test_imperative_container_layerlist', 'test_dequantize_abs_max_op', @@ -1324,6 +1324,7 @@ TWO_PARALLEL_JOB = [ 'test_slice_op', 'test_cond', 'test_ema', + 'test_ema_fleet', 'test_nan_inf', 'test_isinstance', 'test_box_clip_op', diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 5fa3a25f4ca..43281d4375e 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -173,6 +173,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_elementwise_nn_grad', 'test_elementwise_pow_op', 'test_ema', + 'test_ema_fleet', 'test_embedding_id_stop_gradient', 'test_empty_like_op', 'test_entry_attr', -- GitLab