diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index bc76f866d94eb2ef363fcb22688dc46027aa6798..e8f552607affc73c0c379644d18867a12622842b 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -199,47 +199,95 @@ def _is_gpu_bfloat16_supported(): return prop[0] >= 8 and cuda_version_check +def need_keep_fp32(layer, dtype): + need_keep_fp32 = False + # Highest prority. Because all the layers except BN will use bfloat16 params in bfoat16 training, + # here we provide a option to keep fp32 param. + if not layer._cast_to_low_precison: + need_keep_fp32 = True + # The BN layers will keep fp32 + elif isinstance( + layer, + ( + paddle.nn.BatchNorm, + paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, + paddle.nn.BatchNorm3D, + paddle.nn.SyncBatchNorm, + ), + ): + need_keep_fp32 = True + # layer._dtype is used to set params dtype. BF16 will use bf16 params. + elif (layer._dtype == 'float16') or ( + (dtype == 'float16') + and isinstance( + layer, + ( + paddle.nn.LayerNorm, + paddle.nn.InstanceNorm1D, + paddle.nn.InstanceNorm2D, + paddle.nn.InstanceNorm3D, + ), + ) + ): + need_keep_fp32 = True + + return need_keep_fp32 + + +def set_excluded_layers(models, excluded_layers): + excluded_layers_instances = [] + excluded_layers_types = [] + error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types." + if excluded_layers is None: + excluded_layers = [] + elif isinstance(excluded_layers, paddle.nn.Layer): + excluded_layers_instances = [excluded_layers] + elif isinstance(excluded_layers, type) and issubclass( + excluded_layers, paddle.nn.Layer + ): + excluded_layers_types = [excluded_layers] + elif isinstance(excluded_layers, list): + for item in excluded_layers: + if isinstance(item, paddle.nn.Layer): + excluded_layers_instances.append(item) + elif issubclass(item, paddle.nn.Layer): + excluded_layers_types.append(item) + else: + raise TypeError(error_message) + else: + raise TypeError(error_message) + + for idx in range(len(excluded_layers_instances)): + for layer in excluded_layers_instances[idx].sublayers( + include_self=True + ): + layer._cast_to_low_precison = False + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + if type(layer) in excluded_layers_types: + layer._cast_to_low_precison = False + + @dygraph_only -def pure_fp16_initialize(models): +def amp_initialize(models, dtype, excluded_layers): + set_excluded_layers(models, excluded_layers) for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): - layer._casted_by_pure_fp16 = True - if (layer._dtype == 'float16') or isinstance( - layer, - ( - paddle.nn.BatchNorm, - paddle.nn.BatchNorm1D, - paddle.nn.BatchNorm2D, - paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, - paddle.nn.SyncBatchNorm, - paddle.nn.InstanceNorm1D, - paddle.nn.InstanceNorm2D, - paddle.nn.InstanceNorm3D, - ), - ): + if need_keep_fp32(layer, dtype): continue - if isinstance( + if dtype == "float16" and isinstance( layer, ( paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedMultiHeadAttention, ), ): - layer._amp_decorate(dtype='float16') + layer._amp_decorate(dtype=dtype) continue - layer._to_impl( - dtype='float16', include_sublayers=False, floating_only=True - ) - return models - -@dygraph_only -def pure_bf16_initialize(models): - for idx in range(len(models)): - for layer in models[idx].sublayers(include_self=True): layer._to_impl( - dtype='bfloat16', include_sublayers=False, floating_only=True + dtype=dtype, include_sublayers=False, floating_only=True ) return models @@ -522,6 +570,7 @@ def amp_decorate( master_weight=None, save_dtype=None, master_grad=False, + excluded_layers=None, ): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. @@ -590,6 +639,8 @@ def amp_decorate( raise ValueError( "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." ) + if not (dtype in ['float16', 'bfloat16']): + raise ValueError("dtype only support float16 or bfloat16.") if level == 'O1': if optimizers is None: @@ -609,12 +660,9 @@ def amp_decorate( raise TypeError( "models must be either a single model or a list of models." ) - if dtype == 'float16': - models = pure_fp16_initialize(models=models) - elif dtype == 'bfloat16': - models = pure_bf16_initialize(models=models) - else: - raise TypeError("dtype only support float16 or bfloat16.") + + # initialize parameters of the model. + amp_initialize(models=models, dtype=dtype, excluded_layers=excluded_layers) if optimizers is not None: # check optimizers @@ -741,6 +789,7 @@ def decorate( master_weight=None, save_dtype=None, master_grad=False, + excluded_layers=None, ): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. @@ -757,8 +806,10 @@ def decorate( master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. - master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight - gradients will be FP32 dtype after the backpropagation. Default is False. + master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight + gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients. + excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as + an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16. Examples: @@ -808,5 +859,12 @@ def decorate( print(output.dtype) # FP16 """ return amp_decorate( - models, optimizers, level, dtype, master_weight, save_dtype, master_grad + models, + optimizers, + level, + dtype, + master_weight, + save_dtype, + master_grad, + excluded_layers, ) diff --git a/python/paddle/nn/layer/layers.py b/python/paddle/nn/layer/layers.py index 0babc935f1de798383e6c69794cb1d090db70431..29a6f49b5dc0bf242c3a90f55620069c30312b1b 100644 --- a/python/paddle/nn/layer/layers.py +++ b/python/paddle/nn/layer/layers.py @@ -401,7 +401,8 @@ class Layer: self._forward_pre_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict() - self._casted_by_pure_fp16 = False + # only used in AMP Training + self._cast_to_low_precison = True self._state_dict_hooks = collections.OrderedDict() # Records orignal functions after @to_static to support to rollback diff --git a/test/amp/test_amp_decorate.py b/test/amp/test_amp_decorate.py new file mode 100644 index 0000000000000000000000000000000000000000..1a77146cf1de78025eb339d33b0ac453b8923948 --- /dev/null +++ b/test/amp/test_amp_decorate.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.nn.functional as F + + +class ConvBNLayer(paddle.nn.Layer): + def __init__( + self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + ): + super().__init__() + + self._conv = paddle.nn.Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + bias_attr=None, + ) + + self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + + return y + + +class Model(paddle.nn.Layer): + def __init__( + self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True + ): + super().__init__() + self.conv = ConvBNLayer(input_channel, 8, 3) + self.linear = paddle.nn.Linear(8, hidden_size) + self.layernorm = paddle.nn.Sequential( + paddle.nn.LayerNorm(hidden_size), + paddle.nn.LayerNorm(hidden_size), + ) + self.fp16_conv = fp16_conv + self.fp16_linear = fp16_linear + + def forward(self, inputs): + with paddle.amp.auto_cast(enable=self.fp16_conv): + if not self.fp16_conv: + inputs = inputs.astype('float32') + x = self.conv(inputs) + with paddle.amp.auto_cast(enable=self.fp16_linear): + if not self.fp16_linear: + x = x.astype('float32') + x = self.linear(x) + x = F.relu(x) + x = self.layernorm(x) + return x + + +class TestAMPDecorate(unittest.TestCase): + def check_results(self, fp32_layers=[], fp16_layers=[]): + for idx in range(len(fp32_layers)): + for layer in fp32_layers[idx].sublayers(include_self=False): + self.assertEqual(layer.weight.dtype, paddle.float32) + self.assertEqual(layer.bias.dtype, paddle.float32) + + for idx in range(len(fp16_layers)): + for layer in fp16_layers[idx].sublayers(include_self=False): + self.assertEqual(layer.weight.dtype, paddle.float16) + self.assertEqual(layer.bias.dtype, paddle.float16) + + def test_excluded_layers(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8, fp16_conv=False) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=model.conv, + ) + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) + self.check_results( + fp32_layers=[model.conv, model.layernorm], + fp16_layers=[model.linear], + ) + + def test_excluded_layers_attr_list(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8, fp16_conv=False, fp16_linear=False) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=[model.conv, model.linear], + ) + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) + + self.check_results( + fp32_layers=[model.conv, model.linear, model.layernorm] + ) + + def test_excluded_layers_attr_types(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=[paddle.nn.Conv2D, model.linear], + ) + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16')) + + self.check_results( + fp32_layers=[model.conv, model.linear, model.layernorm] + ) + + def test_excluded_layers_attr_none(self): + if not paddle.amp.is_float16_supported(): + return + model = Model(4, 8) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='float16', + excluded_layers=None, + ) + + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16')) + + self.check_results( + fp32_layers=[model.layernorm, model.conv._batch_norm], + fp16_layers=[model.conv._conv, model.linear], + ) + + +if __name__ == '__main__': + unittest.main()