From 69a3339aaa7429e72b9dc512143c101aad2ceeed Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 30 Sep 2020 18:04:11 +0800 Subject: [PATCH] Move dygraph amp api to paddle-2.0 (#27681) * move dygraph amp api to paddle * refine code and add unit test --- python/paddle/__init__.py | 1 + python/paddle/amp/__init__.py | 18 +++ python/paddle/amp/auto_cast.py | 52 +++++++ python/paddle/amp/grad_scaler.py | 136 ++++++++++++++++++ .../test_imperative_auto_mixed_precision.py | 78 ++++++++++ python/setup.py.in | 1 + 6 files changed, 286 insertions(+) create mode 100644 python/paddle/amp/__init__.py create mode 100644 python/paddle/amp/auto_cast.py create mode 100644 python/paddle/amp/grad_scaler.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 84713d513fb..3c52bbdccca 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -272,6 +272,7 @@ from .fluid.dygraph.base import no_grad_ as no_grad #DEFINE_ALIAS from . import jit from . import static +from . import amp # high-level api from .hapi import Model diff --git a/python/paddle/amp/__init__.py b/python/paddle/amp/__init__.py new file mode 100644 index 00000000000..32587938512 --- /dev/null +++ b/python/paddle/amp/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .auto_cast import auto_cast +from .grad_scaler import GradScaler + +__all__ = ['auto_cast', 'GradScaler'] diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py new file mode 100644 index 00000000000..e33f6e2afc8 --- /dev/null +++ b/python/paddle/amp/auto_cast.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.dygraph.amp import amp_guard + +__all__ = ['auto_cast'] + + +def auto_cast(enable=True, custom_white_list=None, custom_black_list=None): + """ + Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. + If enabled, the input data type (float32 or float16) of each operator is decided + by autocast algorithm for better performance. + + Commonly, it is used together with `AmpScaler` to achieve Auto-Mixed-Precision in + imperative mode. + + Args: + enable(bool, optional): Enable auto-mixed-precision or not. Default is True. + custom_white_list(set|list, optional): The custom white_list. + custom_black_list(set|list, optional): The custom black_list. + + Examples: + + .. code-block:: python + + import paddle + + conv2d = paddle.nn.Conv2d(3, 2, 3, bias_attr=False) + data = paddle.rand([10, 3, 32, 32]) + + with paddle.amp.auto_cast(): + conv = conv2d(data) + print(conv.dtype) # FP16 + + with paddle.amp.auto_cast(enable=False): + conv = conv2d(data) + print(conv.dtype) # FP32 + + """ + return amp_guard(enable, custom_white_list, custom_black_list) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py new file mode 100644 index 00000000000..9476f3765b3 --- /dev/null +++ b/python/paddle/amp/grad_scaler.py @@ -0,0 +1,136 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.dygraph.amp import AmpScaler + +__all__ = ['GradScaler'] + + +class GradScaler(AmpScaler): + """ + GradScaler is used for Auto-Mixed-Precision training/inferring in dynamic graph + mode. It controls the scaling of loss, helps avoiding numerical overflow. + The object of this class has two methods `scale()`, `minimize()`. + + `scale()` is used to multiply the loss by a scale ratio. + `minimize()` is similar as `Optimizer.minimize()`, performs parameters updating. + + Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in + dynamic graph mode. + + Args: + enable(bool, optional): Enable loss scaling or not. Default is True. + init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15. + incr_ratio(float, optional): The multiplier to use when increasing the loss + scaling. Default is 2.0. + decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing + the loss scaling. Default is 0.5. + incr_every_n_steps(int, optional): Increases loss scaling every n consecutive + steps with finite gradients. Default is 1000. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n + accumulated steps with nan or inf gradients. Default is 2. + use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. + Returns: + An AmpScaler object. + + Examples: + + .. code-block:: python + + import paddle + + model = paddle.nn.Conv2d(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + with paddle.amp.auto_cast(): + conv = model(data) + loss = paddle.reduce_mean(conv) + scaled = scaler.scale(loss) # scale the loss + scaled.backward() # do backward + scaler.minimize(optimizer, scaled) # update parameters + """ + + def __init__(self, + enable=True, + init_loss_scaling=2.**15, + incr_ratio=2.0, + decr_ratio=0.5, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=1, + use_dynamic_loss_scaling=True): + super(GradScaler, self).__init__(enable, init_loss_scaling, incr_ratio, + decr_ratio, incr_every_n_steps, + decr_every_n_nan_or_inf, + use_dynamic_loss_scaling) + + def scale(self, var): + """ + Multiplies a Tensor by the scale factor and returns scaled outputs. + If this instance of :class:`GradScaler` is not enabled, output are returned unmodified. + + Args: + var (Tensor): The tensor to scale. + Returns: + The scaled tensor or original tensor. + + Examples: + .. code-block:: python + + import paddle + + model = paddle.nn.Conv2d(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + with paddle.amp.auto_cast(): + conv = model(data) + loss = paddle.reduce_mean(conv) + scaled = scaler.scale(loss) # scale the loss + scaled.backward() # do backward + scaler.minimize(optimizer, scaled) # update parameters + """ + return super(GradScaler, self).scale(var) + + def minimize(self, optimizer, *args, **kwargs): + """ + This function is similar as `Optimizer.minimize()`, which performs parameters updating. + + If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped. + Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters. + + Finally, the loss scaling ratio is updated. + + Args: + optimizer(Optimizer): The optimizer used to update parameters. + args: Arguments, which will be forward to `optimizer.minimize()`. + kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`. + + Examples: + .. code-block:: python + + import paddle + + model = paddle.nn.Conv2d(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + with paddle.amp.auto_cast(): + conv = model(data) + loss = paddle.reduce_mean(conv) + scaled = scaler.scale(loss) # scale the loss + scaled.backward() # do backward + scaler.minimize(optimizer, scaled) # update parameters + """ + return super(GradScaler, self).minimize(optimizer, *args, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index fdf7adbfb45..71381ecfde7 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -196,6 +196,84 @@ class TestAmpScaler(unittest.TestCase): np.array_equal(param.numpy(), params_init[param.name])) +class TestResnet2(unittest.TestCase): + def train_resnet(self, enable_amp=True): + seed = 90 + + batch_size = train_parameters["batch_size"] + batch_num = 1 + + paddle.disable_static() + + paddle.manual_seed(seed) + paddle.framework.random._manual_program_seed(seed) + + resnet = ResNet(use_cudnn=True) + optimizer = optimizer_setting( + train_parameters, parameter_list=resnet.parameters()) + np.random.seed(seed) + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) + + dy_param_init_value = {} + for param in resnet.parameters(): + dy_param_init_value[param.name] = param.numpy() + + program = None + scaler = paddle.amp.GradScaler( + enable=enable_amp, init_loss_scaling=2.**10) + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + dy_x_data = np.array( + [x[0].reshape(3, 224, 224) for x in data]).astype('float32') + if len(np.array([x[1] + for x in data]).astype('int64')) != batch_size: + continue + y_data = np.array([x[1] for x in data]).astype('int64').reshape(-1, + 1) + img = paddle.to_tensor(dy_x_data) + label = paddle.to_tensor(y_data) + label.stop_gradient = True + + with paddle.amp.auto_cast(enable=enable_amp): + out = resnet(img) + + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + avg_loss = paddle.mean(x=loss) + + dy_out = avg_loss.numpy() + + scaled_loss = scaler.scale(avg_loss) + scaled_loss.backward() + + scaler.minimize(optimizer, scaled_loss) + + dy_grad_value = {} + for param in resnet.parameters(): + if param.trainable: + np_array = np.array(param._grad_ivar().value().get_tensor()) + dy_grad_value[param.name + fluid.core.grad_var_suffix( + )] = np_array + + resnet.clear_gradients() + + dy_param_value = {} + for param in resnet.parameters(): + dy_param_value[param.name] = param.numpy() + + paddle.enable_static() + + return dy_out, dy_param_value, dy_grad_value + + def test_resnet(self): + out_fp32 = self.train_resnet(enable_amp=False) + out_amp = self.train_resnet(enable_amp=True) + print(out_fp32[0], out_amp[0]) + self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) + + class TestResnet(unittest.TestCase): def train_resnet(self, enable_amp=True): seed = 90 diff --git a/python/setup.py.in b/python/setup.py.in index 414258a3b37..f09c189a68e 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -192,6 +192,7 @@ packages=['paddle', 'paddle.fluid.incubate.fleet.parameter_server.ir', 'paddle.fluid.incubate.fleet.collective', 'paddle.fluid.incubate.fleet.utils', + 'paddle.amp', 'paddle.hapi', 'paddle.vision', 'paddle.vision.models', -- GitLab