From 4d7e9b5535c2294fd7bb26e3fa8b34b5f0405902 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 5 May 2023 11:21:16 +0800 Subject: [PATCH] [AMP] Cherry-pick AMP (#53442) Cherry-pick AMP --- .flake8 | 1 + paddle/fluid/eager/amp_utils.h | 6 +- paddle/fluid/imperative/amp_auto_cast.h | 1 + paddle/fluid/pybind/imperative.cc | 1 + python/paddle/amp/__init__.py | 6 +- python/paddle/amp/amp_lists.py | 120 ++++ python/paddle/amp/auto_cast.py | 278 ++++---- python/paddle/amp/grad_scaler.py | 14 + python/paddle/fluid/optimizer.py | 4 +- ...perative_auto_mixed_precision_for_eager.py | 21 +- .../fluid/tests/unittests/test_adadelta_op.py | 8 +- .../fluid/tests/unittests/test_adagrad_op.py | 8 +- .../fluid/tests/unittests/test_adam_op.py | 4 +- .../fluid/tests/unittests/test_adamax_op.py | 8 +- .../fluid/tests/unittests/test_momentum_op.py | 4 +- .../fluid/tests/unittests/test_rmsprop_op.py | 8 +- .../fluid/tests/unittests/test_sgd_op.py | 8 +- .../paddle/jit/dy2static/partial_program.py | 4 +- python/paddle/nn/layer/layers.py | 3 +- python/paddle/static/amp/__init__.py | 2 +- python/paddle/static/amp/debugging.py | 101 ++- python/paddle/static/amp/decorator.py | 148 ++++- python/paddle/static/amp/fp16_lists.py | 54 +- python/paddle/static/amp/fp16_utils.py | 602 +++++++++--------- python/paddle/static/amp/function_overload.py | 142 +++++ test/amp/amp_base_models.py | 128 +++- test/amp/test_amp_api.py | 66 ++ test/amp/test_amp_decorate.py | 166 +++++ test/amp/test_amp_list.py | 63 +- test/amp/test_amp_promote.py | 103 +++ test/amp/test_model_cast_to_bf16.py | 98 ++- .../contrib/test_image_classification_fp16.py | 35 +- test/ir/test_fuse_resnet_unit.py | 6 +- ...t_standalone_executor_aot_choose_kernel.py | 4 +- 34 files changed, 1622 insertions(+), 603 deletions(-) create mode 100644 python/paddle/amp/amp_lists.py create mode 100644 python/paddle/static/amp/function_overload.py create mode 100644 test/amp/test_amp_api.py create mode 100644 test/amp/test_amp_decorate.py create mode 100644 test/amp/test_amp_promote.py diff --git a/.flake8 b/.flake8 index 0315113df58..eeee9c2329a 100644 --- a/.flake8 +++ b/.flake8 @@ -39,3 +39,4 @@ per-file-ignores = .cmake-format.py: F821 test/dygraph_to_static/test_loop.py: F821 test/dygraph_to_static/test_closure_analysis.py: F821 + python/paddle/static/amp/decorator.py: F811 diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h index 68b5f7b11dd..8abb287122a 100644 --- a/paddle/fluid/eager/amp_utils.h +++ b/paddle/fluid/eager/amp_utils.h @@ -131,7 +131,11 @@ inline phi::DataType GetAmpDestDtype( ->count(op_name)) { dst_type = phi::DataType::FLOAT32; } else { - dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); + if (amp_level == paddle::imperative::AmpLevel::OD) { + dst_type = phi::DataType::FLOAT32; + } else { + dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); + } } if (dst_type == amp_setting_dtype && diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index ced07b953d0..31dfc9dec57 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -31,6 +31,7 @@ enum class AmpLevel { O1, // amp, mixed fp32-fp16 O2, // almost fp16 O3, // fp16 + OD, // only conv and matmul use low precison. }; std::tuple, diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e78a5bfd35d..1be8371ad4f 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2154,6 +2154,7 @@ void BindImperative(py::module *m_ptr) { py::enum_(m, "AmpLevel", py::arithmetic()) .value("O0", paddle::imperative::AmpLevel::O0) + .value("OD", paddle::imperative::AmpLevel::OD) .value("O1", paddle::imperative::AmpLevel::O1) .value("O2", paddle::imperative::AmpLevel::O2) .value("O3", paddle::imperative::AmpLevel::O3) diff --git a/python/paddle/amp/__init__.py b/python/paddle/amp/__init__.py index 60df9de03ad..5fa8055ba23 100644 --- a/python/paddle/amp/__init__.py +++ b/python/paddle/amp/__init__.py @@ -16,10 +16,8 @@ from .auto_cast import auto_cast # noqa: F401 from .auto_cast import decorate # noqa: F401 from .auto_cast import amp_guard # noqa: F401 from .auto_cast import amp_decorate # noqa: F401 -from .auto_cast import FP16_WHITE_LIST # noqa: F401 -from .auto_cast import FP16_BLACK_LIST # noqa: F401 -from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401 -from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401 +from .amp_lists import white_list # noqa: F401 +from .amp_lists import black_list # noqa: F401 from . import grad_scaler # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py new file mode 100644 index 00000000000..51c557b9481 --- /dev/null +++ b/python/paddle/amp/amp_lists.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The set of ops that support fp16 calculation and are considered numerically- +# safe and performance-critical. These ops are always converted to fp16. +FP16_WHITE_LIST = { + 'conv2d', + 'matmul', + 'matmul_v2', + 'max_pool2d_with_index', + 'mul', + 'fake_quantize_dequantize_abs_max', + 'fake_quantize_dequantize_moving_average_abs_max', +} + +# The set of ops that support fp16 calculation and are considered numerically- +# dangerous and whose effects may also be observed in downstream ops. +FP16_BLACK_LIST = { + 'tan', + 'acos', + 'asin', + 'sinh', + 'cosh', + 'atanh', + 'tanh_shrink', + 'cos_sim', + 'erfinv', + 'exp', + 'expm1', + 'log', + 'log10', + 'log2', + 'reciprocal', + 'rsqrt', + 'pow', + 'square', + 'reduce_sum', + 'mean', + 'reduce_mean', + 'reduce_prod', + 'cumprod', + 'cumsum', + 'dist', + 'pnorm', + 'frobenius_norm', + 'renorm', + 'group_norm', + 'layer_norm', + 'softmax', + 'softmin', + 'softplus', + 'log_softmax', + 'softmax_with_cross_entropy', + 'sigmoid_cross_entropy_with_logits', + 'c_softmax_with_cross_entropy', + 'cross_entropy', + 'cross_entropy2', + 'nll_loss', + 'huber_loss', + 'triplet_margin_loss', + 'log_loss', + 'hsigmoid_loss', + 'margin_cross_entropy', +} + +# FP16 performance of grad op is worse than that of FP32. Use FP32 by default. +FP16_EXTRA_BLACK_LIST = { + 'linear_interp_v2', + 'nearest_interp_v2', + 'bilinear_interp_v2', + 'bicubic_interp_v2', + 'trilinear_interp_v2', + 'lookup_table', + 'lookup_table_v2', + 'scatter', + 'depthwise_conv2d', +} + +BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} +BF16_BLACK_LIST = set() + + +# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32. +def white_list(): + white_list = { + "float16": { + "OD": FP16_WHITE_LIST, + "O1": FP16_WHITE_LIST, + "O2": FP16_WHITE_LIST, + }, + "bfloat16": { + "OD": BF16_WHITE_LIST, + "O1": BF16_WHITE_LIST, + "O2": BF16_WHITE_LIST, + }, + } + return white_list + + +def black_list(): + black_list = { + "float16": { + "OD": set(), + "O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST, + "O2": FP16_EXTRA_BLACK_LIST, + }, + "bfloat16": {"OD": set(), "O1": BF16_BLACK_LIST, "O2": set()}, + } + return black_list diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index bc76f866d94..76dd8b270c2 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -20,45 +20,7 @@ from paddle.fluid import core from paddle.fluid.framework import _dygraph_tracer, dygraph_only from paddle.fluid.wrapped_decorator import signature_safe_contextmanager -AMP_LEVEL = core.AmpLevel - -# The set of ops that support fp16 calculation and are considered numerically- -# safe and performance-critical. These ops are always converted to fp16. -FP16_WHITE_LIST = { - 'conv2d', - 'matmul', - 'matmul_v2', - 'max_pool2d_with_index', - 'mul', - 'fake_quantize_dequantize_abs_max', - 'fake_quantize_dequantize_moving_average_abs_max', -} - -# The set of ops that support fp16 calculation and are considered numerically- -# dangerous and whose effects may also be observed in downstream ops. -FP16_BLACK_LIST = { - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'c_softmax_with_cross_entropy', - 'cross_entropy', - 'cross_entropy2', - # default fp32 can avoid return inf when the sum value large than 65504 - 'reduce_sum', - # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} - +from .amp_lists import black_list, white_list AMP_RELATED_FLAGS = [ 'FLAGS_cudnn_exhaustive_search', @@ -72,27 +34,7 @@ AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, } -PURE_FP16_WHITE_LIST = copy.copy(FP16_WHITE_LIST) - -PURE_FP16_BLACK_LIST = { - 'lookup_table', - 'lookup_table_v2', - 'scatter', - 'scatter_grad', - # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} - -BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} -BF16_BLACK_LIST = set() - -PURE_BF16_WHITE_LIST = copy.copy(BF16_WHITE_LIST) -PURE_BF16_BLACK_LIST = set() - +AMP_LEVEL = core.AmpLevel _g_amp_state_ = None @@ -106,6 +48,7 @@ class AMPGlobalState: self.model_parameters = [] self.use_master_grad = False self.already_register_final_backward_hook = False + self.amp_dtype = 'float32' def __setattr__(self, name, val): self.__dict__[name] = val @@ -126,20 +69,12 @@ def _update_list( """ Update black and white list according to users' custom list. """ - if dtype == 'float16': - if level == 'O1': - _white_list = copy.copy(FP16_WHITE_LIST) - _black_list = copy.copy(FP16_BLACK_LIST) - else: - _white_list = copy.copy(PURE_FP16_WHITE_LIST) - _black_list = copy.copy(PURE_FP16_BLACK_LIST) - else: - if level == 'O1': - _white_list = copy.copy(BF16_WHITE_LIST) - _black_list = copy.copy(BF16_BLACK_LIST) - else: - _white_list = copy.copy(PURE_BF16_WHITE_LIST) - _black_list = copy.copy(PURE_BF16_BLACK_LIST) + if level == 'O0': + _white_list = set() + _black_list = set() + return _white_list, _black_list + _white_list = copy.copy(white_list()[dtype][level]) + _black_list = copy.copy(black_list()[dtype][level]) if custom_white_list and custom_black_list: for op_name in custom_white_list: if op_name in custom_black_list: @@ -199,47 +134,95 @@ def _is_gpu_bfloat16_supported(): return prop[0] >= 8 and cuda_version_check +def need_keep_fp32(layer, dtype): + need_keep_fp32 = False + # Highest prority. Because all the layers except BN will use bfloat16 params in bfoat16 training, + # here we provide a option to keep fp32 param. + if not layer._cast_to_low_precison: + need_keep_fp32 = True + # The BN layers will keep fp32 + elif isinstance( + layer, + ( + paddle.nn.BatchNorm, + paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, + paddle.nn.BatchNorm3D, + paddle.nn.SyncBatchNorm, + ), + ): + need_keep_fp32 = True + # layer._dtype is used to set params dtype. BF16 will use bf16 params. + elif (layer._dtype == 'float16') or ( + (dtype == 'float16') + and isinstance( + layer, + ( + paddle.nn.LayerNorm, + paddle.nn.InstanceNorm1D, + paddle.nn.InstanceNorm2D, + paddle.nn.InstanceNorm3D, + ), + ) + ): + need_keep_fp32 = True + + return need_keep_fp32 + + +def set_excluded_layers(models, excluded_layers): + excluded_layers_instances = [] + excluded_layers_types = [] + error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types." + if excluded_layers is None: + excluded_layers = [] + elif isinstance(excluded_layers, paddle.nn.Layer): + excluded_layers_instances = [excluded_layers] + elif isinstance(excluded_layers, type) and issubclass( + excluded_layers, paddle.nn.Layer + ): + excluded_layers_types = [excluded_layers] + elif isinstance(excluded_layers, list): + for item in excluded_layers: + if isinstance(item, paddle.nn.Layer): + excluded_layers_instances.append(item) + elif issubclass(item, paddle.nn.Layer): + excluded_layers_types.append(item) + else: + raise TypeError(error_message) + else: + raise TypeError(error_message) + + for idx in range(len(excluded_layers_instances)): + for layer in excluded_layers_instances[idx].sublayers( + include_self=True + ): + layer._cast_to_low_precison = False + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + if type(layer) in excluded_layers_types: + layer._cast_to_low_precison = False + + @dygraph_only -def pure_fp16_initialize(models): +def amp_initialize(models, dtype, excluded_layers): + set_excluded_layers(models, excluded_layers) for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): - layer._casted_by_pure_fp16 = True - if (layer._dtype == 'float16') or isinstance( - layer, - ( - paddle.nn.BatchNorm, - paddle.nn.BatchNorm1D, - paddle.nn.BatchNorm2D, - paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, - paddle.nn.SyncBatchNorm, - paddle.nn.InstanceNorm1D, - paddle.nn.InstanceNorm2D, - paddle.nn.InstanceNorm3D, - ), - ): + if need_keep_fp32(layer, dtype): continue - if isinstance( + if dtype == "float16" and isinstance( layer, ( paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedMultiHeadAttention, ), ): - layer._amp_decorate(dtype='float16') + layer._amp_decorate(dtype=dtype) continue - layer._to_impl( - dtype='float16', include_sublayers=False, floating_only=True - ) - return models - -@dygraph_only -def pure_bf16_initialize(models): - for idx in range(len(models)): - for layer in models[idx].sublayers(include_self=True): layer._to_impl( - dtype='bfloat16', include_sublayers=False, floating_only=True + dtype=dtype, include_sublayers=False, floating_only=True ) return models @@ -338,10 +321,8 @@ def amp_guard( # check amp_level: O0-O2 level = level.upper() - if not (level in ['O0', 'O1', 'O2']): - raise ValueError( - "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." - ) + if not (level in ['O0', 'OD', 'O1', 'O2']): + raise ValueError("level should be O0, OD, O1 or O2.") # check amp_dtype: float16 or bfloat16 dtype = dtype.lower() @@ -402,37 +383,20 @@ def amp_guard( ) amp_dtype = dtype + amp_global_state().amp_dtype = amp_dtype - if level == 'O1': + if level == 'OD': + amp_level = AMP_LEVEL.OD + elif level == 'O1': amp_level = AMP_LEVEL.O1 - if dtype == 'float16': - _white_list = FP16_WHITE_LIST - _black_list = FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - elif level == 'O2': amp_level = AMP_LEVEL.O2 - if dtype == 'float16': - _white_list = PURE_FP16_WHITE_LIST - _black_list = PURE_FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST elif level == 'O0': amp_level = AMP_LEVEL.O0 - if dtype == 'float16': - _white_list = FP16_WHITE_LIST - _black_list = FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - - if custom_white_list or custom_black_list: - _white_list, _black_list = _update_list( - custom_white_list, custom_black_list, level, dtype - ) + + _white_list, _black_list = _update_list( + custom_white_list, custom_black_list, level, dtype + ) if not enable: amp_level = AMP_LEVEL.O0 @@ -522,6 +486,7 @@ def amp_decorate( master_weight=None, save_dtype=None, master_grad=False, + excluded_layers=None, ): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. @@ -590,6 +555,8 @@ def amp_decorate( raise ValueError( "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." ) + if not (dtype in ['float16', 'bfloat16']): + raise ValueError("dtype only support float16 or bfloat16.") if level == 'O1': if optimizers is None: @@ -609,12 +576,9 @@ def amp_decorate( raise TypeError( "models must be either a single model or a list of models." ) - if dtype == 'float16': - models = pure_fp16_initialize(models=models) - elif dtype == 'bfloat16': - models = pure_bf16_initialize(models=models) - else: - raise TypeError("dtype only support float16 or bfloat16.") + + # initialize parameters of the model. + amp_initialize(models=models, dtype=dtype, excluded_layers=excluded_layers) if optimizers is not None: # check optimizers @@ -680,22 +644,24 @@ def auto_cast( ): """ Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. - If enabled, the input data type (float32 or float16) of each operator is decided + If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided by autocast algorithm for better performance. - Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in - imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. + Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in + imperative mode. Args: enable(bool, optional): Enable auto-mixed-precision or not. Default is True. - custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support - fp16 calculation and are considered numerically-safe and performance-critical. These ops - will be converted to fp16. - custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 - calculation and are considered numerically-dangerous and whose effects may also be - observed in downstream ops. These ops will not be converted to fp16. - level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; - O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) + custom_white_list(set|list|tuple, optional): A default white list is already set. Usually there is no need to set custom white list. + The set of ops should be considered numerically-safe and performance-critical. These ops will be converted to float16/bfloat16. + custom_black_list(set|list|tuple, optional): A default black list is already set. You can set a custom black list according to the model. + The set of ops are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be + converted to float16/bfloat16. + level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list + will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2 + level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs + will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in + default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. Examples: @@ -741,6 +707,7 @@ def decorate( master_weight=None, save_dtype=None, master_grad=False, + excluded_layers=None, ): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. @@ -757,8 +724,10 @@ def decorate( master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. - master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight - gradients will be FP32 dtype after the backpropagation. Default is False. + master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight + gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients. + excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as + an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16. Examples: @@ -808,5 +777,12 @@ def decorate( print(output.dtype) # FP16 """ return amp_decorate( - models, optimizers, level, dtype, master_weight, save_dtype, master_grad + models, + optimizers, + level, + dtype, + master_weight, + save_dtype, + master_grad, + excluded_layers, ) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 5c2d033d336..d0b3cf31880 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import to_variable from paddle.fluid.framework import _dygraph_tracer, dygraph_only +from .auto_cast import amp_global_state + class OptimizerState(Enum): INIT = 0 @@ -179,6 +181,18 @@ class AmpScaler: """ check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()') + if ( + self._enable + and amp_global_state().amp_dtype != 'float16' + and self._use_dynamic_loss_scaling + ): + self._enable = False + self._use_dynamic_loss_scaling = False + warnings.warn( + 'It is not recommended to use dynamic loss scaling for %s, so GradScaler is disable by default.' + % (amp_global_state().amp_dtype) + ) + if not self._enable: return var diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index db483b151e4..21b08c82d5e 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4982,8 +4982,8 @@ class PipelineOptimizer: device = post_op.attr(self._op_device_key) assert device, "The post op must have op_device set." op._set_attr(self._op_device_key, device) - elif (op.type == "cast" or op.type == "scale") and self._is_backward_op( - op + elif (op.type == "cast" or op.type == "scale") and ( + self._is_backward_op(op) or self._is_forward_op(op) ): prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index 7a7d65d27d5..5de19dfb411 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase): def custom_op_list(self): with fluid.dygraph.guard(): tracer = fluid.framework._dygraph_tracer() - base_white_list = paddle.amp.FP16_WHITE_LIST - base_black_list = paddle.amp.FP16_BLACK_LIST + base_white_list = paddle.amp.white_list()["float16"]["O1"] + base_black_list = paddle.amp.black_list()["float16"]["O1"] with paddle.amp.amp_guard( custom_white_list=["log"], custom_black_list=["conv2d"] ): @@ -104,8 +104,8 @@ class TestAutoCast(unittest.TestCase): == (set(base_black_list) - {"log"}) | {"conv2d"} ) - base_white_list = paddle.amp.PURE_FP16_WHITE_LIST - base_black_list = paddle.amp.PURE_FP16_BLACK_LIST + base_white_list = paddle.amp.white_list()["float16"]["O2"] + base_black_list = paddle.amp.black_list()["float16"]["O2"] with paddle.amp.amp_guard( custom_white_list=["log"], custom_black_list=["conv2d"], @@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase): class TestAmpScaler(unittest.TestCase): def scale(self): + if not paddle.amp.is_float16_supported(): + return with fluid.dygraph.guard(): - data = paddle.rand([10, 1024]) + with paddle.amp.auto_cast(dtype='float16'): + data = paddle.rand([10, 1024]) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaled_data = scaler.scale(data) self.assertEqual( @@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase): ) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) data = fluid.dygraph.to_variable(inp_np) - - out = model(data) - loss = paddle.mean(out) + with paddle.amp.auto_cast(dtype='float16'): + out = model(data) + loss = paddle.mean(out) scaled_loss = scaler.scale(loss) scaled_loss.backward() optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) @@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase): ) def test_nan_inf(self): + if not paddle.amp.is_float16_supported(): + return self.nan_inf() def step_update_exception(self): diff --git a/python/paddle/fluid/tests/unittests/test_adadelta_op.py b/python/paddle/fluid/tests/unittests/test_adadelta_op.py index f3eca8fec9c..14e791ce18f 100644 --- a/python/paddle/fluid/tests/unittests/test_adadelta_op.py +++ b/python/paddle/fluid/tests/unittests/test_adadelta_op.py @@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') @@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_adagrad_op.py b/python/paddle/fluid/tests/unittests/test_adagrad_op.py index f90bd83490b..d89e9233b33 100644 --- a/python/paddle/fluid/tests/unittests/test_adagrad_op.py +++ b/python/paddle/fluid/tests/unittests/test_adagrad_op.py @@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') @@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 8e4428131c6..d4f97402084 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase): optimizer.minimize(loss) exe.run(startup_program) if use_amp: - optimizer.amp_init(place=place, scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_adamax_op.py b/python/paddle/fluid/tests/unittests/test_adamax_op.py index 8acad2b4bfb..a473b4fece8 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_op.py @@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') @@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 14081d1cd73..30172c8d758 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase): optimizer.minimize(loss) exe.run(startup_program) if use_amp: - optimizer.amp_init(place=place, scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = numpy.random.random(size=(2, 2)).astype('float16') else: x = numpy.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 127abef2883..5f9579aaa07 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') @@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase): exe.run(startup_program) if use_amp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index 6fa74c14a87..86337e8a0f2 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase): exe.run(startup_program) if mp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') @@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase): exe.run(startup_program) if mp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) x = np.random.random(size=(2, 2)).astype('float16') else: x = np.random.random(size=(2, 2)).astype('float32') diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index c192316c775..d24a335c894 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -294,8 +294,8 @@ class PartialProgramLayer: def _create_amp_program(self, is_infer_mode=False): amp_program = self._origin_main_program.clone(for_test=is_infer_mode) with program_guard(amp_program): - paddle.static.amp.fp16_utils.rewrite_program( - amp_program, self._amp_list + paddle.static.amp.fp16_utils.cast_model_to_fp16( + amp_program, self._amp_list, use_fp16_guard=False, level='O1' ) if is_infer_mode: if self._hooker: diff --git a/python/paddle/nn/layer/layers.py b/python/paddle/nn/layer/layers.py index 0babc935f1d..29a6f49b5dc 100644 --- a/python/paddle/nn/layer/layers.py +++ b/python/paddle/nn/layer/layers.py @@ -401,7 +401,8 @@ class Layer: self._forward_pre_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict() - self._casted_by_pure_fp16 = False + # only used in AMP Training + self._cast_to_low_precison = True self._state_dict_hooks = collections.OrderedDict() # Records orignal functions after @to_static to support to rollback diff --git a/python/paddle/static/amp/__init__.py b/python/paddle/static/amp/__init__.py index 48856cc1af9..843be1443a4 100644 --- a/python/paddle/static/amp/__init__.py +++ b/python/paddle/static/amp/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from . import decorator -from .decorator import decorate, amp_decorate +from .decorator import decorate from . import fp16_lists from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists from . import fp16_utils diff --git a/python/paddle/static/amp/debugging.py b/python/paddle/static/amp/debugging.py index 28abe84c39b..5a894495d98 100644 --- a/python/paddle/static/amp/debugging.py +++ b/python/paddle/static/amp/debugging.py @@ -13,8 +13,14 @@ # limitations under the License. import copy +import logging import paddle +from paddle.fluid.log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) class OperatorStatsUnit: @@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input): var = block._var_recursive(var_name) return var.dtype except: - print( + _logger.warning( "Operator < {} > gets {} < {} : {} > error!".format( op.type, "input" if is_input else "output", arg_name, var_name ) @@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block): if _is_floating_point(compute_dtype) and _is_floating_point( var_dtype ): - print( + _logger.warning( "Operator < {} > has different input data types, input_names = {}, output_names = {}.".format( op.type, op.input_names, op.output_names ) @@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block): if _is_floating_point(compute_dtype) and _is_floating_point( var_dtype ): - print( + _logger.warning( "Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format( op.type, op.input_names, op.output_names ) @@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list): def _get_op_stats_list(program): + def _is_special_ops_with_input_x(op_type): + # operators have input X and have inputs different dtypes. + special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm'] + if op_type in special_op_list: + return True + if op_type.replace("_grad", "") in special_op_list: + return True + return False + op_stats_list = [] for block in program.blocks: block_op_stats_dict = {} @@ -161,13 +176,7 @@ def _get_op_stats_list(program): 'create_double_buffer_reader', ]: compute_dtype = None - elif op.type in [ - 'cast', - 'layer_norm', - 'layer_norm_grad', - 'batch_norm', - 'batch_norm_grad', - ]: + elif _is_special_ops_with_input_x(op.type): # Not check the input and output dtype difference for this operators. compute_dtype = _get_var_dtype_from_block(block, op, 'X', True) elif "Param" in op.input_names: @@ -183,6 +192,78 @@ def _get_op_stats_list(program): def collect_operator_stats(program=None, print_subblocks=False): + """ + Collect the number of operators for different data types through parsing + the program. The statistical data are categorized according to four data + types, namely float32, float16, bfloat16 and others. + + Args: + program(Program, optional): The program to parse. Default None, and the default main_program will be parsed. + print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False. + + Examples: + + .. code-block:: python + + import paddle + + paddle.enable_static() + + class SimpleConvNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) + self.linear = paddle.nn.Linear(in_features=26, out_features=10) + + def forward(self, x): + out = self.conv(x) + out = paddle.nn.functional.relu(out) + out = self.linear(out) + out = paddle.nn.functional.softmax(out) + return out + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.utils.unique_name.guard(): + with paddle.static.program_guard(main_program, startup_program): + model = SimpleConvNet() + x = paddle.static.data( + name='input', shape=[None, 1, 28, 28], dtype='float32' + ) + out = model(x) + loss = paddle.mean(out) + optimizer = paddle.optimizer.AdamW() + optimizer = paddle.static.amp.decorate(optimizer) + optimizer.minimize(loss) + paddle.static.amp.debugging.collect_operator_stats(main_program) + # <------------------------------------------------ op list of all blocks -------------------------------------------------> + # <------------------------------------------------------- op list --------------------------------------------------------> + # <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls --> + # adamw | 0 | 0 | 4 | 0 + # cast | 5 | 0 | 6 | 0 + # check_finite_and_unscale | 0 | 0 | 1 | 0 + # conv2d | 1 | 0 | 0 | 0 + # conv2d_grad | 1 | 0 | 0 | 0 + # elementwise_add | 2 | 0 | 0 | 0 + # elementwise_add_grad | 2 | 0 | 0 | 0 + # elementwise_mul | 0 | 0 | 1 | 0 + # elementwise_mul_grad | 0 | 0 | 1 | 0 + # fill_constant | 0 | 0 | 1 | 0 + # matmul_v2 | 1 | 0 | 0 | 0 + # matmul_v2_grad | 1 | 0 | 0 | 0 + # memcpy | 0 | 0 | 0 | 1 + # reduce_mean | 0 | 0 | 1 | 0 + # reduce_mean_grad | 0 | 0 | 1 | 0 + # relu | 1 | 0 | 0 | 0 + # relu_grad | 1 | 0 | 0 | 0 + # reshape2 | 0 | 0 | 1 | 0 + # reshape2_grad | 0 | 0 | 1 | 0 + # softmax | 0 | 0 | 1 | 0 + # softmax_grad | 0 | 0 | 1 | 0 + # update_loss_scaling | 0 | 0 | 1 | 0 + # <----------------------------------------------------- op count: 22 -----------------------------------------------------> + """ + def _convert_to_list(op_stats_unit_dict): for key, value in op_stats_unit_dict.items(): op_stats_unit_dict[key] = value.convert_to_list() diff --git a/python/paddle/static/amp/decorator.py b/python/paddle/static/amp/decorator.py index ae3a98c37b3..fc0aaac92bf 100644 --- a/python/paddle/static/amp/decorator.py +++ b/python/paddle/static/amp/decorator.py @@ -29,9 +29,24 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype from .fp16_utils import ( cast_model_to_fp16, cast_parameters_to_fp16, - rewrite_program, update_role_var_grad, ) +from .function_overload import FunctionType, overload + + +def _set_multi_precision(optimizer, multi_precision): + if not isinstance( + optimizer, + (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer), + ): + raise RuntimeError( + "Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format( + type(optimizer) + ) + ) + + if multi_precision and hasattr(optimizer, "_multi_precision"): + optimizer._multi_precision = multi_precision class OptimizerWithMixedPrecision: @@ -66,6 +81,7 @@ class OptimizerWithMixedPrecision: the loss scaling. use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program. Default None, which means that its value is equal to `use_pure_fp16`. + use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False. """ def __init__( @@ -81,6 +97,7 @@ class OptimizerWithMixedPrecision: incr_ratio, decr_ratio, use_amp_guard=None, + use_promote=False, ): self._optimizer = optimizer self._amp_lists = amp_lists @@ -115,6 +132,7 @@ class OptimizerWithMixedPrecision: self._decr_ratio = decr_ratio self._num_good_steps = None self._num_bad_steps = None + self.use_promote = use_promote def _set_distributed(self, flag): # if distributed, all cards will communication with each other, @@ -230,10 +248,18 @@ class OptimizerWithMixedPrecision: self._amp_lists, self._use_fp16_guard, self._amp_vartype, + level='O2', + use_promote=self.use_promote, ) else: - rewrite_program( - self._train_program, self._amp_lists, self._amp_vartype + # use_fp16_guard is not support amp-o1. + cast_model_to_fp16( + self._train_program, + self._amp_lists, + use_fp16_guard=False, + dest_type=self._amp_vartype, + level='O1', + use_promote=self.use_promote, ) if loss.dtype != core.VarDesc.VarType.FP32: @@ -361,10 +387,18 @@ class OptimizerWithMixedPrecision: self._amp_lists, self._use_fp16_guard, self._amp_vartype, + level='O2', + use_promote=self.use_promote, ) elif use_fp16_test: - rewrite_program( - test_program, self._amp_lists, self._amp_vartype + # use_fp16_guard is not support amp-o1. + cast_model_to_fp16( + test_program, + self._amp_lists, + use_fp16_guard=False, + dest_type=self._amp_vartype, + level='O1', + use_promote=self.use_promote, ) def apply_gradients(self, params_grads): @@ -610,6 +644,7 @@ class OptimizerWithMixedPrecision: return optimize_ops, scaled_params_grads +@overload(key=FunctionType.FP16_ONLY) def decorate( optimizer, amp_lists=None, @@ -622,6 +657,7 @@ def decorate( use_pure_fp16=False, use_fp16_guard=None, use_bf16=False, + use_promote=False, ): """ Decorate the given optimizer to adapt to the mixed-precision training. @@ -734,31 +770,108 @@ def decorate( incr_ratio=incr_ratio, decr_ratio=decr_ratio, use_amp_guard=use_fp16_guard, + use_promote=use_promote, ) return mp_optimizer -def amp_decorate( +@overload(key=FunctionType.COMMON) +def decorate( optimizer, amp_lists=None, level='O1', dtype='float16', + master_weight=None, init_loss_scaling=2**15, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, incr_ratio=2.0, decr_ratio=0.8, - use_dynamic_loss_scaling=True, + use_dynamic_loss_scaling=None, use_amp_guard=False, + use_promote=False, ): """ Decorate the given optimizer to adapt to the mixed-precision training. - """ - amp_dtype = check_amp_dtype(dtype) - if amp_lists is None: - amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) + Args: + optimizer(Optimizer): A common Optimizer. + amp_lists(CustomOpLists, optional): An CustomOpLists object. The default + white_list and black_list will be used for AMP training when it is + not set. Default is None. + level(str, optional): Auto mixed precision level. Accepted values are + "O1" and "O2": O1 represent mixed precision, the input data type of + each operator will be casted by white_list and black_list; + O2 represent pure FP16 / BF16 training, all operators parameters + and input data will be casted to FP16 / BF16, except operators in + black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1. + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. + master_weight(bool, optinal): For level='O2', whether to use multi-precision + during weight updating. If master_weight is None, in O2 level optimizer + will use multi-precision. Default is None. + init_loss_scaling(float, optional): The initial loss scaling factor. + Default is 32768. + incr_every_n_steps(int, optional): Increases loss scaling every n + consecutive steps with finite gradients. Default is 1000. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n + accumulated steps with nan or inf gradients. Default is 2. + incr_ratio(float, optional): The multiplier to use when increasing the + loss scaling. Default is 2. + decr_ratio(float, optional): The less-than-one-multiplier to use when + decreasing the loss scaling. Default is 0.8. + use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss + scaling. Default is None, which means True for float16, and False + for bfloat16. + + Returns: + An optimizer acting like a normal one but with mixed-precision training + + Examples: + + .. code-block:: python + + import paddle + + paddle.enable_static() + + class SimpleConvNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) + self.linear = paddle.nn.Linear(in_features=26, out_features=10) + + def forward(self, x): + out = self.conv(x) + out = paddle.nn.functional.relu(out) + out = self.linear(out) + out = paddle.nn.functional.softmax(out) + return out + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.utils.unique_name.guard(): + with paddle.static.program_guard(main_program, startup_program): + model = SimpleConvNet() + x = paddle.static.data( + name='input', shape=[None, 1, 28, 28], dtype='float32' + ) + out = model(x) + loss = paddle.mean(out) + optimizer = paddle.optimizer.AdamW() + optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16") + optimizer.minimize(loss) + + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(startup_program) + + # Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`, + # to convert FP32 parameters to low precision FP16 / BF16. + optimizer.amp_init(place, scope=paddle.static.global_scope()) + + """ # check amp_level: O0-O2 level = level.upper() if not (level in ['O0', 'O1', 'O2']): @@ -766,6 +879,18 @@ def amp_decorate( "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." ) + amp_dtype = check_amp_dtype(dtype) + if amp_lists is None: + amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) + + if use_dynamic_loss_scaling is None: + use_dynamic_loss_scaling = dtype == "float16" + + if optimizer is not None: + # support master_weight + multi_precision = not (master_weight is False) + _set_multi_precision(optimizer, multi_precision) + mp_optimizer = OptimizerWithMixedPrecision( optimizer, amp_lists, @@ -778,6 +903,7 @@ def amp_decorate( incr_ratio=incr_ratio, decr_ratio=decr_ratio, use_amp_guard=use_amp_guard, + use_promote=use_promote, ) return mp_optimizer diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index f310ae49126..6e0a4a5254c 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype): else: device = 'GPU' _, _, sys_unsupported_list = core.op_supported_infos(device, var_type) + + # sys_unsupported_list will include the following ops. + supported_fp16_list = { + "conditional_block", + "conditional_block_infer", + "select_input", + "while", + "cast", + "tensor_array_to_tensor", + "lod_array_length", + "write_to_array", + } + sys_unsupported_list -= supported_fp16_list + return device, sys_unsupported_list @@ -108,6 +122,29 @@ def _get_unsupported_list(dtype): return unsupported_list +# The three sets listed below are changed dynamiclly. They don't contain all +# paddle ops currently. + +# The set of ops that support fp16 calculation and are considered numerically- +# safe and performance-critical. These ops are always converted to fp16. + +_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'} + +white_list = { + 'conv2d', + 'matmul', + 'matmul_v2', + 'mul', +} + + +def _get_white_list(dtype): + white_list_for_dtype = copy.copy(white_list) + if dtype == 'float16': + white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list + return white_list_for_dtype + + class AutoMixedPrecisionLists: """ AutoMixedPrecisionLists is a class for black/white list. It can update @@ -132,7 +169,7 @@ class AutoMixedPrecisionLists: self.amp_dtype = check_amp_dtype(dtype) self._custom_white_list = custom_white_list self._custom_black_list = custom_black_list - self.white_list = copy.copy(white_list) + self.white_list = copy.copy(_get_white_list(self.amp_dtype)) self.black_list = copy.copy(black_list) self.gray_list = copy.copy(gray_list) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) @@ -143,6 +180,9 @@ class AutoMixedPrecisionLists: """ Update black and white list according to users' custom list. """ + _logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ") + _logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ") + _logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ") if self._custom_white_list and self._custom_black_list: for op_name in self._custom_white_list: if op_name in self._custom_black_list: @@ -177,18 +217,6 @@ class AutoMixedPrecisionLists: ) -# The three sets listed below are changed dynamiclly. They don't contain all -# paddle ops currently. - -# The set of ops that support fp16 calculation and are considered numerically- -# safe and performance-critical. These ops are always converted to fp16. -white_list = { - 'conv2d', - 'matmul', - 'matmul_v2', - 'mul', -} - # The set of ops that support fp16 calculation and are considered numerically- # dangerous and whose effects may also be observed in downstream ops. black_list = { diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 21b5268aa40..740930769cb 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging import numpy as np @@ -22,7 +21,11 @@ from paddle.fluid import core, framework, global_scope from paddle.fluid.log_helper import get_logger from paddle.fluid.wrapped_decorator import signature_safe_contextmanager -from .fp16_lists import AutoMixedPrecisionLists, get_low_precision_dtypestr +from .fp16_lists import ( + AutoMixedPrecisionLists, + black_list, + get_low_precision_dtypestr, +) _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' @@ -144,7 +147,7 @@ def _keep_fp32_output(op, out_name): def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): """ - Insert cast op and rename args of input and output. + Insert cast op and rename op's input. Args: block (Program): The block in which the operator is. @@ -167,8 +170,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): in_var = block._find_var_recursive(in_var_name) if in_var.type not in _valid_types or in_var.dtype == dest_dtype: continue + # op's input is already casted to dest_dtype before. Set the in_var.name to cast_name. + cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) + casted_var = block._find_var_recursive(cast_name) + if casted_var and casted_var.dtype == dest_dtype: + _rename_arg(op, in_var.name, casted_var.name) + continue + + # insert cast for op's input. if in_var.dtype == src_dtype: - cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) out_var = block.vars.get(cast_name) if out_var is None or out_var.dtype != dest_dtype: op_device = op.attr('op_device') @@ -206,6 +216,13 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): stop_gradient=in_var.stop_gradient, ) + # Only forward program will be inserted cast op, but some ops + # has no op_role attr, so here set it direcly. eg. resnet_unit. + op_role = ( + int(core.op_proto_and_checker_maker.OpRole.Forward) + if not op.has_attr('op_role') + else op.attr('op_role') + ) block._insert_op_without_sync( idx, type="cast", @@ -215,70 +232,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): "in_dtype": in_var.dtype, "out_dtype": out_var.dtype, "op_device": op_device, - "op_role": op.attr("op_role"), + "op_role": op_role, }, ) num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) - else: - if op.has_attr('in_dtype'): - op._set_attr('in_dtype', dest_dtype) - if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [ - core.VarDesc.VarType.FP16, - core.VarDesc.VarType.BF16, - ]: - for out_name in op.output_names: - if _keep_fp32_output(op, out_name): - continue - for out_var_name in op.output(out_name): - out_var = block.var(out_var_name) - if out_var.type not in _valid_types: - continue - if out_var.dtype == core.VarDesc.VarType.FP32: - out_var.desc.set_dtype(dest_dtype) - if op.has_attr('out_dtype'): - op._set_attr('out_dtype', dest_dtype) - return num_cast_ops - -def _insert_cast_post_op( - block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map -): - num_cast_ops = 0 - - target_var = block.var(target_name) - if target_var.type not in _valid_types or target_var.dtype == dest_dtype: - return num_cast_ops - - assert ( - target_var.dtype == src_dtype - ), "The real dtype({}) is not equal to the src dtype({})".format( - _dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype) - ) - - cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) - cast_var = block.vars.get(cast_name) - if cast_var is None or cast_var.dtype != dest_dtype: - cast_var = block.create_var( - name=cast_name, - dtype=dest_dtype, - persistable=False, - stop_gradient=target_var.stop_gradient, - ) - block._insert_op( - idx, - type="cast", - inputs={"X": target_var}, - outputs={"Out": cast_var}, - attrs={ - "in_dtype": target_var.dtype, - "out_dtype": cast_var.dtype, - "op_device": op.attr("op_device"), - "op_role": op.attr("op_role"), - }, - ) - num_cast_ops += 1 - op_var_rename_map[block.idx][target_var.name] = cast_var.name + for attr_name in ['in_dtype', 'out_dtype', 'dtype']: + if op.has_attr(attr_name) and is_float_dtype(op.attr(attr_name)): + op._set_attr(attr_name, dest_dtype) return num_cast_ops @@ -420,11 +382,204 @@ def fp16_guard(): yield +def is_float_dtype(dtype): + return ( + dtype == core.VarDesc.VarType.FP32 + or dtype == core.VarDesc.VarType.FP16 + or dtype == core.VarDesc.VarType.BF16 + or dtype == core.VarDesc.VarType.FP64 + ) + + +def set_var_dst_dtype( + op, var_names, block, global_block, dtype, need_set_dtype +): + low_precison_var_names = set() + for var_name in var_names: + var = None + try: + var = block._var_recursive(var_name) + except ValueError as e: + _logger.debug(f"-- {e}, try to get it in the global block --") + var = global_block.var(var_name) + if var is not None: + _logger.debug( + f"-- var {var_name} is got in the global block --" + ) + + if var is None or var.type not in _valid_types: + continue + + if is_float_dtype(var.dtype): + low_precison_var_names.add(var_name) + if need_set_dtype: + var.desc.set_dtype(dtype) + + _logger.debug( + "---- op type: {}, var name: {}, var dtype: {} ----".format( + op.type, var_name, var.dtype + ) + ) + + return low_precison_var_names + + +def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): + if level == "O1": + return + keep_fp32_var_names = set() + all_parameters = [] + for block in program.blocks: + all_parameters.extend(block.all_parameters()) + ops = block.ops + for op in ops: + if op_need_keep_fp32(op, amp_lists, use_fp16_guard): + for in_name in op.input_names: + keep_fp32_var_names = keep_fp32_var_names.union( + op.input(in_name) + ) + else: + for in_name in op.input_names: + if not core.is_compiled_with_ipu() and _keep_fp32_input( + op, in_name + ): + keep_fp32_var_names = keep_fp32_var_names.union( + op.input(in_name) + ) + + for param in all_parameters: + if param.name not in keep_fp32_var_names: + _logger.debug(f"-- set param {param.name} to {dtype} --.") + param.desc.set_dtype(dtype) + + +def op_need_keep_fp32(op, amp_lists, use_fp16_guard): + need_keep_fp32 = False + if _need_keep_fp32( + op, + amp_lists.unsupported_list, + use_fp16_guard, + ): + need_keep_fp32 = True + elif amp_lists.black_varnames is not None and _is_in_black_varnames( + op, amp_lists + ): + need_keep_fp32 = True + elif op.type in amp_lists.black_list: + need_keep_fp32 = True + + return need_keep_fp32 + + +def get_promote_dtype(op, amp_dtype, block): + dst_dtype = amp_dtype + for in_name in op.input_names: + # for ipu, all inputs must be converted to fp16 + if not core.is_compiled_with_ipu() and _keep_fp32_input(op, in_name): + _logger.debug( + "---- Input {} {} should be kept fp32 ----".format( + in_name, op.input(in_name) + ) + ) + continue + # if this op has inputs + if in_name: + for in_var_name in op.input(in_name): + in_var = block._find_var_recursive(in_var_name) + if in_var and in_var.dtype == core.VarDesc.VarType.FP32: + dst_dtype = core.VarDesc.VarType.FP32 + break + else: + dst_dtype = core.VarDesc.VarType.FP32 + + return dst_dtype + + +def get_amp_dst_dtype( + op, amp_dtype, level, block, amp_lists, keep_fp32_ops, keep_fp16_ops +): + if level == 'O2': + return amp_dtype + + ops = block.ops + dst_dtype = amp_dtype + if op.type in amp_lists.gray_list: + keep_fp32 = False + keep_fp16 = False + for in_name in op.input_names: + # if this op has inputs + if in_name: + for in_var_name in op.input(in_name): + in_var = block._find_var_recursive(in_var_name) + # this in_var isn't the output of other op + if in_var.op is None: + continue + elif in_var.op is op: + prev_op = find_true_prev_op(ops, op, in_var_name) + if prev_op is None: + continue + else: + prev_op = in_var.op + + # if it's one of inputs + if ( + prev_op in keep_fp32_ops + or prev_op.type in amp_lists.black_list + ): + dst_dtype = core.VarDesc.VarType.FP32 + elif ( + prev_op in keep_fp16_ops + or prev_op.type in amp_lists.white_list + ): + dst_dtype = amp_dtype + else: + # For numerical safe, we apply fp32 computation on ops that + # are not determined which list they should stay. + dst_dtype = core.VarDesc.VarType.FP32 + return dst_dtype + + +def process_op_input_and_outputs(op, block, global_block, dtype): + low_precison_var_names = set() + # Get the FP16 input because the low_precison_var_names is required for the parameter casting. + # The dtype of the input is not set to fp16, because it is done in the step 3 of cast_model_to_fp16. + for in_name in op.input_names: + # for ipu, all inputs must be converted to fp16 + if not core.is_compiled_with_ipu() and _keep_fp32_input(op, in_name): + continue + in_vars = set_var_dst_dtype( + op, + op.input(in_name), + block, + global_block, + dtype, + need_set_dtype=False, + ) + low_precison_var_names = low_precison_var_names.union(in_vars) + # Set the output to FP16 because its consumer OP needs to determine if the dtype needs + # to be promoted. + for out_name in op.output_names: + # for ipu, all outputs must be converted to fp16 + if not core.is_compiled_with_ipu() and _keep_fp32_output(op, out_name): + continue + set_var_dst_dtype( + op, + op.output(out_name), + block, + global_block, + dtype, + need_set_dtype=True, + ) + return low_precison_var_names + + def cast_model_to_fp16( program, amp_lists=None, use_fp16_guard=True, dest_type=core.VarDesc.VarType.FP16, + level='O2', + use_promote=False, ): """ Traverse all ops in the whole model and set their inputs and outputs @@ -438,158 +593,132 @@ def cast_model_to_fp16( constructing the program. Default True. dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16. """ - + _logger.debug("---- before cast model to fp16 ----") + _logger.debug(program) if amp_lists is None: dtype = get_low_precision_dtypestr(dest_type) amp_lists = AutoMixedPrecisionLists(dtype) - amp_lists.unsupported_list -= { - "conditional_block_grad", - "conditional_block", - "conditional_block_infer", - "select_input", - "while", - "while_grad", - "cast", - "tensor_array_to_tensor", - "lod_array_length", - "write_to_array", - } + + # For amp o2 there is no blacklist by default. + if level == 'O2': + amp_lists.black_list = amp_lists.black_list - black_list + global_block = program.global_block() keep_fp32_ops = set() + keep_fp16_ops = set() to_fp16_var_names = set() - origin_ops = [] - for block in program.blocks: - origin_ops.extend(block.ops) + # step 1: set params dtype. + set_param_dtype( + program, + dtype=dest_type, + amp_lists=amp_lists, + use_fp16_guard=use_fp16_guard, + level=level, + ) + + def need_process(op): + need_process = True + if op.type in ["cast", "create_py_reader", "read"]: + need_process = False + else: + for attr_name in ['out_dtype', 'dtype']: + if op.has_attr(attr_name) and is_float_dtype( + op.attr(attr_name) + ): + need_process = False + + return need_process + + # step 2: divide op into different sets according to the black/unsupported and white lists. for block in program.blocks: ops = block.ops for op in ops: - if op.type == 'create_py_reader' or op.type == 'read': + _logger.debug(f"-- process op: {op} --") + if not need_process(op): + _logger.debug("---- The op does not need to be processed ----.") continue - if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard): + if op_need_keep_fp32(op, amp_lists, use_fp16_guard): keep_fp32_ops.add(op) - continue # processed below - for in_name in op.input_names: - # for ipu, all inputs must be converted to fp16 - if not core.is_compiled_with_ipu() and _keep_fp32_input( - op, in_name - ): - continue - for in_var_name in op.input(in_name): - in_var = None - try: - in_var = block._var_recursive(in_var_name) - except ValueError as e: - _logger.debug( - "-- {}, try to get it in the global block --".format( - e - ) - ) - in_var = global_block.var(in_var_name) - if in_var is not None: - _logger.debug( - "-- var {} is got in the global block --".format( - in_var_name - ) - ) - - if in_var is None or in_var.type not in _valid_types: - continue - - if in_var.dtype == core.VarDesc.VarType.FP32: - in_var.desc.set_dtype(dest_type) - to_fp16_var_names.add(in_var_name) + process_op_input_and_outputs( + op, block, global_block, core.VarDesc.VarType.FP32 + ) + _logger.debug( + "---- Add into keep_fp32_ops because the op needs to be kept fp32 ----" + ) + elif op.type in amp_lists.white_list: + keep_fp16_ops.add(op) + # get fp16 inputs and set op's outputs to fp16 for promote judgments + fp16_var_names = process_op_input_and_outputs( + op, block, global_block, dest_type + ) + to_fp16_var_names = to_fp16_var_names.union(fp16_var_names) + _logger.debug( + "---- Add into keep_fp16_ops because the op in white_list ----" + ) + else: + # divide others ops into fp16/fp32 sets according to promoting principle. + dst_dtype = dest_type + if not use_promote: + dst_dtype = get_amp_dst_dtype( + op, + dest_type, + level, + block, + amp_lists, + keep_fp32_ops, + keep_fp16_ops, + ) + else: + dst_dtype = get_promote_dtype(op, dest_type, block) + if dst_dtype == dest_type: + keep_fp16_ops.add(op) + fp16_var_names = process_op_input_and_outputs( + op, block, global_block, dest_type + ) + to_fp16_var_names = to_fp16_var_names.union(fp16_var_names) _logger.debug( - "-- op type: {}, in var name: {}, in var dtype: {} --".format( - op.type, in_var_name, in_var.dtype - ) + "---- Add into keep_fp16_ops because it should be promoted to fp16 ----" + ) + else: + keep_fp32_ops.add(op) + process_op_input_and_outputs( + op, block, global_block, core.VarDesc.VarType.FP32 ) - - for out_name in op.output_names: - # for ipu, all outputs must be converted to fp16 - if not core.is_compiled_with_ipu() and _keep_fp32_output( - op, out_name - ): - continue - for out_var_name in op.output(out_name): - out_var = None - try: - out_var = block._var_recursive(out_var_name) - except ValueError as e: - _logger.debug( - "-- {}, try to get it in the global block --".format( - e - ) - ) - out_var = global_block.var(out_var_name) - if out_var is not None: - _logger.debug( - "-- var {} is got in the global block --".format( - out_var_name - ) - ) - - if out_var is None or out_var.type not in _valid_types: - continue - - if out_var.dtype == core.VarDesc.VarType.FP32: - out_var.desc.set_dtype(dest_type) - _logger.debug( - "-- op type: {}, out var name: {}, out var dtype: {} --".format( - op.type, out_var_name, out_var.dtype - ) + "---- Add into keep_fp32_ops because it should be promoted to fp32 ----" ) - for attr_name in ['in_dtype', 'out_dtype', 'dtype']: - if ( - op.has_attr(attr_name) - and op.attr(attr_name) == core.VarDesc.VarType.FP32 - ): - op._set_attr(attr_name, dest_type) - # process ops in keep_fp32_ops - op_var_rename_map = [ - collections.OrderedDict() for _ in range(len(program.blocks)) - ] + # step 3: insert cast op for op's inputs. for block in program.blocks: ops = block.ops idx = 0 while idx < len(ops): op = ops[idx] num_cast_ops = 0 + if op in keep_fp16_ops: + in_var_cast_num = _insert_cast_op( + block, + op, + idx, + core.VarDesc.VarType.FP32, + dest_type, + ) + num_cast_ops += in_var_cast_num if op in keep_fp32_ops: - pre_cast_num = _insert_cast_op( + in_var_cast_num = _insert_cast_op( block, op, idx, dest_type, core.VarDesc.VarType.FP32, ) - num_cast_ops += pre_cast_num - for out_var_name in op.output_arg_names: - out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: - continue - if out_var.dtype == dest_type: - out_var.desc.set_dtype(core.VarDesc.VarType.FP32) - post_ops = find_true_post_op(ops, op, out_var_name) - for post_op in post_ops: - if post_op in keep_fp32_ops: - continue - post_cast_num = _insert_cast_post_op( - block, - op, - idx + pre_cast_num + 1, - core.VarDesc.VarType.FP32, - dest_type, - out_var_name, - op_var_rename_map, - ) - num_cast_ops += post_cast_num - idx += num_cast_ops + 1 + num_cast_ops += in_var_cast_num - _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) + idx += num_cast_ops + 1 + _logger.debug("---- after cast model to fp16 ----") + _logger.debug(program) return to_fp16_var_names @@ -626,11 +755,14 @@ def cast_parameters_to_fp16( for block in program.blocks: all_parameters.extend(block.all_parameters()) + dtype_str = get_low_precision_dtypestr(dest_type) fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() var_scope = scope if scope else global_scope() for param in all_parameters: if param.name in fp16_var_names: - _logger.debug(f"---- cast {param.name} to fp16/bf16 dtype ----") + _logger.debug( + f"-- cast {param.name} to {dtype_str}, place is {place}" + ) if var_scope.find_var(param.name): param_t = var_scope.find_var(param.name).get_tensor() data = np.array(param_t) @@ -643,108 +775,6 @@ def cast_parameters_to_fp16( _logger.warning(f"Cannot find {param.name}") -def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16): - """ - Traverse all ops in current block and insert cast op according to - which set current op belongs to. - - 1. When an op belongs to the black list, add it to black set - 2. When an op belongs to the white list, add it to white set - 3. When an op belongs to the gray list. If one - of its inputs is the output of black set op or black list op, - add it to black set. If all of its previous ops are not black - op and one of its inputs is the output of white set op or - white list op, add it to white set. - 4. When an op isn't in the lists, add it to black op set. - 5. Add necessary cast ops to make sure that black set op will be - computed in fp32 mode, while white set op will be computed in - fp16 mode. - - Args: - main_prog (Program): The main program for training. - dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16. - """ - block = main_prog.global_block() - block._sync_with_cpp() - ops = block.ops - white_op_set = set() - black_op_set = set() - for op in ops: - - # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, - # we don't need to handle reader op and the input of 'create_py_reader' is not - # in block, which may result in errors. - # See GeneratorLoader._init_non_iterable() for details. - if op.type == 'create_py_reader' or op.type == 'read': - continue - - if amp_lists.black_varnames is not None and _is_in_black_varnames( - op, amp_lists - ): - black_op_set.add(op) - continue - - if op.type in amp_lists.black_list: - black_op_set.add(op) - elif op.type in amp_lists.white_list: - white_op_set.add(op) - elif op.type in amp_lists.gray_list: - is_black_op = False - is_white_op = False - for in_name in op.input_names: - # if this op has inputs - if in_name: - for in_var_name in op.input(in_name): - in_var = block.var(in_var_name) - # this in_var isn't the output of other op - if in_var.op is None: - continue - elif in_var.op is op: - prev_op = find_true_prev_op(ops, op, in_var_name) - if prev_op is None: - continue - else: - prev_op = in_var.op - # if it's one of inputs - if ( - prev_op in black_op_set - or prev_op.type in amp_lists.black_list - ): - is_black_op = True - elif ( - prev_op in white_op_set - or prev_op.type in amp_lists.white_list - ): - is_white_op = True - if is_black_op: - black_op_set.add(op) - elif is_white_op: - white_op_set.add(op) - else: - pass - else: - # For numerical safe, we apply fp32 computation on ops that - # are not determined which list they should stay. - black_op_set.add(op) - - idx = 0 - while idx < len(ops): - op = ops[idx] - num_cast_ops = 0 - if op in black_op_set: - num_cast_ops = _insert_cast_op( - block, op, idx, dest_type, core.VarDesc.VarType.FP32 - ) - elif op in white_op_set: - num_cast_ops = _insert_cast_op( - block, op, idx, core.VarDesc.VarType.FP32, dest_type - ) - else: - pass - - idx += num_cast_ops + 1 - - def update_role_var_grad(main_prog, params_grads): """ Update op_role_var attr for some ops to make sure the gradients diff --git a/python/paddle/static/amp/function_overload.py b/python/paddle/static/amp/function_overload.py new file mode 100644 index 00000000000..8139401c21d --- /dev/null +++ b/python/paddle/static/amp/function_overload.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation refers to https://arpitbhayani.me/blogs/function-overloading. +# Note: it is customed for paddle.static.amp.decorate function. + +import inspect +import logging +from enum import Enum + +from paddle.fluid.log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + + +class FunctionType(Enum): + FP16_ONLY = 0 + COMMON = 1 + + +class Function: + """ + Function is a wrap over standard python function + An instance of this Function class is also callable + just like the python function that it wrapped. + When the instance is "called" like a function it fetches + the function to be invoked from the virtual namespace and then + invokes the same. + """ + + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + """ + Overriding the __call__ function which makes the + instance callable. + """ + # fetching the function to be invoked from the virtual namespace + # through the arguments. + fn = Namespace.get_instance().get(*args, **kwargs) + # invoking the wrapped function and returning the value. + return fn(*args, **kwargs) + + +class Namespace: + """ + Namespace is the singleton class that is responsible + for holding all the functions. + """ + + __instance = None + + def __init__(self): + if self.__instance is None: + self.function_map = {} + Namespace.__instance = self + else: + raise Exception("cannot instantiate Namespace again.") + + @staticmethod + def get_instance(): + if Namespace.__instance is None: + Namespace() + return Namespace.__instance + + def register(self, fn, key): + """ + Register the function in the virtual namespace and return + an instance of callable Function that wraps the function fn. + + Args: + fn (function): the native python function handle. + key (FunctionType): the specified type. + """ + assert isinstance( + key, FunctionType + ), f"The type of key is expected to be FunctionType, but recieved {type(key)}." + func = Function(fn) + self.function_map[key] = fn + return func + + def get(self, *args, **kwargs): + """ + Get the matching function from the virtual namespace according to the actual arguments. + Return None if it did not find any matching function. + """ + _logger.debug(f"get function: args={args}, kwargs={kwargs}") + satisfied_function_keys = set(self.function_map.keys()) + num_actual_args = len(args) + len(kwargs) + for func_key in self.function_map.keys(): + if func_key not in satisfied_function_keys: + continue + fn = self.function_map[func_key] + specs = inspect.getfullargspec(fn) + if len(specs) < len(args) + len(kwargs): + # Remove the not satisfied function according to the number of actual arguments. + _logger.debug( + f"fn={fn} (key={func_key}) is not satisfied and removed." + ) + satisfied_function_keys.remove(func_key) + continue + if len(kwargs) > 0: + # Remove the not satisfied function according to argument keys in kwargs. + for arg_name, value in kwargs.items(): + if arg_name not in specs.args: + _logger.debug( + f"fn={fn} (key={func_key}) is not satisfied and removed." + ) + satisfied_function_keys.remove(func_key) + break + if len(satisfied_function_keys) == 1: + key = list(satisfied_function_keys)[0] + elif len(args) >= 3 and isinstance(args[2], float): + key = FunctionType.FP16_ONLY + else: + key = FunctionType.COMMON + return self.function_map.get(key) + + +def overload(key): + """overload is the decorator that wraps the function + and returns a callable object of type Function. + """ + + def decorator(fn): + return Namespace.get_instance().register(fn, key) + + return decorator diff --git a/test/amp/amp_base_models.py b/test/amp/amp_base_models.py index 7f97a923f04..8b63b2391c0 100644 --- a/test/amp/amp_base_models.py +++ b/test/amp/amp_base_models.py @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from paddle import nn +from paddle.fluid import core _fixed_add_param = np.random.random(size=[16, 16]).astype("float32") def _build_optimizer( - use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False + use_amp, + amp_dtype="float16", + amp_level="O1", + amp_lists=None, + use_grad_clip=False, + use_promote=False, ): if use_grad_clip: grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) @@ -34,16 +42,14 @@ def _build_optimizer( beta2=0.836, epsilon=1e-4, weight_decay=0.01, - multi_precision=True, ) if use_amp: - amp_lists = paddle.static.amp.AutoMixedPrecisionLists( - custom_white_list=["elementwise_add"], - custom_black_list=["reduce_mean"], + optimizer = paddle.static.amp.decorate( + optimizer, + amp_lists, + level=amp_level, dtype=amp_dtype, - ) - optimizer = paddle.static.amp.amp_decorate( - optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype + use_promote=use_promote, ) return optimizer @@ -65,7 +71,9 @@ class SimpleAddNet(nn.Layer): return x + self.weight -def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): +def build_add_model( + use_amp, amp_dtype="float16", amp_level="O1", use_promote=False +): main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.utils.unique_name.guard(): @@ -80,7 +88,22 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype) out = model(x) loss = paddle.mean(out) - optimizer = _build_optimizer(use_amp, amp_dtype, amp_level) + + if use_amp: + amp_lists = paddle.static.amp.AutoMixedPrecisionLists( + custom_white_list=["elementwise_add"], + custom_black_list=["reduce_mean"], + dtype=amp_dtype, + ) + else: + amp_lists = None + optimizer = _build_optimizer( + use_amp, + amp_dtype, + amp_level, + amp_lists, + use_promote=use_promote, + ) optimizer.minimize(loss) feed_vars = [x] fetch_vars = [loss] @@ -91,30 +114,37 @@ class SimpleConvNet(nn.Layer): def __init__(self): super().__init__() self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) - self.linear = nn.Linear(in_features=6, out_features=10) + self.linear = nn.Linear(in_features=96, out_features=4) def forward(self, x): out = self.conv(x) out = nn.functional.relu(out) + out = out.flatten(start_axis=1, stop_axis=3) out = self.linear(out) out = nn.functional.softmax(out) return out -def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"): +def build_conv_model( + use_amp, amp_dtype="float16", amp_level="O1", use_promote=False +): main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.utils.unique_name.guard(): with paddle.static.program_guard(main_program, startup_program): model = SimpleConvNet() x = paddle.static.data( - name='input', shape=[None, 1, 28, 28], dtype='float32' + name='input', shape=[None, 1, 6, 6], dtype='float32' ) out = model(x) loss = paddle.mean(out) - optimizer = _build_optimizer(use_amp, amp_dtype, amp_level) + optimizer = _build_optimizer( + use_amp, amp_dtype, amp_level, use_promote=use_promote + ) optimizer.minimize(loss) - return main_program, startup_program + feed_vars = [x] + fetch_vars = [loss] + return main_program, startup_program, optimizer, feed_vars, fetch_vars class SimpleEmbeddingNet(nn.Layer): @@ -136,7 +166,9 @@ class SimpleEmbeddingNet(nn.Layer): return out -def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): +def build_embedding_model( + use_amp, amp_dtype="float16", amp_level="O1", use_promote=False +): main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.utils.unique_name.guard(): @@ -145,7 +177,14 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): x = paddle.static.data(name='x', shape=[None, 32], dtype='int64') out = model(x) loss = paddle.mean(out) - optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True) + optimizer = _build_optimizer( + use_amp, + amp_dtype, + amp_level, + None, + True, + use_promote=use_promote, + ) optimizer.minimize(loss) return main_program, startup_program @@ -186,3 +225,58 @@ def build_while_model(): out = model(x) loss = paddle.mean(out) return main_program, startup_program + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not complied with CUDA and not support amp.", +) +class AmpTestBase(unittest.TestCase): + def setUp(self): + self.amp_dtype = None + self.amp_level = None + + def _check_op_calls( + self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={} + ): + for op_type, value in expected_bf16_calls.items(): + self.assertEqual( + op_stats_dict[op_type].bf16_calls, + value, + f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.", + ) + for op_type, value in expected_fp16_calls.items(): + self.assertEqual( + op_stats_dict[op_type].fp16_calls, + value, + f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.", + ) + + def run_program( + self, + main_program, + startup_program, + optimizer, + feed_vars, + fetch_vars, + place, + exe, + x_np, + max_iters, + level, + ): + losses = [] + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup_program) + if level == 'O2': + optimizer.amp_init(place) + for iter_id in range(max_iters): + results = exe.run( + program=main_program, + feed={feed_vars[0].name: x_np}, + fetch_list=fetch_vars, + ) + print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}") + losses.append(results[0]) + return losses diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py new file mode 100644 index 00000000000..7d397c70432 --- /dev/null +++ b/test/amp/test_amp_api.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from amp_base_models import AmpTestBase + +import paddle + + +class TestAutoCast(AmpTestBase): + def test_amp_OD_level(self): + conv = paddle.nn.Conv2D( + in_channels=1, out_channels=6, kernel_size=3, bias_attr=False + ) + linear = paddle.nn.Linear(in_features=4, out_features=4) + with paddle.amp.auto_cast(level='OD'): + out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) + out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16') + out3 = linear(out2) + + self.assertEqual(out1.dtype, paddle.float16) + self.assertEqual(out2.dtype, paddle.float32) + self.assertEqual(out3.dtype, paddle.float32) + + +class TestGradScaler(AmpTestBase): + def test_amp_grad_scaler(self): + model = paddle.nn.Conv2D(3, 2, 3) + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=model.parameters() + ) + scaler = paddle.amp.GradScaler() + data = paddle.rand([1, 3, 8, 8], dtype='float32') + paddle.amp.debugging.enable_operator_stats_collection() + with paddle.amp.auto_cast( + custom_black_list=['conv2d'], dtype='bfloat16' + ): + out = model(data) + loss = out.mean() + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + optimizer.clear_grad() + paddle.amp.debugging.disable_operator_stats_collection() + op_list = paddle.fluid.core.get_low_precision_op_list() + + self.assertEqual(scaler._enable, False) + self.assertEqual(scaler._use_dynamic_loss_scaling, False) + self.assertTrue('scale' not in op_list) + self.assertTrue('check_finite_and_unscale' not in op_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/amp/test_amp_decorate.py b/test/amp/test_amp_decorate.py new file mode 100644 index 00000000000..1a77146cf1d --- /dev/null +++ b/test/amp/test_amp_decorate.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.nn.functional as F + + +class ConvBNLayer(paddle.nn.Layer): + def __init__( + self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + ): + super().__init__() + + self._conv = paddle.nn.Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + bias_attr=None, + ) + + self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + + return y + + +class Model(paddle.nn.Layer): + def __init__( + self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True + ): + super().__init__() + self.conv = ConvBNLayer(input_channel, 8, 3) + self.linear = paddle.nn.Linear(8, hidden_size) + self.layernorm = paddle.nn.Sequential( + paddle.nn.LayerNorm(hidden_size), + paddle.nn.LayerNorm(hidden_size), + ) + self.fp16_conv = fp16_conv + self.fp16_linear = fp16_linear + + def forward(self, inputs): + with paddle.amp.auto_cast(enable=self.fp16_conv): + if not self.fp16_conv: + inputs = inputs.astype('float32') + x = self.conv(inputs) + with paddle.amp.auto_cast(enable=self.fp16_linear): + if not self.fp16_linear: + x = x.astype('float32') + x = self.linear(x) + x = F.relu(x) + x = self.layernorm(x) + return x + + +class TestAMPDecorate(unittest.TestCase): + def check_results(self, fp32_layers=[], fp16_layers=[]): + for idx in range(len(fp32_layers)): + for layer in fp32_layers[idx].sublayers(include_self=False): + self.assertEqual(layer.weight.dtype, paddle.float32) + self.assertEqual(layer.bias.dtype, paddle.float32) + + for idx in range(len(fp16_layers)): + for layer in fp16_layers[idx].sublayers(include_self=False): + self.assertEqual(layer.weight.dtype, paddle.float16) + self.assertEqual(layer.bias.dtype, paddle.float16) + + def test_excluded_layers(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8, fp16_conv=False) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=model.conv, + ) + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) + self.check_results( + fp32_layers=[model.conv, model.layernorm], + fp16_layers=[model.linear], + ) + + def test_excluded_layers_attr_list(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8, fp16_conv=False, fp16_linear=False) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=[model.conv, model.linear], + ) + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) + + self.check_results( + fp32_layers=[model.conv, model.linear, model.layernorm] + ) + + def test_excluded_layers_attr_types(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=[paddle.nn.Conv2D, model.linear], + ) + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16')) + + self.check_results( + fp32_layers=[model.conv, model.linear, model.layernorm] + ) + + def test_excluded_layers_attr_none(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=None, + ) + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16')) + + self.check_results( + fp32_layers=[model.layernorm, model.conv._batch_norm], + fp16_layers=[model.conv._conv, model.linear], + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/amp/test_amp_list.py b/test/amp/test_amp_list.py index 11bcdbfd3ba..9b0bf5129c3 100644 --- a/test/amp/test_amp_list.py +++ b/test/amp/test_amp_list.py @@ -14,32 +14,63 @@ import unittest +import paddle from paddle.fluid import core -from paddle.static.amp import fp16_lists -from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists +from paddle.static.amp import AutoMixedPrecisionLists, fp16_lists class TestAMPList(unittest.TestCase): - def test_main(self): - custom_white_list = [ - 'lookup_table', - 'lookup_table_v2', - ] - amp_list = AutoMixedPrecisionLists(custom_white_list=custom_white_list) - for op in custom_white_list: - self.assertTrue(op in amp_list.white_list) - self.assertTrue(op not in amp_list.black_list) - self.assertTrue(op not in amp_list.unsupported_list) - - default_black_list = [ + def setUp(self): + self.default_black_list = [ 'linear_interp_v2', 'nearest_interp_v2', 'bilinear_interp_v2', 'bicubic_interp_v2', 'trilinear_interp_v2', ] - for op in default_black_list: - self.assertTrue(op in amp_list.black_list) + self.custom_white_list = [ + 'lookup_table', + 'lookup_table_v2', + ] + + def check_if_op_in_list(self, op_list, amp_list): + for op in op_list: + self.assertTrue(op in amp_list) + + def check_if_op_not_in_list(self, op_list, amp_list): + for op in op_list: + self.assertTrue(op not in amp_list) + + def test_static(self): + amp_list = AutoMixedPrecisionLists( + custom_white_list=self.custom_white_list + ) + self.check_if_op_in_list(self.default_black_list, amp_list.black_list) + self.check_if_op_in_list(self.custom_white_list, amp_list.white_list) + self.check_if_op_not_in_list( + self.custom_white_list, amp_list.black_list + ) + self.check_if_op_not_in_list( + self.custom_white_list, amp_list.unsupported_list + ) + + def test_eager(self): + if not paddle.amp.is_float16_supported(): + return + white_list = paddle.amp.white_list() + black_list = paddle.amp.black_list() + self.check_if_op_in_list( + self.default_black_list, black_list["float16"]["O2"] + ) + self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list) + with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}): + out1 = paddle.rand([2, 3]) + paddle.rand([2, 3]) + out2 = out1.mean() + out3 = paddle.log(out2) + self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list) + self.assertEqual(out1.dtype, paddle.float16) + self.assertEqual(out2.dtype, paddle.float32) + self.assertEqual(out3.dtype, paddle.float32) def test_apis(self): def _run_check_dtype(): diff --git a/test/amp/test_amp_promote.py b/test/amp/test_amp_promote.py new file mode 100644 index 00000000000..8aa5d71a79e --- /dev/null +++ b/test/amp/test_amp_promote.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from amp_base_models import AmpTestBase, build_conv_model + +import paddle +from paddle.static import amp + +paddle.enable_static() + + +class TestAMPPromote(AmpTestBase): + def check_promote_results( + self, use_amp, dtype, level, use_promote, expected_op_calls + ): + ( + main_program, + startup_program, + optimizer, + feed_vars, + fetch_vars, + ) = build_conv_model(use_amp, dtype, level, use_promote) + self.assertEqual(main_program.num_blocks, 1) + + amp.debugging.collect_operator_stats(main_program) + op_stats_list = amp.debugging._get_op_stats_list(main_program) + + self._check_op_calls( + op_stats_list[0], expected_fp16_calls=expected_op_calls + ) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + + max_iters = 2 + x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32") + print(main_program) + losses_o1 = self.run_program( + main_program, + startup_program, + optimizer, + feed_vars, + fetch_vars, + place, + exe, + x_fp32, + max_iters, + level, + ) + + def test_static_amp_o1(self): + expected_fp16_calls = { + "conv2d": 1, + "elementwise_add": 0, + "relu": 0, + "matmul_v2": 1, + "softmax": 0, + "reduce_mean": 0, + "adamw": 0, + } + self.check_promote_results( + True, + 'float16', + 'O1', + use_promote=True, + expected_op_calls=expected_fp16_calls, + ) + + def test_static_amp_o2(self): + expected_fp16_calls = { + "conv2d": 1, + "elementwise_add": 2, + "relu": 1, + "matmul_v2": 1, + "softmax": 1, + "reduce_mean": 1, + "adamw": 4, + } + self.check_promote_results( + True, + 'float16', + 'O2', + use_promote=True, + expected_op_calls=expected_fp16_calls, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/amp/test_model_cast_to_bf16.py b/test/amp/test_model_cast_to_bf16.py index c09c15e37d2..3002b623b18 100644 --- a/test/amp/test_model_cast_to_bf16.py +++ b/test/amp/test_model_cast_to_bf16.py @@ -17,7 +17,7 @@ import struct import unittest import numpy as np -from amp_base_models import build_add_model, build_embedding_model +from amp_base_models import AmpTestBase, build_add_model, build_embedding_model import paddle from paddle import fluid @@ -220,24 +220,30 @@ class TestModelCastBF16(unittest.TestCase): ) -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not complied with CUDA and not support the bfloat16", -) -class TestProgramBF16(unittest.TestCase): - def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls): - for op_type, value in expected_bf16_calls.items(): - self.assertEqual( - op_stats_dict[op_type].bf16_calls, - value, - f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.", - ) +class TestProgramBF16(AmpTestBase): + def _check_optimizer(self, program, expected_num_mp): + optimizers = [] + for block in program.blocks: + for op in block.ops: + if "Param" in op.input_names and "Grad" in op.input_names: + optimizers.append(op) + + actual_num_mp = 0 + for op in optimizers: + if op.has_attr("multi_precision") and op.attr("multi_precision"): + actual_num_mp += 1 + self.assertEqual( + actual_num_mp, + expected_num_mp, + f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.", + ) def test_amp_bf16_o1(self): main_program, startup_program = build_embedding_model( True, "bfloat16", "O1" ) self.assertEqual(main_program.num_blocks, 1) + self._check_optimizer(main_program, 0) amp.debugging.collect_operator_stats(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program) @@ -249,7 +255,7 @@ class TestProgramBF16(unittest.TestCase): "squared_l2_norm": 0, "adamw": 0, } - self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) + self._check_op_calls(op_stats_list[0], expected_bf16_calls) def test_amp_bf16_o2(self): main_program, startup_program = build_embedding_model( @@ -267,14 +273,15 @@ class TestProgramBF16(unittest.TestCase): "squared_l2_norm": 2, "adamw": 2, } - self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) + self._check_optimizer( + main_program, + expected_bf16_calls["matmul_v2"] + + expected_bf16_calls["elementwise_add"], + ) + self._check_op_calls(op_stats_list[0], expected_bf16_calls) -@unittest.skipIf( - not core.is_compiled_with_cuda(), - "core is not complied with CUDA and not support the bfloat16", -) -class TestStaticBF16(unittest.TestCase): +class TestStaticBF16(AmpTestBase): def _generate_feed_x(self): x = np.random.random(size=[16, 16]).astype("float32") x_bf16 = convert_float_to_uint16(x) @@ -282,60 +289,35 @@ class TestStaticBF16(unittest.TestCase): return x_fp32, x_bf16 def test_compare_o1_o2(self): - def _run_o1(exe, x_np, max_iters): + def _run(place, exe, x_np, max_iters, level): ( main_program, startup_program, optimizer, feed_vars, fetch_vars, - ) = build_add_model(True, "bfloat16", "O1") - - losses = [] - scope = paddle.static.Scope() - with paddle.static.scope_guard(scope): - exe.run(startup_program) - for iter_id in range(max_iters): - results = exe.run( - program=main_program, - feed={feed_vars[0].name: x_np}, - fetch_list=fetch_vars, - ) - print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}") - losses.append(results[0]) - return losses + ) = build_add_model(True, "bfloat16", level) - def _run_o2(exe, x_np, max_iters): - ( + losses = self.run_program( main_program, startup_program, optimizer, feed_vars, fetch_vars, - ) = build_add_model(True, "bfloat16", "O2") - - losses = [] - scope = paddle.static.Scope() - with paddle.static.scope_guard(scope): - exe.run(startup_program) - optimizer.amp_init(place) - for iter_id in range(max_iters): - results = exe.run( - program=main_program, - feed={feed_vars[0].name: x_np}, - fetch_list=fetch_vars, - ) - print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}") - losses.append(results[0]) + place, + exe, + x_np, + max_iters, + level, + ) return losses - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - max_iters = 2 x_fp32, x_bf16 = self._generate_feed_x() - losses_o1 = _run_o1(exe, x_fp32, max_iters) - losses_o2 = _run_o2(exe, x_bf16, max_iters) + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1') + losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2') if __name__ == '__main__': diff --git a/test/contrib/test_image_classification_fp16.py b/test/contrib/test_image_classification_fp16.py index 48bb126431d..fb1bafdc861 100644 --- a/test/contrib/test_image_classification_fp16.py +++ b/test/contrib/test_image_classification_fp16.py @@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase): # infer(use_cuda, save_dirname) def test_amp_lists(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) @@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase): self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_1(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) @@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase): self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_2(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) @@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase): self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_3(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) @@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase): self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_4(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) @@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase): self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_5(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) @@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase): self.assertEqual(amp_lists.gray_list, gray_list) def test_amp_lists_6(self): - white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) + white_list = ( + copy.copy(paddle.static.amp.fp16_lists.white_list) + | paddle.static.amp.fp16_lists._only_supported_fp16_list + ) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) diff --git a/test/ir/test_fuse_resnet_unit.py b/test/ir/test_fuse_resnet_unit.py index bcadccf5fd6..d76a806c0c8 100644 --- a/test/ir/test_fuse_resnet_unit.py +++ b/test/ir/test_fuse_resnet_unit.py @@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase): startup_program = paddle.static.Program() with paddle.static.amp.fp16_guard(): with paddle.static.program_guard(program, startup_program): - x = paddle.static.data("x", [1, 64, 64, 8]) + x = paddle.static.data("x", [1, 64, 64, 8], dtype="float16") conv2d = paddle.nn.Conv2D( 8, 32, 1, bias_attr=False, data_format='NHWC' ) @@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase): np.testing.assert_allclose( before_out[0], after_out[0], rtol=1e-05, atol=0.005 ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/standalone_executor/test_standalone_executor_aot_choose_kernel.py b/test/standalone_executor/test_standalone_executor_aot_choose_kernel.py index 8ac34385936..334ac8ffc1a 100644 --- a/test/standalone_executor/test_standalone_executor_aot_choose_kernel.py +++ b/test/standalone_executor/test_standalone_executor_aot_choose_kernel.py @@ -25,10 +25,10 @@ paddle.enable_static() def build_resnet50(use_amp=False): main_program = paddle.static.Program() startup_program = paddle.static.Program() - + dtype = 'float16' if use_amp else 'float32' with paddle.static.program_guard(main_program, startup_program): image = paddle.static.data( - name='image', shape=[32, 3, 224, 224], dtype='float32' + name='image', shape=[32, 3, 224, 224], dtype=dtype ) label = paddle.static.data(name='label', shape=[32], dtype='int64') model = paddle.vision.models.resnet50() -- GitLab