未验证 提交 4d7e9b55 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] Cherry-pick AMP (#53442)

 Cherry-pick AMP 
上级 ec849efd
...@@ -39,3 +39,4 @@ per-file-ignores = ...@@ -39,3 +39,4 @@ per-file-ignores =
.cmake-format.py: F821 .cmake-format.py: F821
test/dygraph_to_static/test_loop.py: F821 test/dygraph_to_static/test_loop.py: F821
test/dygraph_to_static/test_closure_analysis.py: F821 test/dygraph_to_static/test_closure_analysis.py: F821
python/paddle/static/amp/decorator.py: F811
...@@ -131,7 +131,11 @@ inline phi::DataType GetAmpDestDtype( ...@@ -131,7 +131,11 @@ inline phi::DataType GetAmpDestDtype(
->count(op_name)) { ->count(op_name)) {
dst_type = phi::DataType::FLOAT32; dst_type = phi::DataType::FLOAT32;
} else { } else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); if (amp_level == paddle::imperative::AmpLevel::OD) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
} }
if (dst_type == amp_setting_dtype && if (dst_type == amp_setting_dtype &&
......
...@@ -31,6 +31,7 @@ enum class AmpLevel { ...@@ -31,6 +31,7 @@ enum class AmpLevel {
O1, // amp, mixed fp32-fp16 O1, // amp, mixed fp32-fp16
O2, // almost fp16 O2, // almost fp16
O3, // fp16 O3, // fp16
OD, // only conv and matmul use low precison.
}; };
std::tuple<std::unordered_set<std::string>, std::tuple<std::unordered_set<std::string>,
......
...@@ -2154,6 +2154,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -2154,6 +2154,7 @@ void BindImperative(py::module *m_ptr) {
py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic()) py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic())
.value("O0", paddle::imperative::AmpLevel::O0) .value("O0", paddle::imperative::AmpLevel::O0)
.value("OD", paddle::imperative::AmpLevel::OD)
.value("O1", paddle::imperative::AmpLevel::O1) .value("O1", paddle::imperative::AmpLevel::O1)
.value("O2", paddle::imperative::AmpLevel::O2) .value("O2", paddle::imperative::AmpLevel::O2)
.value("O3", paddle::imperative::AmpLevel::O3) .value("O3", paddle::imperative::AmpLevel::O3)
......
...@@ -16,10 +16,8 @@ from .auto_cast import auto_cast # noqa: F401 ...@@ -16,10 +16,8 @@ from .auto_cast import auto_cast # noqa: F401
from .auto_cast import decorate # noqa: F401 from .auto_cast import decorate # noqa: F401
from .auto_cast import amp_guard # noqa: F401 from .auto_cast import amp_guard # noqa: F401
from .auto_cast import amp_decorate # noqa: F401 from .auto_cast import amp_decorate # noqa: F401
from .auto_cast import FP16_WHITE_LIST # noqa: F401 from .amp_lists import white_list # noqa: F401
from .auto_cast import FP16_BLACK_LIST # noqa: F401 from .amp_lists import black_list # noqa: F401
from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401
from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401
from . import grad_scaler # noqa: F401 from . import grad_scaler # noqa: F401
from .grad_scaler import GradScaler # noqa: F401 from .grad_scaler import GradScaler # noqa: F401
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
'tan',
'acos',
'asin',
'sinh',
'cosh',
'atanh',
'tanh_shrink',
'cos_sim',
'erfinv',
'exp',
'expm1',
'log',
'log10',
'log2',
'reciprocal',
'rsqrt',
'pow',
'square',
'reduce_sum',
'mean',
'reduce_mean',
'reduce_prod',
'cumprod',
'cumsum',
'dist',
'pnorm',
'frobenius_norm',
'renorm',
'group_norm',
'layer_norm',
'softmax',
'softmin',
'softplus',
'log_softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
'nll_loss',
'huber_loss',
'triplet_margin_loss',
'log_loss',
'hsigmoid_loss',
'margin_cross_entropy',
}
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
FP16_EXTRA_BLACK_LIST = {
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
'lookup_table',
'lookup_table_v2',
'scatter',
'depthwise_conv2d',
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
def white_list():
white_list = {
"float16": {
"OD": FP16_WHITE_LIST,
"O1": FP16_WHITE_LIST,
"O2": FP16_WHITE_LIST,
},
"bfloat16": {
"OD": BF16_WHITE_LIST,
"O1": BF16_WHITE_LIST,
"O2": BF16_WHITE_LIST,
},
}
return white_list
def black_list():
black_list = {
"float16": {
"OD": set(),
"O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST,
"O2": FP16_EXTRA_BLACK_LIST,
},
"bfloat16": {"OD": set(), "O1": BF16_BLACK_LIST, "O2": set()},
}
return black_list
...@@ -20,45 +20,7 @@ from paddle.fluid import core ...@@ -20,45 +20,7 @@ from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer, dygraph_only from paddle.fluid.framework import _dygraph_tracer, dygraph_only
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
AMP_LEVEL = core.AmpLevel from .amp_lists import black_list, white_list
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
AMP_RELATED_FLAGS = [ AMP_RELATED_FLAGS = [
'FLAGS_cudnn_exhaustive_search', 'FLAGS_cudnn_exhaustive_search',
...@@ -72,27 +34,7 @@ AMP_RELATED_FLAGS_SETTING = { ...@@ -72,27 +34,7 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
} }
PURE_FP16_WHITE_LIST = copy.copy(FP16_WHITE_LIST) AMP_LEVEL = core.AmpLevel
PURE_FP16_BLACK_LIST = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
PURE_BF16_WHITE_LIST = copy.copy(BF16_WHITE_LIST)
PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None _g_amp_state_ = None
...@@ -106,6 +48,7 @@ class AMPGlobalState: ...@@ -106,6 +48,7 @@ class AMPGlobalState:
self.model_parameters = [] self.model_parameters = []
self.use_master_grad = False self.use_master_grad = False
self.already_register_final_backward_hook = False self.already_register_final_backward_hook = False
self.amp_dtype = 'float32'
def __setattr__(self, name, val): def __setattr__(self, name, val):
self.__dict__[name] = val self.__dict__[name] = val
...@@ -126,20 +69,12 @@ def _update_list( ...@@ -126,20 +69,12 @@ def _update_list(
""" """
Update black and white list according to users' custom list. Update black and white list according to users' custom list.
""" """
if dtype == 'float16': if level == 'O0':
if level == 'O1': _white_list = set()
_white_list = copy.copy(FP16_WHITE_LIST) _black_list = set()
_black_list = copy.copy(FP16_BLACK_LIST) return _white_list, _black_list
else: _white_list = copy.copy(white_list()[dtype][level])
_white_list = copy.copy(PURE_FP16_WHITE_LIST) _black_list = copy.copy(black_list()[dtype][level])
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else:
if level == 'O1':
_white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(BF16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_BF16_WHITE_LIST)
_black_list = copy.copy(PURE_BF16_BLACK_LIST)
if custom_white_list and custom_black_list: if custom_white_list and custom_black_list:
for op_name in custom_white_list: for op_name in custom_white_list:
if op_name in custom_black_list: if op_name in custom_black_list:
...@@ -199,47 +134,95 @@ def _is_gpu_bfloat16_supported(): ...@@ -199,47 +134,95 @@ def _is_gpu_bfloat16_supported():
return prop[0] >= 8 and cuda_version_check return prop[0] >= 8 and cuda_version_check
def need_keep_fp32(layer, dtype):
need_keep_fp32 = False
# Highest prority. Because all the layers except BN will use bfloat16 params in bfoat16 training,
# here we provide a option to keep fp32 param.
if not layer._cast_to_low_precison:
need_keep_fp32 = True
# The BN layers will keep fp32
elif isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.SyncBatchNorm,
),
):
need_keep_fp32 = True
# layer._dtype is used to set params dtype. BF16 will use bf16 params.
elif (layer._dtype == 'float16') or (
(dtype == 'float16')
and isinstance(
layer,
(
paddle.nn.LayerNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
)
):
need_keep_fp32 = True
return need_keep_fp32
def set_excluded_layers(models, excluded_layers):
excluded_layers_instances = []
excluded_layers_types = []
error_message = "excluded_layers must be either a nn.Layer instance/type or a list of nn.Layer instances/types."
if excluded_layers is None:
excluded_layers = []
elif isinstance(excluded_layers, paddle.nn.Layer):
excluded_layers_instances = [excluded_layers]
elif isinstance(excluded_layers, type) and issubclass(
excluded_layers, paddle.nn.Layer
):
excluded_layers_types = [excluded_layers]
elif isinstance(excluded_layers, list):
for item in excluded_layers:
if isinstance(item, paddle.nn.Layer):
excluded_layers_instances.append(item)
elif issubclass(item, paddle.nn.Layer):
excluded_layers_types.append(item)
else:
raise TypeError(error_message)
else:
raise TypeError(error_message)
for idx in range(len(excluded_layers_instances)):
for layer in excluded_layers_instances[idx].sublayers(
include_self=True
):
layer._cast_to_low_precison = False
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
if type(layer) in excluded_layers_types:
layer._cast_to_low_precison = False
@dygraph_only @dygraph_only
def pure_fp16_initialize(models): def amp_initialize(models, dtype, excluded_layers):
set_excluded_layers(models, excluded_layers)
for idx in range(len(models)): for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True): for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True if need_keep_fp32(layer, dtype):
if (layer._dtype == 'float16') or isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm,
paddle.nn.SyncBatchNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
),
):
continue continue
if isinstance( if dtype == "float16" and isinstance(
layer, layer,
( (
paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention, paddle.incubate.nn.FusedMultiHeadAttention,
), ),
): ):
layer._amp_decorate(dtype='float16') layer._amp_decorate(dtype=dtype)
continue continue
layer._to_impl(
dtype='float16', include_sublayers=False, floating_only=True
)
return models
@dygraph_only
def pure_bf16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._to_impl( layer._to_impl(
dtype='bfloat16', include_sublayers=False, floating_only=True dtype=dtype, include_sublayers=False, floating_only=True
) )
return models return models
...@@ -338,10 +321,8 @@ def amp_guard( ...@@ -338,10 +321,8 @@ def amp_guard(
# check amp_level: O0-O2 # check amp_level: O0-O2
level = level.upper() level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError( raise ValueError("level should be O0, OD, O1 or O2.")
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
# check amp_dtype: float16 or bfloat16 # check amp_dtype: float16 or bfloat16
dtype = dtype.lower() dtype = dtype.lower()
...@@ -402,37 +383,20 @@ def amp_guard( ...@@ -402,37 +383,20 @@ def amp_guard(
) )
amp_dtype = dtype amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype
if level == 'O1': if level == 'OD':
amp_level = AMP_LEVEL.OD
elif level == 'O1':
amp_level = AMP_LEVEL.O1 amp_level = AMP_LEVEL.O1
if dtype == 'float16':
_white_list = FP16_WHITE_LIST
_black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O2': elif level == 'O2':
amp_level = AMP_LEVEL.O2 amp_level = AMP_LEVEL.O2
if dtype == 'float16':
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O0': elif level == 'O0':
amp_level = AMP_LEVEL.O0 amp_level = AMP_LEVEL.O0
if dtype == 'float16':
_white_list = FP16_WHITE_LIST _white_list, _black_list = _update_list(
_black_list = FP16_BLACK_LIST custom_white_list, custom_black_list, level, dtype
elif dtype == 'bfloat16': )
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(
custom_white_list, custom_black_list, level, dtype
)
if not enable: if not enable:
amp_level = AMP_LEVEL.O0 amp_level = AMP_LEVEL.O0
...@@ -522,6 +486,7 @@ def amp_decorate( ...@@ -522,6 +486,7 @@ def amp_decorate(
master_weight=None, master_weight=None,
save_dtype=None, save_dtype=None,
master_grad=False, master_grad=False,
excluded_layers=None,
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
...@@ -590,6 +555,8 @@ def amp_decorate( ...@@ -590,6 +555,8 @@ def amp_decorate(
raise ValueError( raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
) )
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype only support float16 or bfloat16.")
if level == 'O1': if level == 'O1':
if optimizers is None: if optimizers is None:
...@@ -609,12 +576,9 @@ def amp_decorate( ...@@ -609,12 +576,9 @@ def amp_decorate(
raise TypeError( raise TypeError(
"models must be either a single model or a list of models." "models must be either a single model or a list of models."
) )
if dtype == 'float16':
models = pure_fp16_initialize(models=models) # initialize parameters of the model.
elif dtype == 'bfloat16': amp_initialize(models=models, dtype=dtype, excluded_layers=excluded_layers)
models = pure_bf16_initialize(models=models)
else:
raise TypeError("dtype only support float16 or bfloat16.")
if optimizers is not None: if optimizers is not None:
# check optimizers # check optimizers
...@@ -680,22 +644,24 @@ def auto_cast( ...@@ -680,22 +644,24 @@ def auto_cast(
): ):
""" """
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided If enabled, the input data type (float32, float16 or bfloat16) of each operator is decided
by autocast algorithm for better performance. by autocast algorithm for better performance.
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in Commonly, it is used together with `GradScaler` and `decorator` to achieve Auto-Mixed-Precision in
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. imperative mode.
Args: Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True. enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support custom_white_list(set|list|tuple, optional): A default white list is already set. Usually there is no need to set custom white list.
fp16 calculation and are considered numerically-safe and performance-critical. These ops The set of ops should be considered numerically-safe and performance-critical. These ops will be converted to float16/bfloat16.
will be converted to fp16. custom_black_list(set|list|tuple, optional): A default black list is already set. You can set a custom black list according to the model.
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 The set of ops are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be
calculation and are considered numerically-dangerous and whose effects may also be converted to float16/bfloat16.
observed in downstream ops. These ops will not be converted to fp16. level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples: Examples:
...@@ -741,6 +707,7 @@ def decorate( ...@@ -741,6 +707,7 @@ def decorate(
master_weight=None, master_weight=None,
save_dtype=None, save_dtype=None,
master_grad=False, master_grad=False,
excluded_layers=None,
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
...@@ -757,8 +724,10 @@ def decorate( ...@@ -757,8 +724,10 @@ def decorate(
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
master_grad(bool, optional): For level='O2', whether to use FP32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If it is enabled, the weight master_grad(bool, optional): For level='O2', whether to use float32 weight gradients for calculations such as gradient clipping, weight decay, and weight updates. If master_grad is enabled, the weight
gradients will be FP32 dtype after the backpropagation. Default is False. gradients will be float32 dtype after the backpropagation. Default is False, there is only float16 weight gradients.
excluded_layers(Layer|list of Layer, optional): Specify the layers not to be decorated. The weights of these layers will always keep float32 when level is O2. `excluded_layers` can be specified as
an Layer instance/type or a list of Layer instances/types. Default is None, the weights of the whole model will be casted to float16 or bfloat16.
Examples: Examples:
...@@ -808,5 +777,12 @@ def decorate( ...@@ -808,5 +777,12 @@ def decorate(
print(output.dtype) # FP16 print(output.dtype) # FP16
""" """
return amp_decorate( return amp_decorate(
models, optimizers, level, dtype, master_weight, save_dtype, master_grad models,
optimizers,
level,
dtype,
master_weight,
save_dtype,
master_grad,
excluded_layers,
) )
...@@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type ...@@ -24,6 +24,8 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import _dygraph_tracer, dygraph_only from paddle.fluid.framework import _dygraph_tracer, dygraph_only
from .auto_cast import amp_global_state
class OptimizerState(Enum): class OptimizerState(Enum):
INIT = 0 INIT = 0
...@@ -179,6 +181,18 @@ class AmpScaler: ...@@ -179,6 +181,18 @@ class AmpScaler:
""" """
check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()') check_type(var, "var", core.eager.Tensor, 'AmpScaler.scale()')
if (
self._enable
and amp_global_state().amp_dtype != 'float16'
and self._use_dynamic_loss_scaling
):
self._enable = False
self._use_dynamic_loss_scaling = False
warnings.warn(
'It is not recommended to use dynamic loss scaling for %s, so GradScaler is disable by default.'
% (amp_global_state().amp_dtype)
)
if not self._enable: if not self._enable:
return var return var
......
...@@ -4982,8 +4982,8 @@ class PipelineOptimizer: ...@@ -4982,8 +4982,8 @@ class PipelineOptimizer:
device = post_op.attr(self._op_device_key) device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set." assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or op.type == "scale") and self._is_backward_op( elif (op.type == "cast" or op.type == "scale") and (
op self._is_backward_op(op) or self._is_forward_op(op)
): ):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
......
...@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase): ...@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase):
def custom_op_list(self): def custom_op_list(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer() tracer = fluid.framework._dygraph_tracer()
base_white_list = paddle.amp.FP16_WHITE_LIST base_white_list = paddle.amp.white_list()["float16"]["O1"]
base_black_list = paddle.amp.FP16_BLACK_LIST base_black_list = paddle.amp.black_list()["float16"]["O1"]
with paddle.amp.amp_guard( with paddle.amp.amp_guard(
custom_white_list=["log"], custom_black_list=["conv2d"] custom_white_list=["log"], custom_black_list=["conv2d"]
): ):
...@@ -104,8 +104,8 @@ class TestAutoCast(unittest.TestCase): ...@@ -104,8 +104,8 @@ class TestAutoCast(unittest.TestCase):
== (set(base_black_list) - {"log"}) | {"conv2d"} == (set(base_black_list) - {"log"}) | {"conv2d"}
) )
base_white_list = paddle.amp.PURE_FP16_WHITE_LIST base_white_list = paddle.amp.white_list()["float16"]["O2"]
base_black_list = paddle.amp.PURE_FP16_BLACK_LIST base_black_list = paddle.amp.black_list()["float16"]["O2"]
with paddle.amp.amp_guard( with paddle.amp.amp_guard(
custom_white_list=["log"], custom_white_list=["log"],
custom_black_list=["conv2d"], custom_black_list=["conv2d"],
...@@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase): ...@@ -194,8 +194,11 @@ class TestAutoCast(unittest.TestCase):
class TestAmpScaler(unittest.TestCase): class TestAmpScaler(unittest.TestCase):
def scale(self): def scale(self):
if not paddle.amp.is_float16_supported():
return
with fluid.dygraph.guard(): with fluid.dygraph.guard():
data = paddle.rand([10, 1024]) with paddle.amp.auto_cast(dtype='float16'):
data = paddle.rand([10, 1024])
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
scaled_data = scaler.scale(data) scaled_data = scaler.scale(data)
self.assertEqual( self.assertEqual(
...@@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase): ...@@ -333,9 +336,9 @@ class TestAmpScaler(unittest.TestCase):
) )
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(inp_np) data = fluid.dygraph.to_variable(inp_np)
with paddle.amp.auto_cast(dtype='float16'):
out = model(data) out = model(data)
loss = paddle.mean(out) loss = paddle.mean(out)
scaled_loss = scaler.scale(loss) scaled_loss = scaler.scale(loss)
scaled_loss.backward() scaled_loss.backward()
optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss)
...@@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase): ...@@ -348,6 +351,8 @@ class TestAmpScaler(unittest.TestCase):
) )
def test_nan_inf(self): def test_nan_inf(self):
if not paddle.amp.is_float16_supported():
return
self.nan_inf() self.nan_inf()
def step_update_exception(self): def step_update_exception(self):
......
...@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase): ...@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase): ...@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase): ...@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase): ...@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase): ...@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase): ...@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase): ...@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase): ...@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = numpy.random.random(size=(2, 2)).astype('float16') x = numpy.random.random(size=(2, 2)).astype('float16')
else: else:
x = numpy.random.random(size=(2, 2)).astype('float32') x = numpy.random.random(size=(2, 2)).astype('float32')
......
...@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase): ...@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase): ...@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase): ...@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if mp: if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase): ...@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if mp: if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -294,8 +294,8 @@ class PartialProgramLayer: ...@@ -294,8 +294,8 @@ class PartialProgramLayer:
def _create_amp_program(self, is_infer_mode=False): def _create_amp_program(self, is_infer_mode=False):
amp_program = self._origin_main_program.clone(for_test=is_infer_mode) amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
with program_guard(amp_program): with program_guard(amp_program):
paddle.static.amp.fp16_utils.rewrite_program( paddle.static.amp.fp16_utils.cast_model_to_fp16(
amp_program, self._amp_list amp_program, self._amp_list, use_fp16_guard=False, level='O1'
) )
if is_infer_mode: if is_infer_mode:
if self._hooker: if self._hooker:
......
...@@ -401,7 +401,8 @@ class Layer: ...@@ -401,7 +401,8 @@ class Layer:
self._forward_pre_hooks = collections.OrderedDict() self._forward_pre_hooks = collections.OrderedDict()
self._forward_post_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict()
self._casted_by_pure_fp16 = False # only used in AMP Training
self._cast_to_low_precison = True
self._state_dict_hooks = collections.OrderedDict() self._state_dict_hooks = collections.OrderedDict()
# Records orignal functions after @to_static to support to rollback # Records orignal functions after @to_static to support to rollback
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from . import decorator from . import decorator
from .decorator import decorate, amp_decorate from .decorator import decorate
from . import fp16_lists from . import fp16_lists
from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists
from . import fp16_utils from . import fp16_utils
......
...@@ -13,8 +13,14 @@ ...@@ -13,8 +13,14 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import paddle import paddle
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class OperatorStatsUnit: class OperatorStatsUnit:
...@@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input): ...@@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input):
var = block._var_recursive(var_name) var = block._var_recursive(var_name)
return var.dtype return var.dtype
except: except:
print( _logger.warning(
"Operator < {} > gets {} < {} : {} > error!".format( "Operator < {} > gets {} < {} : {} > error!".format(
op.type, "input" if is_input else "output", arg_name, var_name op.type, "input" if is_input else "output", arg_name, var_name
) )
...@@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block): ...@@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block):
if _is_floating_point(compute_dtype) and _is_floating_point( if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype var_dtype
): ):
print( _logger.warning(
"Operator < {} > has different input data types, input_names = {}, output_names = {}.".format( "Operator < {} > has different input data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names op.type, op.input_names, op.output_names
) )
...@@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block): ...@@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block):
if _is_floating_point(compute_dtype) and _is_floating_point( if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype var_dtype
): ):
print( _logger.warning(
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format( "Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names op.type, op.input_names, op.output_names
) )
...@@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list): ...@@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list):
def _get_op_stats_list(program): def _get_op_stats_list(program):
def _is_special_ops_with_input_x(op_type):
# operators have input X and have inputs different dtypes.
special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm']
if op_type in special_op_list:
return True
if op_type.replace("_grad", "") in special_op_list:
return True
return False
op_stats_list = [] op_stats_list = []
for block in program.blocks: for block in program.blocks:
block_op_stats_dict = {} block_op_stats_dict = {}
...@@ -161,13 +176,7 @@ def _get_op_stats_list(program): ...@@ -161,13 +176,7 @@ def _get_op_stats_list(program):
'create_double_buffer_reader', 'create_double_buffer_reader',
]: ]:
compute_dtype = None compute_dtype = None
elif op.type in [ elif _is_special_ops_with_input_x(op.type):
'cast',
'layer_norm',
'layer_norm_grad',
'batch_norm',
'batch_norm_grad',
]:
# Not check the input and output dtype difference for this operators. # Not check the input and output dtype difference for this operators.
compute_dtype = _get_var_dtype_from_block(block, op, 'X', True) compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
elif "Param" in op.input_names: elif "Param" in op.input_names:
...@@ -183,6 +192,78 @@ def _get_op_stats_list(program): ...@@ -183,6 +192,78 @@ def _get_op_stats_list(program):
def collect_operator_stats(program=None, print_subblocks=False): def collect_operator_stats(program=None, print_subblocks=False):
"""
Collect the number of operators for different data types through parsing
the program. The statistical data are categorized according to four data
types, namely float32, float16, bfloat16 and others.
Args:
program(Program, optional): The program to parse. Default None, and the default main_program will be parsed.
print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
class SimpleConvNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = paddle.nn.Linear(in_features=26, out_features=10)
def forward(self, x):
out = self.conv(x)
out = paddle.nn.functional.relu(out)
out = self.linear(out)
out = paddle.nn.functional.softmax(out)
return out
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.AdamW()
optimizer = paddle.static.amp.decorate(optimizer)
optimizer.minimize(loss)
paddle.static.amp.debugging.collect_operator_stats(main_program)
# <------------------------------------------------ op list of all blocks ------------------------------------------------->
# <------------------------------------------------------- op list -------------------------------------------------------->
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
# adamw | 0 | 0 | 4 | 0
# cast | 5 | 0 | 6 | 0
# check_finite_and_unscale | 0 | 0 | 1 | 0
# conv2d | 1 | 0 | 0 | 0
# conv2d_grad | 1 | 0 | 0 | 0
# elementwise_add | 2 | 0 | 0 | 0
# elementwise_add_grad | 2 | 0 | 0 | 0
# elementwise_mul | 0 | 0 | 1 | 0
# elementwise_mul_grad | 0 | 0 | 1 | 0
# fill_constant | 0 | 0 | 1 | 0
# matmul_v2 | 1 | 0 | 0 | 0
# matmul_v2_grad | 1 | 0 | 0 | 0
# memcpy | 0 | 0 | 0 | 1
# reduce_mean | 0 | 0 | 1 | 0
# reduce_mean_grad | 0 | 0 | 1 | 0
# relu | 1 | 0 | 0 | 0
# relu_grad | 1 | 0 | 0 | 0
# reshape2 | 0 | 0 | 1 | 0
# reshape2_grad | 0 | 0 | 1 | 0
# softmax | 0 | 0 | 1 | 0
# softmax_grad | 0 | 0 | 1 | 0
# update_loss_scaling | 0 | 0 | 1 | 0
# <----------------------------------------------------- op count: 22 ----------------------------------------------------->
"""
def _convert_to_list(op_stats_unit_dict): def _convert_to_list(op_stats_unit_dict):
for key, value in op_stats_unit_dict.items(): for key, value in op_stats_unit_dict.items():
op_stats_unit_dict[key] = value.convert_to_list() op_stats_unit_dict[key] = value.convert_to_list()
......
...@@ -29,9 +29,24 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype ...@@ -29,9 +29,24 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import ( from .fp16_utils import (
cast_model_to_fp16, cast_model_to_fp16,
cast_parameters_to_fp16, cast_parameters_to_fp16,
rewrite_program,
update_role_var_grad, update_role_var_grad,
) )
from .function_overload import FunctionType, overload
def _set_multi_precision(optimizer, multi_precision):
if not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
raise RuntimeError(
"Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
type(optimizer)
)
)
if multi_precision and hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
class OptimizerWithMixedPrecision: class OptimizerWithMixedPrecision:
...@@ -66,6 +81,7 @@ class OptimizerWithMixedPrecision: ...@@ -66,6 +81,7 @@ class OptimizerWithMixedPrecision:
the loss scaling. the loss scaling.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program. use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`. Default None, which means that its value is equal to `use_pure_fp16`.
use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
""" """
def __init__( def __init__(
...@@ -81,6 +97,7 @@ class OptimizerWithMixedPrecision: ...@@ -81,6 +97,7 @@ class OptimizerWithMixedPrecision:
incr_ratio, incr_ratio,
decr_ratio, decr_ratio,
use_amp_guard=None, use_amp_guard=None,
use_promote=False,
): ):
self._optimizer = optimizer self._optimizer = optimizer
self._amp_lists = amp_lists self._amp_lists = amp_lists
...@@ -115,6 +132,7 @@ class OptimizerWithMixedPrecision: ...@@ -115,6 +132,7 @@ class OptimizerWithMixedPrecision:
self._decr_ratio = decr_ratio self._decr_ratio = decr_ratio
self._num_good_steps = None self._num_good_steps = None
self._num_bad_steps = None self._num_bad_steps = None
self.use_promote = use_promote
def _set_distributed(self, flag): def _set_distributed(self, flag):
# if distributed, all cards will communication with each other, # if distributed, all cards will communication with each other,
...@@ -230,10 +248,18 @@ class OptimizerWithMixedPrecision: ...@@ -230,10 +248,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
self._use_fp16_guard, self._use_fp16_guard,
self._amp_vartype, self._amp_vartype,
level='O2',
use_promote=self.use_promote,
) )
else: else:
rewrite_program( # use_fp16_guard is not support amp-o1.
self._train_program, self._amp_lists, self._amp_vartype cast_model_to_fp16(
self._train_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
) )
if loss.dtype != core.VarDesc.VarType.FP32: if loss.dtype != core.VarDesc.VarType.FP32:
...@@ -361,10 +387,18 @@ class OptimizerWithMixedPrecision: ...@@ -361,10 +387,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
self._use_fp16_guard, self._use_fp16_guard,
self._amp_vartype, self._amp_vartype,
level='O2',
use_promote=self.use_promote,
) )
elif use_fp16_test: elif use_fp16_test:
rewrite_program( # use_fp16_guard is not support amp-o1.
test_program, self._amp_lists, self._amp_vartype cast_model_to_fp16(
test_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
) )
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
...@@ -610,6 +644,7 @@ class OptimizerWithMixedPrecision: ...@@ -610,6 +644,7 @@ class OptimizerWithMixedPrecision:
return optimize_ops, scaled_params_grads return optimize_ops, scaled_params_grads
@overload(key=FunctionType.FP16_ONLY)
def decorate( def decorate(
optimizer, optimizer,
amp_lists=None, amp_lists=None,
...@@ -622,6 +657,7 @@ def decorate( ...@@ -622,6 +657,7 @@ def decorate(
use_pure_fp16=False, use_pure_fp16=False,
use_fp16_guard=None, use_fp16_guard=None,
use_bf16=False, use_bf16=False,
use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
...@@ -734,31 +770,108 @@ def decorate( ...@@ -734,31 +770,108 @@ def decorate(
incr_ratio=incr_ratio, incr_ratio=incr_ratio,
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard, use_amp_guard=use_fp16_guard,
use_promote=use_promote,
) )
return mp_optimizer return mp_optimizer
def amp_decorate( @overload(key=FunctionType.COMMON)
def decorate(
optimizer, optimizer,
amp_lists=None, amp_lists=None,
level='O1', level='O1',
dtype='float16', dtype='float16',
master_weight=None,
init_loss_scaling=2**15, init_loss_scaling=2**15,
incr_every_n_steps=1000, incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2, decr_every_n_nan_or_inf=2,
incr_ratio=2.0, incr_ratio=2.0,
decr_ratio=0.8, decr_ratio=0.8,
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=None,
use_amp_guard=False, use_amp_guard=False,
use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
"""
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type of
each operator will be casted by white_list and black_list;
O2 represent pure FP16 / BF16 training, all operators parameters
and input data will be casted to FP16 / BF16, except operators in
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
will use multi-precision. Default is None.
init_loss_scaling(float, optional): The initial loss scaling factor.
Default is 32768.
incr_every_n_steps(int, optional): Increases loss scaling every n
consecutive steps with finite gradients. Default is 1000.
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
accumulated steps with nan or inf gradients. Default is 2.
incr_ratio(float, optional): The multiplier to use when increasing the
loss scaling. Default is 2.
decr_ratio(float, optional): The less-than-one-multiplier to use when
decreasing the loss scaling. Default is 0.8.
use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss
scaling. Default is None, which means True for float16, and False
for bfloat16.
Returns:
An optimizer acting like a normal one but with mixed-precision training
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
class SimpleConvNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = paddle.nn.Linear(in_features=26, out_features=10)
def forward(self, x):
out = self.conv(x)
out = paddle.nn.functional.relu(out)
out = self.linear(out)
out = paddle.nn.functional.softmax(out)
return out
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.AdamW()
optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16")
optimizer.minimize(loss)
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup_program)
# Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`,
# to convert FP32 parameters to low precision FP16 / BF16.
optimizer.amp_init(place, scope=paddle.static.global_scope())
"""
# check amp_level: O0-O2 # check amp_level: O0-O2
level = level.upper() level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'O1', 'O2']):
...@@ -766,6 +879,18 @@ def amp_decorate( ...@@ -766,6 +879,18 @@ def amp_decorate(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
) )
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16"
if optimizer is not None:
# support master_weight
multi_precision = not (master_weight is False)
_set_multi_precision(optimizer, multi_precision)
mp_optimizer = OptimizerWithMixedPrecision( mp_optimizer = OptimizerWithMixedPrecision(
optimizer, optimizer,
amp_lists, amp_lists,
...@@ -778,6 +903,7 @@ def amp_decorate( ...@@ -778,6 +903,7 @@ def amp_decorate(
incr_ratio=incr_ratio, incr_ratio=incr_ratio,
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard, use_amp_guard=use_amp_guard,
use_promote=use_promote,
) )
return mp_optimizer return mp_optimizer
...@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype): ...@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype):
else: else:
device = 'GPU' device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type) _, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops.
supported_fp16_list = {
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list return device, sys_unsupported_list
...@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype): ...@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype):
return unsupported_list return unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
def _get_white_list(dtype):
white_list_for_dtype = copy.copy(white_list)
if dtype == 'float16':
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
return white_list_for_dtype
class AutoMixedPrecisionLists: class AutoMixedPrecisionLists:
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists: ...@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists:
self.amp_dtype = check_amp_dtype(dtype) self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list) self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
...@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists: ...@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists:
""" """
Update black and white list according to users' custom list. Update black and white list according to users' custom list.
""" """
_logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ")
_logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ")
_logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ")
if self._custom_white_list and self._custom_black_list: if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
...@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists: ...@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists:
) )
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops. # dangerous and whose effects may also be observed in downstream ops.
black_list = { black_list = {
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The implementation refers to https://arpitbhayani.me/blogs/function-overloading.
# Note: it is customed for paddle.static.amp.decorate function.
import inspect
import logging
from enum import Enum
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class FunctionType(Enum):
FP16_ONLY = 0
COMMON = 1
class Function:
"""
Function is a wrap over standard python function
An instance of this Function class is also callable
just like the python function that it wrapped.
When the instance is "called" like a function it fetches
the function to be invoked from the virtual namespace and then
invokes the same.
"""
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
"""
Overriding the __call__ function which makes the
instance callable.
"""
# fetching the function to be invoked from the virtual namespace
# through the arguments.
fn = Namespace.get_instance().get(*args, **kwargs)
# invoking the wrapped function and returning the value.
return fn(*args, **kwargs)
class Namespace:
"""
Namespace is the singleton class that is responsible
for holding all the functions.
"""
__instance = None
def __init__(self):
if self.__instance is None:
self.function_map = {}
Namespace.__instance = self
else:
raise Exception("cannot instantiate Namespace again.")
@staticmethod
def get_instance():
if Namespace.__instance is None:
Namespace()
return Namespace.__instance
def register(self, fn, key):
"""
Register the function in the virtual namespace and return
an instance of callable Function that wraps the function fn.
Args:
fn (function): the native python function handle.
key (FunctionType): the specified type.
"""
assert isinstance(
key, FunctionType
), f"The type of key is expected to be FunctionType, but recieved {type(key)}."
func = Function(fn)
self.function_map[key] = fn
return func
def get(self, *args, **kwargs):
"""
Get the matching function from the virtual namespace according to the actual arguments.
Return None if it did not find any matching function.
"""
_logger.debug(f"get function: args={args}, kwargs={kwargs}")
satisfied_function_keys = set(self.function_map.keys())
num_actual_args = len(args) + len(kwargs)
for func_key in self.function_map.keys():
if func_key not in satisfied_function_keys:
continue
fn = self.function_map[func_key]
specs = inspect.getfullargspec(fn)
if len(specs) < len(args) + len(kwargs):
# Remove the not satisfied function according to the number of actual arguments.
_logger.debug(
f"fn={fn} (key={func_key}) is not satisfied and removed."
)
satisfied_function_keys.remove(func_key)
continue
if len(kwargs) > 0:
# Remove the not satisfied function according to argument keys in kwargs.
for arg_name, value in kwargs.items():
if arg_name not in specs.args:
_logger.debug(
f"fn={fn} (key={func_key}) is not satisfied and removed."
)
satisfied_function_keys.remove(func_key)
break
if len(satisfied_function_keys) == 1:
key = list(satisfied_function_keys)[0]
elif len(args) >= 3 and isinstance(args[2], float):
key = FunctionType.FP16_ONLY
else:
key = FunctionType.COMMON
return self.function_map.get(key)
def overload(key):
"""overload is the decorator that wraps the function
and returns a callable object of type Function.
"""
def decorator(fn):
return Namespace.get_instance().register(fn, key)
return decorator
...@@ -12,16 +12,24 @@ ...@@ -12,16 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.fluid import core
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32") _fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
def _build_optimizer( def _build_optimizer(
use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False use_amp,
amp_dtype="float16",
amp_level="O1",
amp_lists=None,
use_grad_clip=False,
use_promote=False,
): ):
if use_grad_clip: if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
...@@ -34,16 +42,14 @@ def _build_optimizer( ...@@ -34,16 +42,14 @@ def _build_optimizer(
beta2=0.836, beta2=0.836,
epsilon=1e-4, epsilon=1e-4,
weight_decay=0.01, weight_decay=0.01,
multi_precision=True,
) )
if use_amp: if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists( optimizer = paddle.static.amp.decorate(
custom_white_list=["elementwise_add"], optimizer,
custom_black_list=["reduce_mean"], amp_lists,
level=amp_level,
dtype=amp_dtype, dtype=amp_dtype,
) use_promote=use_promote,
optimizer = paddle.static.amp.amp_decorate(
optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype
) )
return optimizer return optimizer
...@@ -65,7 +71,9 @@ class SimpleAddNet(nn.Layer): ...@@ -65,7 +71,9 @@ class SimpleAddNet(nn.Layer):
return x + self.weight return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -80,7 +88,22 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -80,7 +88,22 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype) x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_add"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp,
amp_dtype,
amp_level,
amp_lists,
use_promote=use_promote,
)
optimizer.minimize(loss) optimizer.minimize(loss)
feed_vars = [x] feed_vars = [x]
fetch_vars = [loss] fetch_vars = [loss]
...@@ -91,30 +114,37 @@ class SimpleConvNet(nn.Layer): ...@@ -91,30 +114,37 @@ class SimpleConvNet(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10) self.linear = nn.Linear(in_features=96, out_features=4)
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
out = nn.functional.relu(out) out = nn.functional.relu(out)
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out) out = self.linear(out)
out = nn.functional.softmax(out) out = nn.functional.softmax(out)
return out return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet() model = SimpleConvNet()
x = paddle.static.data( x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32' name='input', shape=[None, 1, 6, 6], dtype='float32'
) )
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level) optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, use_promote=use_promote
)
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleEmbeddingNet(nn.Layer): class SimpleEmbeddingNet(nn.Layer):
...@@ -136,7 +166,9 @@ class SimpleEmbeddingNet(nn.Layer): ...@@ -136,7 +166,9 @@ class SimpleEmbeddingNet(nn.Layer):
return out return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -145,7 +177,14 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -145,7 +177,14 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64') x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True) optimizer = _build_optimizer(
use_amp,
amp_dtype,
amp_level,
None,
True,
use_promote=use_promote,
)
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program return main_program, startup_program
...@@ -186,3 +225,58 @@ def build_while_model(): ...@@ -186,3 +225,58 @@ def build_while_model():
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
return main_program, startup_program return main_program, startup_program
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support amp.",
)
class AmpTestBase(unittest.TestCase):
def setUp(self):
self.amp_dtype = None
self.amp_level = None
def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
for op_type, value in expected_fp16_calls.items():
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
)
def run_program(
self,
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
level,
):
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
if level == 'O2':
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from amp_base_models import AmpTestBase
import paddle
class TestAutoCast(AmpTestBase):
def test_amp_OD_level(self):
conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
linear = paddle.nn.Linear(in_features=4, out_features=4)
with paddle.amp.auto_cast(level='OD'):
out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = linear(out2)
self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
self.assertEqual(out3.dtype, paddle.float32)
class TestGradScaler(AmpTestBase):
def test_amp_grad_scaler(self):
model = paddle.nn.Conv2D(3, 2, 3)
optimizer = paddle.optimizer.SGD(
learning_rate=0.01, parameters=model.parameters()
)
scaler = paddle.amp.GradScaler()
data = paddle.rand([1, 3, 8, 8], dtype='float32')
paddle.amp.debugging.enable_operator_stats_collection()
with paddle.amp.auto_cast(
custom_black_list=['conv2d'], dtype='bfloat16'
):
out = model(data)
loss = out.mean()
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
paddle.amp.debugging.disable_operator_stats_collection()
op_list = paddle.fluid.core.get_low_precision_op_list()
self.assertEqual(scaler._enable, False)
self.assertEqual(scaler._use_dynamic_loss_scaling, False)
self.assertTrue('scale' not in op_list)
self.assertTrue('check_finite_and_unscale' not in op_list)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn.functional as F
class ConvBNLayer(paddle.nn.Layer):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
):
super().__init__()
self._conv = paddle.nn.Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=None,
)
self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class Model(paddle.nn.Layer):
def __init__(
self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True
):
super().__init__()
self.conv = ConvBNLayer(input_channel, 8, 3)
self.linear = paddle.nn.Linear(8, hidden_size)
self.layernorm = paddle.nn.Sequential(
paddle.nn.LayerNorm(hidden_size),
paddle.nn.LayerNorm(hidden_size),
)
self.fp16_conv = fp16_conv
self.fp16_linear = fp16_linear
def forward(self, inputs):
with paddle.amp.auto_cast(enable=self.fp16_conv):
if not self.fp16_conv:
inputs = inputs.astype('float32')
x = self.conv(inputs)
with paddle.amp.auto_cast(enable=self.fp16_linear):
if not self.fp16_linear:
x = x.astype('float32')
x = self.linear(x)
x = F.relu(x)
x = self.layernorm(x)
return x
class TestAMPDecorate(unittest.TestCase):
def check_results(self, fp32_layers=[], fp16_layers=[]):
for idx in range(len(fp32_layers)):
for layer in fp32_layers[idx].sublayers(include_self=False):
self.assertEqual(layer.weight.dtype, paddle.float32)
self.assertEqual(layer.bias.dtype, paddle.float32)
for idx in range(len(fp16_layers)):
for layer in fp16_layers[idx].sublayers(include_self=False):
self.assertEqual(layer.weight.dtype, paddle.float16)
self.assertEqual(layer.bias.dtype, paddle.float16)
def test_excluded_layers(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8, fp16_conv=False)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=model.conv,
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32'))
self.check_results(
fp32_layers=[model.conv, model.layernorm],
fp16_layers=[model.linear],
)
def test_excluded_layers_attr_list(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8, fp16_conv=False, fp16_linear=False)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=[model.conv, model.linear],
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32'))
self.check_results(
fp32_layers=[model.conv, model.linear, model.layernorm]
)
def test_excluded_layers_attr_types(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=[paddle.nn.Conv2D, model.linear],
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16'))
self.check_results(
fp32_layers=[model.conv, model.linear, model.layernorm]
)
def test_excluded_layers_attr_none(self):
if not paddle.amp.is_float16_supported():
return
model = Model(4, 8)
model = paddle.amp.decorate(
models=model,
level='O2',
dtype='float16',
excluded_layers=None,
)
with paddle.amp.auto_cast(level='O2'):
out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float16'))
self.check_results(
fp32_layers=[model.layernorm, model.conv._batch_norm],
fp16_layers=[model.conv._conv, model.linear],
)
if __name__ == '__main__':
unittest.main()
...@@ -14,32 +14,63 @@ ...@@ -14,32 +14,63 @@
import unittest import unittest
import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.static.amp import fp16_lists from paddle.static.amp import AutoMixedPrecisionLists, fp16_lists
from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists
class TestAMPList(unittest.TestCase): class TestAMPList(unittest.TestCase):
def test_main(self): def setUp(self):
custom_white_list = [ self.default_black_list = [
'lookup_table',
'lookup_table_v2',
]
amp_list = AutoMixedPrecisionLists(custom_white_list=custom_white_list)
for op in custom_white_list:
self.assertTrue(op in amp_list.white_list)
self.assertTrue(op not in amp_list.black_list)
self.assertTrue(op not in amp_list.unsupported_list)
default_black_list = [
'linear_interp_v2', 'linear_interp_v2',
'nearest_interp_v2', 'nearest_interp_v2',
'bilinear_interp_v2', 'bilinear_interp_v2',
'bicubic_interp_v2', 'bicubic_interp_v2',
'trilinear_interp_v2', 'trilinear_interp_v2',
] ]
for op in default_black_list: self.custom_white_list = [
self.assertTrue(op in amp_list.black_list) 'lookup_table',
'lookup_table_v2',
]
def check_if_op_in_list(self, op_list, amp_list):
for op in op_list:
self.assertTrue(op in amp_list)
def check_if_op_not_in_list(self, op_list, amp_list):
for op in op_list:
self.assertTrue(op not in amp_list)
def test_static(self):
amp_list = AutoMixedPrecisionLists(
custom_white_list=self.custom_white_list
)
self.check_if_op_in_list(self.default_black_list, amp_list.black_list)
self.check_if_op_in_list(self.custom_white_list, amp_list.white_list)
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list
)
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.unsupported_list
)
def test_eager(self):
if not paddle.amp.is_float16_supported():
return
white_list = paddle.amp.white_list()
black_list = paddle.amp.black_list()
self.check_if_op_in_list(
self.default_black_list, black_list["float16"]["O2"]
)
self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list)
with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}):
out1 = paddle.rand([2, 3]) + paddle.rand([2, 3])
out2 = out1.mean()
out3 = paddle.log(out2)
self.check_if_op_not_in_list(['log', 'elementwise_add'], white_list)
self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
self.assertEqual(out3.dtype, paddle.float32)
def test_apis(self): def test_apis(self):
def _run_check_dtype(): def _run_check_dtype():
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
paddle.enable_static()
class TestAMPPromote(AmpTestBase):
def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_promote_results(
True,
'float16',
'O1',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
def test_static_amp_o2(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw": 4,
}
self.check_promote_results(
True,
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
if __name__ == '__main__':
unittest.main()
...@@ -17,7 +17,7 @@ import struct ...@@ -17,7 +17,7 @@ import struct
import unittest import unittest
import numpy as np import numpy as np
from amp_base_models import build_add_model, build_embedding_model from amp_base_models import AmpTestBase, build_add_model, build_embedding_model
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -220,24 +220,30 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -220,24 +220,30 @@ class TestModelCastBF16(unittest.TestCase):
) )
@unittest.skipIf( class TestProgramBF16(AmpTestBase):
not core.is_compiled_with_cuda(), def _check_optimizer(self, program, expected_num_mp):
"core is not complied with CUDA and not support the bfloat16", optimizers = []
) for block in program.blocks:
class TestProgramBF16(unittest.TestCase): for op in block.ops:
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls): if "Param" in op.input_names and "Grad" in op.input_names:
for op_type, value in expected_bf16_calls.items(): optimizers.append(op)
self.assertEqual(
op_stats_dict[op_type].bf16_calls, actual_num_mp = 0
value, for op in optimizers:
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.", if op.has_attr("multi_precision") and op.attr("multi_precision"):
) actual_num_mp += 1
self.assertEqual(
actual_num_mp,
expected_num_mp,
f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
)
def test_amp_bf16_o1(self): def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1" True, "bfloat16", "O1"
) )
self.assertEqual(main_program.num_blocks, 1) self.assertEqual(main_program.num_blocks, 1)
self._check_optimizer(main_program, 0)
amp.debugging.collect_operator_stats(main_program) amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program)
...@@ -249,7 +255,7 @@ class TestProgramBF16(unittest.TestCase): ...@@ -249,7 +255,7 @@ class TestProgramBF16(unittest.TestCase):
"squared_l2_norm": 0, "squared_l2_norm": 0,
"adamw": 0, "adamw": 0,
} }
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self): def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
...@@ -267,14 +273,15 @@ class TestProgramBF16(unittest.TestCase): ...@@ -267,14 +273,15 @@ class TestProgramBF16(unittest.TestCase):
"squared_l2_norm": 2, "squared_l2_norm": 2,
"adamw": 2, "adamw": 2,
} }
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
@unittest.skipIf( class TestStaticBF16(AmpTestBase):
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestStaticBF16(unittest.TestCase):
def _generate_feed_x(self): def _generate_feed_x(self):
x = np.random.random(size=[16, 16]).astype("float32") x = np.random.random(size=[16, 16]).astype("float32")
x_bf16 = convert_float_to_uint16(x) x_bf16 = convert_float_to_uint16(x)
...@@ -282,60 +289,35 @@ class TestStaticBF16(unittest.TestCase): ...@@ -282,60 +289,35 @@ class TestStaticBF16(unittest.TestCase):
return x_fp32, x_bf16 return x_fp32, x_bf16
def test_compare_o1_o2(self): def test_compare_o1_o2(self):
def _run_o1(exe, x_np, max_iters): def _run(place, exe, x_np, max_iters, level):
( (
main_program, main_program,
startup_program, startup_program,
optimizer, optimizer,
feed_vars, feed_vars,
fetch_vars, fetch_vars,
) = build_add_model(True, "bfloat16", "O1") ) = build_add_model(True, "bfloat16", level)
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
def _run_o2(exe, x_np, max_iters): losses = self.run_program(
(
main_program, main_program,
startup_program, startup_program,
optimizer, optimizer,
feed_vars, feed_vars,
fetch_vars, fetch_vars,
) = build_add_model(True, "bfloat16", "O2") place,
exe,
losses = [] x_np,
scope = paddle.static.Scope() max_iters,
with paddle.static.scope_guard(scope): level,
exe.run(startup_program) )
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2 max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x() x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(exe, x_fp32, max_iters) place = paddle.CUDAPlace(0)
losses_o2 = _run_o2(exe, x_bf16, max_iters) exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase):
# infer(use_cuda, save_dirname) # infer(use_cuda, save_dirname)
def test_amp_lists(self): def test_amp_lists(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_1(self): def test_amp_lists_1(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_2(self): def test_amp_lists_2(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_3(self): def test_amp_lists_3(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_4(self): def test_amp_lists_4(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_5(self): def test_amp_lists_5(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_6(self): def test_amp_lists_6(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......
...@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.static.amp.fp16_guard(): with paddle.static.amp.fp16_guard():
with paddle.static.program_guard(program, startup_program): with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [1, 64, 64, 8]) x = paddle.static.data("x", [1, 64, 64, 8], dtype="float16")
conv2d = paddle.nn.Conv2D( conv2d = paddle.nn.Conv2D(
8, 32, 1, bias_attr=False, data_format='NHWC' 8, 32, 1, bias_attr=False, data_format='NHWC'
) )
...@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase): ...@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
before_out[0], after_out[0], rtol=1e-05, atol=0.005 before_out[0], after_out[0], rtol=1e-05, atol=0.005
) )
if __name__ == '__main__':
unittest.main()
...@@ -25,10 +25,10 @@ paddle.enable_static() ...@@ -25,10 +25,10 @@ paddle.enable_static()
def build_resnet50(use_amp=False): def build_resnet50(use_amp=False):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dtype = 'float16' if use_amp else 'float32'
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[32, 3, 224, 224], dtype='float32' name='image', shape=[32, 3, 224, 224], dtype=dtype
) )
label = paddle.static.data(name='label', shape=[32], dtype='int64') label = paddle.static.data(name='label', shape=[32], dtype='int64')
model = paddle.vision.models.resnet50() model = paddle.vision.models.resnet50()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册