From e8809d995713084cf22051797856c01713ffa616 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 14 Sep 2022 11:17:24 +0800 Subject: [PATCH] [AMP] Support AMP-O2 for bfloat16 (#45541) * support bfloat16 for amp_decorate * add check_finite for bf16 * fix bug * add ut * add ut * refine code --- paddle/phi/kernels/gpu/amp_kernel.cu | 6 ++- python/paddle/amp/auto_cast.py | 13 +++--- python/paddle/fluid/dygraph/amp/auto_cast.py | 41 +++++++++++++++---- .../paddle/fluid/dygraph/amp/loss_scaler.py | 20 ++++++++- python/paddle/fluid/framework.py | 5 ++- .../test_imperative_auto_mixed_precision.py | 4 ++ ...perative_auto_mixed_precision_for_eager.py | 4 ++ 7 files changed, 75 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/gpu/amp_kernel.cu b/paddle/phi/kernels/gpu/amp_kernel.cu index 230eb801d20..919663a75e6 100644 --- a/paddle/phi/kernels/gpu/amp_kernel.cu +++ b/paddle/phi/kernels/gpu/amp_kernel.cu @@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, phi::CheckFiniteAndUnscaleKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(update_loss_scaling, GPU, @@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling, phi::UpdateLossScalingKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 96a94d89846..4cf628abe05 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -81,21 +81,23 @@ def auto_cast(enable=True, def decorate(models, optimizers=None, level='O1', + dtype='float16', master_weight=None, save_dtype=None): """ Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. - When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm. + When level is O2(pure float16/bfloat16), the decorate will cast all parameters of models to float16/bfloat16, except BatchNorm and LayerNorm. - Commonly, it is used together with `auto_cast` to achieve Pure fp16 in imperative mode. + Commonly, it is used together with `auto_cast` to achieve Pure float16/bfloat16 in imperative mode. Args: models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; - O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp) + O2 represent Pure float16/bfloat16, the decorator will cast all parameters of models to float16/bfloat16, except BatchNorm and LayerNorm. Default is O1(amp) + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. - save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None. + save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. Examples: @@ -145,4 +147,5 @@ def decorate(models, output = model(data) print(output.dtype) # FP16 """ - return amp_decorate(models, optimizers, level, master_weight, save_dtype) + return amp_decorate(models, optimizers, level, dtype, master_weight, + save_dtype) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 87df8082136..d1d53853740 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -76,7 +76,7 @@ AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, } -PURE_FP16_WHITE_LIST = {' '} +PURE_FP16_WHITE_LIST = {''} PURE_FP16_BLACK_LIST = { 'lookup_table', 'lookup_table_v2', @@ -91,7 +91,10 @@ PURE_FP16_BLACK_LIST = { } BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} -BF16_BLACK_LIST = {' '} +BF16_BLACK_LIST = {''} + +PURE_BF16_WHITE_LIST = {''} +PURE_BF16_BLACK_LIST = {''} _g_amp_state_ = None @@ -118,8 +121,12 @@ def _update_list(custom_white_list, _white_list = copy.copy(PURE_FP16_WHITE_LIST) _black_list = copy.copy(PURE_FP16_BLACK_LIST) else: - _white_list = copy.copy(BF16_WHITE_LIST) - _black_list = copy.copy(BF16_BLACK_LIST) + if level == 'O1': + _white_list = copy.copy(BF16_WHITE_LIST) + _black_list = copy.copy(BF16_BLACK_LIST) + else: + _white_list = copy.copy(PURE_BF16_WHITE_LIST) + _black_list = copy.copy(PURE_BF16_BLACK_LIST) if custom_white_list and custom_black_list: for op_name in custom_white_list: if op_name in custom_black_list: @@ -198,6 +205,16 @@ def pure_fp16_initialize(models): return models +@dygraph_only +def pure_bf16_initialize(models): + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + layer._to_impl(dtype='bfloat16', + include_sublayers=False, + floating_only=True) + return models + + def check_models(models): for model in models: if not isinstance(model, paddle.nn.Layer): @@ -424,6 +441,7 @@ class StateDictHook(object): def amp_decorate(models, optimizers=None, level='O1', + dtype='float16', master_weight=None, save_dtype=None): """ @@ -436,9 +454,10 @@ def amp_decorate(models, models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; - O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp) + O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp) + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. - save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None. + save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. Examples: @@ -510,8 +529,12 @@ def amp_decorate(models, else: raise TypeError( "models must be either a single model or a list of models.") - - models = pure_fp16_initialize(models=models) + if dtype == 'float16': + models = pure_fp16_initialize(models=models) + elif dtype == 'bfloat16': + models = pure_bf16_initialize(models=models) + else: + raise TypeError("dtype only support float16 or bfloat16.") if optimizers is not None: # check optimizers @@ -538,7 +561,7 @@ def amp_decorate(models, optimizers[idx_opt]._multi_precision = True if save_dtype is not None: - if not (save_dtype in ['float16', 'float32', 'float64']): + if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): raise ValueError( "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s." % save_dtype) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index aeb4c730975..f86bdf18506 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -132,6 +132,8 @@ class AmpScaler(object): self._found_inf = to_variable(np.array([0]).astype(np.bool_)) self._temp_found_inf_fp16 = to_variable( np.array([0]).astype(np.bool_)) + self._temp_found_inf_bf16 = to_variable( + np.array([0]).astype(np.bool_)) self._temp_found_inf_fp32 = to_variable( np.array([0]).astype(np.bool_)) self._scale = to_variable( @@ -262,6 +264,7 @@ class AmpScaler(object): optimizer._param_groups[0], dict): param_grads = [] param_grads_fp16 = [] + param_grads_bf16 = [] param_grads_fp32 = [] for group in optimizer._param_groups: for param in group['params']: @@ -270,6 +273,9 @@ class AmpScaler(object): if param._grad_ivar( ).dtype == core.VarDesc.VarType.FP16: param_grads_fp16.append(param._grad_ivar()) + elif param._grad_ivar( + ).dtype == core.VarDesc.VarType.BF16: + param_grads_bf16.append(param._grad_ivar()) else: param_grads_fp32.append(param._grad_ivar()) else: @@ -281,6 +287,10 @@ class AmpScaler(object): param for param in param_grads if param.dtype == core.VarDesc.VarType.FP16 ] + param_grads_bf16 = [ + param for param in param_grads + if param.dtype == core.VarDesc.VarType.BF16 + ] param_grads_fp32 = [ param for param in param_grads if param.dtype == core.VarDesc.VarType.FP32 @@ -293,6 +303,10 @@ class AmpScaler(object): _legacy_C_ops.check_finite_and_unscale( param_grads_fp16, self._scale, float_status, param_grads_fp16, self._temp_found_inf_fp16) + if len(param_grads_bf16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bf16, self._scale, float_status, + param_grads_bf16, self._temp_found_inf_bf16) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, self._scale, float_status, @@ -302,12 +316,16 @@ class AmpScaler(object): _legacy_C_ops.check_finite_and_unscale( param_grads_fp16, self._scale, param_grads_fp16, self._temp_found_inf_fp16) + if len(param_grads_bf16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bf16, self._scale, param_grads_bf16, + self._temp_found_inf_bf16) if len(param_grads_fp32): _legacy_C_ops.check_finite_and_unscale( param_grads_fp32, self._scale, param_grads_fp32, self._temp_found_inf_fp32) - self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32 + self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_bf16 or self._temp_found_inf_fp32 optimizer_state["state"] = OptimizerState.UNSCALED diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index bf56b125fd7..82a887be414 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1169,7 +1169,10 @@ def convert_np_dtype_to_dtype_(np_dtype): core.VarDesc.VarType: the data type in Paddle. """ - dtype = np.dtype(np_dtype) + if np_dtype == "bfloat16": + dtype = np.uint16 + else: + dtype = np.dtype(np_dtype) if dtype == np.float32: return core.VarDesc.VarType.FP32 elif dtype == np.float64: diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py index 3491345d67e..e08e069692a 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py @@ -1310,6 +1310,10 @@ class TestBf16(unittest.TestCase): paddle.seed(100) input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) conv = paddle.nn.Conv2D(4, 6, (3, 3)) + if amp_level == 'O2': + conv = paddle.amp.decorate(models=conv, + level=amp_level, + dtype='bfloat16') with paddle.amp.auto_cast(enable=enable_amp, level=amp_level, dtype='bfloat16'): diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index 015e3a8f459..541ca91f996 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -1297,6 +1297,10 @@ class TestBf16(unittest.TestCase): paddle.seed(100) input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) conv = paddle.nn.Conv2D(4, 6, (3, 3)) + if amp_level == 'O2': + conv = paddle.amp.decorate(models=conv, + level=amp_level, + dtype='bfloat16') with paddle.amp.auto_cast(enable=enable_amp, level=amp_level, dtype='bfloat16'): -- GitLab