未验证 提交 e8809d99 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] Support AMP-O2 for bfloat16 (#45541)

* support bfloat16 for amp_decorate

* add check_finite for bf16

* fix bug

* add ut

* add ut

* refine code
上级 12b5b74f
......@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(update_loss_scaling,
GPU,
......@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel,
float,
double,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -81,21 +81,23 @@ def auto_cast(enable=True,
def decorate(models,
optimizers=None,
level='O1',
dtype='float16',
master_weight=None,
save_dtype=None):
"""
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm.
When level is O2(pure float16/bfloat16), the decorate will cast all parameters of models to float16/bfloat16, except BatchNorm and LayerNorm.
Commonly, it is used together with `auto_cast` to achieve Pure fp16 in imperative mode.
Commonly, it is used together with `auto_cast` to achieve Pure float16/bfloat16 in imperative mode.
Args:
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp)
O2 represent Pure float16/bfloat16, the decorator will cast all parameters of models to float16/bfloat16, except BatchNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
Examples:
......@@ -145,4 +147,5 @@ def decorate(models,
output = model(data)
print(output.dtype) # FP16
"""
return amp_decorate(models, optimizers, level, master_weight, save_dtype)
return amp_decorate(models, optimizers, level, dtype, master_weight,
save_dtype)
......@@ -76,7 +76,7 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
PURE_FP16_WHITE_LIST = {' '}
PURE_FP16_WHITE_LIST = {''}
PURE_FP16_BLACK_LIST = {
'lookup_table',
'lookup_table_v2',
......@@ -91,7 +91,10 @@ PURE_FP16_BLACK_LIST = {
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = {' '}
BF16_BLACK_LIST = {''}
PURE_BF16_WHITE_LIST = {''}
PURE_BF16_BLACK_LIST = {''}
_g_amp_state_ = None
......@@ -118,8 +121,12 @@ def _update_list(custom_white_list,
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else:
if level == 'O1':
_white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(BF16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_BF16_WHITE_LIST)
_black_list = copy.copy(PURE_BF16_BLACK_LIST)
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
......@@ -198,6 +205,16 @@ def pure_fp16_initialize(models):
return models
@dygraph_only
def pure_bf16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._to_impl(dtype='bfloat16',
include_sublayers=False,
floating_only=True)
return models
def check_models(models):
for model in models:
if not isinstance(model, paddle.nn.Layer):
......@@ -424,6 +441,7 @@ class StateDictHook(object):
def amp_decorate(models,
optimizers=None,
level='O1',
dtype='float16',
master_weight=None,
save_dtype=None):
"""
......@@ -436,9 +454,10 @@ def amp_decorate(models,
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp)
O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
Examples:
......@@ -510,8 +529,12 @@ def amp_decorate(models,
else:
raise TypeError(
"models must be either a single model or a list of models.")
if dtype == 'float16':
models = pure_fp16_initialize(models=models)
elif dtype == 'bfloat16':
models = pure_bf16_initialize(models=models)
else:
raise TypeError("dtype only support float16 or bfloat16.")
if optimizers is not None:
# check optimizers
......@@ -538,7 +561,7 @@ def amp_decorate(models,
optimizers[idx_opt]._multi_precision = True
if save_dtype is not None:
if not (save_dtype in ['float16', 'float32', 'float64']):
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
raise ValueError(
"save_dtype can only be float16 float32 or float64, but your input save_dtype is %s."
% save_dtype)
......
......@@ -132,6 +132,8 @@ class AmpScaler(object):
self._found_inf = to_variable(np.array([0]).astype(np.bool_))
self._temp_found_inf_fp16 = to_variable(
np.array([0]).astype(np.bool_))
self._temp_found_inf_bf16 = to_variable(
np.array([0]).astype(np.bool_))
self._temp_found_inf_fp32 = to_variable(
np.array([0]).astype(np.bool_))
self._scale = to_variable(
......@@ -262,6 +264,7 @@ class AmpScaler(object):
optimizer._param_groups[0], dict):
param_grads = []
param_grads_fp16 = []
param_grads_bf16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
......@@ -270,6 +273,9 @@ class AmpScaler(object):
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
param_grads_fp16.append(param._grad_ivar())
elif param._grad_ivar(
).dtype == core.VarDesc.VarType.BF16:
param_grads_bf16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
......@@ -281,6 +287,10 @@ class AmpScaler(object):
param for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
]
param_grads_bf16 = [
param for param in param_grads
if param.dtype == core.VarDesc.VarType.BF16
]
param_grads_fp32 = [
param for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
......@@ -293,6 +303,10 @@ class AmpScaler(object):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16, self._scale, float_status,
param_grads_fp16, self._temp_found_inf_fp16)
if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, self._scale, float_status,
param_grads_bf16, self._temp_found_inf_bf16)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, self._scale, float_status,
......@@ -302,12 +316,16 @@ class AmpScaler(object):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16, self._scale, param_grads_fp16,
self._temp_found_inf_fp16)
if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, self._scale, param_grads_bf16,
self._temp_found_inf_bf16)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, self._scale, param_grads_fp32,
self._temp_found_inf_fp32)
self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32
self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_bf16 or self._temp_found_inf_fp32
optimizer_state["state"] = OptimizerState.UNSCALED
......
......@@ -1169,6 +1169,9 @@ def convert_np_dtype_to_dtype_(np_dtype):
core.VarDesc.VarType: the data type in Paddle.
"""
if np_dtype == "bfloat16":
dtype = np.uint16
else:
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.VarDesc.VarType.FP32
......
......@@ -1310,6 +1310,10 @@ class TestBf16(unittest.TestCase):
paddle.seed(100)
input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(4, 6, (3, 3))
if amp_level == 'O2':
conv = paddle.amp.decorate(models=conv,
level=amp_level,
dtype='bfloat16')
with paddle.amp.auto_cast(enable=enable_amp,
level=amp_level,
dtype='bfloat16'):
......
......@@ -1297,6 +1297,10 @@ class TestBf16(unittest.TestCase):
paddle.seed(100)
input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(4, 6, (3, 3))
if amp_level == 'O2':
conv = paddle.amp.decorate(models=conv,
level=amp_level,
dtype='bfloat16')
with paddle.amp.auto_cast(enable=enable_amp,
level=amp_level,
dtype='bfloat16'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册