From e5ed5257083b92b018330812c33c746bae26fb41 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 17 Nov 2022 11:22:47 +0800 Subject: [PATCH] Support bfloat16 for adamw and adam optimizer. Fit the lr for pure bf16 training with tensor fusion. (#48041) * add bfloat16 for adamw * set lr not to bfloat16 for pure bf16 training * update the logic * update the adamw optimizer * support bfloat for adam --- paddle/fluid/pybind/eager_functions.cc | 3 ++- paddle/phi/kernels/gpu/adamw_kernel.cu | 4 ++- python/paddle/optimizer/adam.py | 34 ++++++++++++-------------- python/paddle/optimizer/adamw.py | 17 ++++++------- python/paddle/optimizer/optimizer.py | 26 +++++++++++++++++--- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index cdace567b2e..3389daf330c 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -268,7 +268,8 @@ PyObject* eager_api_get_grads_types(PyObject* self, if (meta && grad.initialized()) { if (grad.is_dense_tensor() && (tensor.dtype() == paddle::experimental::DataType::FLOAT32 || - tensor.dtype() == paddle::experimental::DataType::FLOAT16)) { + tensor.dtype() == paddle::experimental::DataType::FLOAT16 || + tensor.dtype() == paddle::experimental::DataType::BFLOAT16)) { ret.emplace_back( paddle::framework::TransToProtoVarType(tensor.dtype())); } diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 9ddaacdd5cc..6994c83f536 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -300,7 +301,8 @@ PD_REGISTER_KERNEL(adamw, phi::AdamwDenseKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow, skip_update data transform kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 74499b05f24..aa76fb82759 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -28,7 +28,7 @@ from paddle import _C_ops, _legacy_C_ops __all__ = [] -GRAD_TYPES = [int(paddle.float32), int(paddle.float16)] +GRAD_TYPES = [int(paddle.float32), int(paddle.float16), int(paddle.bfloat16)] class Adam(Optimizer): @@ -265,8 +265,8 @@ class Adam(Optimizer): """ if self._name is not None: name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param.dtype ) target_param = ( self._master_weights[param.name] if find_master else param @@ -285,10 +285,7 @@ class Adam(Optimizer): def _add_moments_pows(self, p): acc_dtype = p.dtype - if ( - acc_dtype == core.VarDesc.VarType.FP16 - or acc_dtype == core.VarDesc.VarType.BF16 - ): + if self._is_dtype_fp16_or_bf16(acc_dtype): acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) @@ -322,16 +319,16 @@ class Adam(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) @@ -353,9 +350,8 @@ class Adam(Optimizer): beta2_pow_acc = self._get_accumulator( self._beta2_pow_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -571,7 +567,7 @@ class Adam(Optimizer): def _multi_tensor_init(self, target_block, parameters, param_group_idx): """ - All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). + All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (bfloat16, float16, float32). This function will be overridden in the corresponding optimizer file. Args: target_block: the block in which the loss tensor is present @@ -604,7 +600,7 @@ class Adam(Optimizer): self._beta2_pow_acc_dict['FP32_LODTensor'][ param_group_idx ].append(beta2_pow_acc) - elif param.dtype == paddle.float16: + elif self._is_dtype_fp16_or_bf16(param.dtype): self._param_dict['FP16_LODTensor'][param_group_idx].append( param ) @@ -628,7 +624,7 @@ class Adam(Optimizer): self._master_weight_dict['FP16_LODTensor'] = None else: raise ValueError( - "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR." + "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR." ) def _append_optimize_multi_tensor_op( @@ -656,7 +652,7 @@ class Adam(Optimizer): ) lr = self._create_param_lr(parameters_and_grads[index]) lr_dict['FP32_LODTensor'].append(lr) - elif tp == GRAD_TYPES[1]: + elif tp == GRAD_TYPES[1] or tp == GRAD_TYPES[2]: grad_dict['FP16_LODTensor'].append( parameters_and_grads[index][1] ) @@ -678,7 +674,7 @@ class Adam(Optimizer): lr = self._create_param_lr(param_and_grad) lr_dict['FP32_LODTensor'].append(lr) elif ( - param_and_grad[0].dtype == paddle.float16 + self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) and param_and_grad[1].type == core.VarDesc.VarType.LOD_TENSOR ): @@ -711,7 +707,7 @@ class Adam(Optimizer): lr = self._create_param_lr(param_and_grad) lr_dict['FP32_LODTensor'].append(lr) elif ( - param_and_grad[0].dtype == paddle.float16 + self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) and param_and_grad[1].type == core.VarDesc.VarType.LOD_TENSOR ): diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index dca844b6682..5424331a71f 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -369,8 +369,8 @@ class AdamW(Optimizer): """ if self._name is not None: name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param.dtype ) target_param = ( self._master_weights[param.name] if find_master else param @@ -389,7 +389,7 @@ class AdamW(Optimizer): def _add_moments_pows(self, p): acc_dtype = p.dtype - if acc_dtype == core.VarDesc.VarType.FP16: + if self._is_dtype_fp16_or_bf16(acc_dtype): acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) @@ -423,16 +423,16 @@ class AdamW(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) @@ -463,9 +463,8 @@ class AdamW(Optimizer): beta2_pow_acc = self._get_accumulator( self._beta2_pow_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 26ae5b50269..59663bb8190 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -421,15 +421,21 @@ class Optimizer: return self._opti_name_list def _create_global_learning_rate(self): - # lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr + # lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr _lr_dtype = ( paddle.get_default_dtype() if self._dtype is None else self._dtype ) _lr_dtype = ( paddle.float32 if ( - paddle.get_default_dtype() != "float16" - and _lr_dtype == paddle.float16 + ( + paddle.get_default_dtype() != "float16" + and _lr_dtype == paddle.float16 + ) + or ( + paddle.get_default_dtype() != "bfloat16" + and _lr_dtype == paddle.bfloat16 + ) ) else _lr_dtype ) @@ -1526,3 +1532,17 @@ class Optimizer: For Multi Tensor, append optimize merged_operator to block. """ pass + + def _is_dtype_fp16_or_bf16(self, dtype): + """ + check the dtype is fp16 or the dtype is bf16 + :param dtype: instance of core.VarDesc.VarType + :return: True if dtype is one of fp16 or bf16, False otherwise + """ + assert isinstance( + dtype, core.VarDesc.VarType + ), "The dtype should be an instance of core.VarDesc.VarType." + return ( + dtype == core.VarDesc.VarType.FP16 + or dtype == core.VarDesc.VarType.BF16 + ) -- GitLab