未验证 提交 534efcb6 编写于 作者: Z Zhang Ting 提交者: GitHub

support excluded_layers for amp.decorate (#52871)

上级 864aa75d
...@@ -199,47 +199,95 @@ def _is_gpu_bfloat16_supported(): ...@@ -199,47 +199,95 @@ def _is_gpu_bfloat16_supported():
return prop[0] >= 8 and cuda_version_check return prop[0] >= 8 and cuda_version_check
def need_keep_fp32(layer, dtype):
need_keep_fp32 = False
# Highest prority. Because all the layers except BN will use bfloat16 params in bfoat16 training,
# here we provide a option to keep fp32 param.
if not layer._cast_to_low_precison:
need_keep_fp32 = True
# The BN layers will keep fp32
elif isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.SyncBatchNorm,
),
):
need_keep_fp32 = True
# layer._dtype is used to set params dtype. BF16 will use bf16 params.
elif (layer._dtype == 'float16') or (
(dtype == 'float16')
and isinstance(
layer,
(
paddle.nn.LayerNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
)
):
need_keep_fp32 = True
return need_keep_fp32
def set_excluded_layers(models, excluded_layers):
excluded_layers_instances = []
excluded_layers_types = []
error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types."
if excluded_layers is None:
excluded_layers = []
elif isinstance(excluded_layers, paddle.nn.Layer):
excluded_layers_instances = [excluded_layers]
elif isinstance(excluded_layers, type) and issubclass(
excluded_layers, paddle.nn.Layer
):
excluded_layers_types = [excluded_layers]
elif isinstance(excluded_layers, list):
for item in excluded_layers:
if isinstance(item, paddle.nn.Layer):
excluded_layers_instances.append(item)
elif issubclass(item, paddle.nn.Layer):
excluded_layers_types.append(item)
else:
raise TypeError(error_message)
else:
raise TypeError(error_message)
for idx in range(len(excluded_layers_instances)):
for layer in excluded_layers_instances[idx].sublayers(
include_self=True
):
layer._cast_to_low_precison = False
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
if type(layer) in excluded_layers_types:
layer._cast_to_low_precison = False
@dygraph_only @dygraph_only
def pure_fp16_initialize(models): def amp_initialize(models, dtype, excluded_layers):
set_excluded_layers(models, excluded_layers)
for idx in range(len(models)): for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True): for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True if need_keep_fp32(layer, dtype):
if (layer._dtype == 'float16') or isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm,
paddle.nn.SyncBatchNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
):
continue continue
if isinstance( if dtype == "float16" and isinstance(
layer, layer,
( (
paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention, paddle.incubate.nn.FusedMultiHeadAttention,
), ),
): ):
layer._amp_decorate(dtype='float16') layer._amp_decorate(dtype=dtype)
continue continue
layer._to_impl(
dtype='float16', include_sublayers=False, floating_only=True
)
return models
@dygraph_only
def pure_bf16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._to_impl( layer._to_impl(
dtype='bfloat16', include_sublayers=False, floating_only=True dtype=dtype, include_sublayers=False, floating_only=True
) )
return models return models
...@@ -522,6 +570,7 @@ def amp_decorate( ...@@ -522,6 +570,7 @@ def amp_decorate(
master_weight=None, master_weight=None,
save_dtype=None, save_dtype=None,
master_grad=False, master_grad=False,
excluded_layers=None,
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
...@@ -590,6 +639,8 @@ def amp_decorate( ...@@ -590,6 +639,8 @@ def amp_decorate(
raise ValueError( raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
) )
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype only support float16 or bfloat16.")
if level == 'O1': if level == 'O1':
if optimizers is None: if optimizers is None:
...@@ -609,12 +660,9 @@ def amp_decorate( ...@@ -609,12 +660,9 @@ def amp_decorate(
raise TypeError( raise TypeError(
"models must be either a single model or a list of models." "models must be either a single model or a list of models."
) )
if dtype == 'float16':
models = pure_fp16_initialize(models=models) # initialize parameters of the model.
elif dtype == 'bfloat16': amp_initialize(models=models, dtype=dtype, excluded_layers=excluded_layers)
models = pure_bf16_initialize(models=models)
else:
raise TypeError("dtype only support float16 or bfloat16.")
if optimizers is not None: if optimizers is not None:
# check optimizers # check optimizers
...@@ -741,6 +789,7 @@ def decorate( ...@@ -741,6 +789,7 @@ def decorate(
master_weight=None, master_weight=None,
save_dtype=None, save_dtype=None,
master_grad=False, master_grad=False,
excluded_layers=None,
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
...@@ -757,8 +806,10 @@ def decorate( ...@@ -757,8 +806,10 @@ def decorate(
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight
gradients will be FP32 dtype after the backpropagation. Default is False. gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients.
excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as
an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16.
Examples: Examples:
...@@ -808,5 +859,12 @@ def decorate( ...@@ -808,5 +859,12 @@ def decorate(
print(output.dtype) # FP16 print(output.dtype) # FP16
""" """
return amp_decorate( return amp_decorate(
models, optimizers, level, dtype, master_weight, save_dtype, master_grad models,
optimizers,
level,
dtype,
master_weight,
save_dtype,
master_grad,
excluded_layers,
) )
...@@ -401,7 +401,8 @@ class Layer: ...@@ -401,7 +401,8 @@ class Layer:
self._forward_pre_hooks = collections.OrderedDict() self._forward_pre_hooks = collections.OrderedDict()
self._forward_post_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict()
self._casted_by_pure_fp16 = False # only used in AMP Training
self._cast_to_low_precison = True
self._state_dict_hooks = collections.OrderedDict() self._state_dict_hooks = collections.OrderedDict()
# Records orignal functions after @to_static to support to rollback # Records orignal functions after @to_static to support to rollback
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn.functional as F
class ConvBNLayer(paddle.nn.Layer):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
):
super().__init__()
self._conv = paddle.nn.Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=None,
)
self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class Model(paddle.nn.Layer):
def __init__(
self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True
):
super().__init__()
self.conv = ConvBNLayer(input_channel, 8, 3)
self.linear = paddle.nn.Linear(8, hidden_size)
self.layernorm = paddle.nn.Sequential(
paddle.nn.LayerNorm(hidden_size),
paddle.nn.LayerNorm(hidden_size),
)
self.fp16_conv = fp16_conv
self.fp16_linear = fp16_linear
def forward(self, inputs):
with paddle.amp.auto_cast(enable=self.fp16_conv):
if not self.fp16_conv:
inputs = inputs.astype('float32')
x = self.conv(inputs)
with paddle.amp.auto_cast(enable=self.fp16_linear):
if not self.fp16_linear:
x = x.astype('float32')
x = self.linear(x)
x = F.relu(x)
x = self.layernorm(x)
return x
class TestAMPDecorate(unittest.TestCase):
def check_results(self, fp32_layers=[], fp16_layers=[]):
for idx in range(len(fp32_layers)):
for layer in fp32_layers[idx].sublayers(include_self=False):
self.assertEqual(layer.weight.dtype, paddle.float32)
self.assertEqual(layer.bias.dtype, paddle.float32)
for idx in range(len(fp16_layers)):
for layer in fp16_layers[idx].sublayers(include_self=False):
self.assertEqual(layer.weight.dtype, paddle.float16)
self.assertEqual(layer.bias.dtype, paddle.float16)
def test_excluded_layers(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8, fp16_conv=False)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=model.conv,
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32'))
self.check_results(
fp32_layers=[model.conv, model.layernorm],
fp16_layers=[model.linear],
)
def test_excluded_layers_attr_list(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8, fp16_conv=False, fp16_linear=False)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=[model.conv, model.linear],
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32'))
self.check_results(
fp32_layers=[model.conv, model.linear, model.layernorm]
)
def test_excluded_layers_attr_types(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=[paddle.nn.Conv2D, model.linear],
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16'))
self.check_results(
fp32_layers=[model.conv, model.linear, model.layernorm]
)
def test_excluded_layers_attr_none(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=None,
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16'))
self.check_results(
fp32_layers=[model.layernorm, model.conv._batch_norm],
fp16_layers=[model.conv._conv, model.linear],
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册