From 18e9dcdcdb79467fdc799e332f0f504bfabfb6a1 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 27 Apr 2023 14:11:35 +0800 Subject: [PATCH] [AMP] support OD level and skip dynamic loss scaling for bf16 (#53289) * support OD level and skip dynamic loss scaling for bf16 --- paddle/fluid/eager/amp_utils.h | 6 +- paddle/fluid/imperative/amp_auto_cast.h | 1 + paddle/fluid/pybind/imperative.cc | 1 + python/paddle/amp/amp_lists.py | 16 ++++- python/paddle/amp/auto_cast.py | 36 +++++----- python/paddle/amp/grad_scaler.py | 14 ++++ ...perative_auto_mixed_precision_for_eager.py | 13 ++-- test/amp/test_amp_api.py | 66 +++++++++++++++++++ 8 files changed, 129 insertions(+), 24 deletions(-) create mode 100644 test/amp/test_amp_api.py diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index ac9edc569df..2e06eaab8ac 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -129,7 +129,11 @@ inline phi::DataType GetAmpDestDtype( ->count(op_name)) { dst_type = phi::DataType::FLOAT32; } else { - dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); + if (amp_level == paddle::imperative::AmpLevel::OD) { + dst_type = phi::DataType::FLOAT32; + } else { + dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); + } } if (dst_type == amp_setting_dtype && diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index ced07b953d0..31dfc9dec57 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -31,6 +31,7 @@ enum class AmpLevel { O1, // amp, mixed fp32-fp16 O2, // almost fp16 O3, // fp16 + OD, // only conv and matmul use low precison. }; std::tuple, diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 372fae12ec3..8d5bd524a1c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2152,6 +2152,7 @@ void BindImperative(py::module *m_ptr) { py::enum_(m, "AmpLevel", py::arithmetic()) .value("O0", paddle::imperative::AmpLevel::O0) + .value("OD", paddle::imperative::AmpLevel::OD) .value("O1", paddle::imperative::AmpLevel::O1) .value("O2", paddle::imperative::AmpLevel::O2) .value("O3", paddle::imperative::AmpLevel::O3) diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py index f70c8f5ed7f..51c557b9481 100644 --- a/python/paddle/amp/amp_lists.py +++ b/python/paddle/amp/amp_lists.py @@ -91,10 +91,19 @@ BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} BF16_BLACK_LIST = set() +# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32. def white_list(): white_list = { - "float16": {"O1": FP16_WHITE_LIST, "O2": FP16_WHITE_LIST}, - "bfloat16": {"O1": BF16_WHITE_LIST, "O2": BF16_WHITE_LIST}, + "float16": { + "OD": FP16_WHITE_LIST, + "O1": FP16_WHITE_LIST, + "O2": FP16_WHITE_LIST, + }, + "bfloat16": { + "OD": BF16_WHITE_LIST, + "O1": BF16_WHITE_LIST, + "O2": BF16_WHITE_LIST, + }, } return white_list @@ -102,9 +111,10 @@ def white_list(): def black_list(): black_list = { "float16": { + "OD": set(), "O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST, "O2": FP16_EXTRA_BLACK_LIST, }, - "bfloat16": {"O1": BF16_BLACK_LIST, "O2": set()}, + "bfloat16": {"OD": set(), "O1": BF16_BLACK_LIST, "O2": set()}, } return black_list diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index ae9c957df68..d1eaf0dbd13 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -48,6 +48,7 @@ class AMPGlobalState: self.model_parameters = [] self.use_master_grad = False self.already_register_final_backward_hook = False + self.amp_dtype = 'float32' def __setattr__(self, name, val): self.__dict__[name] = val @@ -320,10 +321,8 @@ def amp_guard( # check amp_level: O0-O2 level = level.upper() - if not (level in ['O0', 'O1', 'O2']): - raise ValueError( - "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." - ) + if not (level in ['O0', 'OD', 'O1', 'O2']): + raise ValueError("level should be O0, OD, O1 or O2.") # check amp_dtype: float16 or bfloat16 dtype = dtype.lower() @@ -384,8 +383,11 @@ def amp_guard( ) amp_dtype = dtype + amp_global_state().amp_dtype = amp_dtype - if level == 'O1': + if level == 'OD': + amp_level = AMP_LEVEL.OD + elif level == 'O1': amp_level = AMP_LEVEL.O1 elif level == 'O2': amp_level = AMP_LEVEL.O2 @@ -642,22 +644,24 @@ def auto_cast( ): """ Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. - If enabled, the input data type (float32 or float16) of each operator is decided + If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided by autocast algorithm for better performance. - Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in - imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. + Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in + imperative mode. Args: enable(bool, optional): Enable auto-mixed-precision or not. Default is True. - custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support - fp16 calculation and are considered numerically-safe and performance-critical. These ops - will be converted to fp16. - custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 - calculation and are considered numerically-dangerous and whose effects may also be - observed in downstream ops. These ops will not be converted to fp16. - level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; - O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) + custom_white_list(set|list|tuple, optional): A default white list is already set. Usually there is no need to set custom white list. + The set of ops should be considered numerically-safe and performance-critical. These ops will be converted to float16/bfloat16. + custom_black_list(set|list|tuple, optional): A default black list is already set. You can set a custom black list according to the model. + The set of ops are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be + converted to float16/bfloat16. + level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list + will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2 + level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs + will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in + default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. Examples: diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 0f6d9f21a32..2cade3482e9 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import to_variable from paddle.fluid.framework import _dygraph_tracer, dygraph_only +from .auto_cast import amp_global_state + class OptimizerState(Enum): INIT = 0 @@ -179,6 +181,18 @@ class AmpScaler: """ check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()') + if ( + self._enable + and amp_global_state().amp_dtype != 'float16' + and self._use_dynamic_loss_scaling + ): + self._enable = False + self._use_dynamic_loss_scaling = False + warnings.warn( + 'It is not recommended to use dynamic loss scaling for %s, so GradScaler is disable by default.' + % (amp_global_state().amp_dtype) + ) + if not self._enable: return var diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index 8d24febaff2..5de19dfb411 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase): class TestAmpScaler(unittest.TestCase): def scale(self): + if not paddle.amp.is_float16_supported(): + return with fluid.dygraph.guard(): - data = paddle.rand([10, 1024]) + with paddle.amp.auto_cast(dtype='float16'): + data = paddle.rand([10, 1024]) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaled_data = scaler.scale(data) self.assertEqual( @@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase): ) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) data = fluid.dygraph.to_variable(inp_np) - - out = model(data) - loss = paddle.mean(out) + with paddle.amp.auto_cast(dtype='float16'): + out = model(data) + loss = paddle.mean(out) scaled_loss = scaler.scale(loss) scaled_loss.backward() optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) @@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase): ) def test_nan_inf(self): + if not paddle.amp.is_float16_supported(): + return self.nan_inf() def step_update_exception(self): diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py new file mode 100644 index 00000000000..7d397c70432 --- /dev/null +++ b/test/amp/test_amp_api.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from amp_base_models import AmpTestBase + +import paddle + + +class TestAutoCast(AmpTestBase): + def test_amp_OD_level(self): + conv = paddle.nn.Conv2D( + in_channels=1, out_channels=6, kernel_size=3, bias_attr=False + ) + linear = paddle.nn.Linear(in_features=4, out_features=4) + with paddle.amp.auto_cast(level='OD'): + out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) + out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16') + out3 = linear(out2) + + self.assertEqual(out1.dtype, paddle.float16) + self.assertEqual(out2.dtype, paddle.float32) + self.assertEqual(out3.dtype, paddle.float32) + + +class TestGradScaler(AmpTestBase): + def test_amp_grad_scaler(self): + model = paddle.nn.Conv2D(3, 2, 3) + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=model.parameters() + ) + scaler = paddle.amp.GradScaler() + data = paddle.rand([1, 3, 8, 8], dtype='float32') + paddle.amp.debugging.enable_operator_stats_collection() + with paddle.amp.auto_cast( + custom_black_list=['conv2d'], dtype='bfloat16' + ): + out = model(data) + loss = out.mean() + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + optimizer.clear_grad() + paddle.amp.debugging.disable_operator_stats_collection() + op_list = paddle.fluid.core.get_low_precision_op_list() + + self.assertEqual(scaler._enable, False) + self.assertEqual(scaler._use_dynamic_loss_scaling, False) + self.assertTrue('scale' not in op_list) + self.assertTrue('check_finite_and_unscale' not in op_list) + + +if __name__ == '__main__': + unittest.main() -- GitLab