未验证 提交 e5ed5257 编写于 作者: Y Yuang Liu 提交者: GitHub

Support bfloat16 for adamw and adam optimizer. Fit the lr for pure bf16...

Support bfloat16 for adamw and adam optimizer. Fit the lr for pure bf16 training with tensor fusion. (#48041)

* add bfloat16 for adamw

* set lr not to bfloat16 for pure bf16 training

* update the logic

* update the adamw optimizer

* support bfloat for adam
上级 4f57da5f
...@@ -268,7 +268,8 @@ PyObject* eager_api_get_grads_types(PyObject* self, ...@@ -268,7 +268,8 @@ PyObject* eager_api_get_grads_types(PyObject* self,
if (meta && grad.initialized()) { if (meta && grad.initialized()) {
if (grad.is_dense_tensor() && if (grad.is_dense_tensor() &&
(tensor.dtype() == paddle::experimental::DataType::FLOAT32 || (tensor.dtype() == paddle::experimental::DataType::FLOAT32 ||
tensor.dtype() == paddle::experimental::DataType::FLOAT16)) { tensor.dtype() == paddle::experimental::DataType::FLOAT16 ||
tensor.dtype() == paddle::experimental::DataType::BFLOAT16)) {
ret.emplace_back( ret.emplace_back(
paddle::framework::TransToProtoVarType(tensor.dtype())); paddle::framework::TransToProtoVarType(tensor.dtype()));
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
...@@ -300,7 +301,8 @@ PD_REGISTER_KERNEL(adamw, ...@@ -300,7 +301,8 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel, phi::AdamwDenseKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform // Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
......
...@@ -28,7 +28,7 @@ from paddle import _C_ops, _legacy_C_ops ...@@ -28,7 +28,7 @@ from paddle import _C_ops, _legacy_C_ops
__all__ = [] __all__ = []
GRAD_TYPES = [int(paddle.float32), int(paddle.float16)] GRAD_TYPES = [int(paddle.float32), int(paddle.float16), int(paddle.bfloat16)]
class Adam(Optimizer): class Adam(Optimizer):
...@@ -265,8 +265,8 @@ class Adam(Optimizer): ...@@ -265,8 +265,8 @@ class Adam(Optimizer):
""" """
if self._name is not None: if self._name is not None:
name = self._name + "_" + name name = self._name + "_" + name
find_master = ( find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 param.dtype
) )
target_param = ( target_param = (
self._master_weights[param.name] if find_master else param self._master_weights[param.name] if find_master else param
...@@ -285,10 +285,7 @@ class Adam(Optimizer): ...@@ -285,10 +285,7 @@ class Adam(Optimizer):
def _add_moments_pows(self, p): def _add_moments_pows(self, p):
acc_dtype = p.dtype acc_dtype = p.dtype
if ( if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype == core.VarDesc.VarType.FP16
or acc_dtype == core.VarDesc.VarType.BF16
):
acc_dtype = core.VarDesc.VarType.FP32 acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
...@@ -322,16 +319,16 @@ class Adam(Optimizer): ...@@ -322,16 +319,16 @@ class Adam(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
continue continue
if ( if (
p.dtype == core.VarDesc.VarType.FP16 self._is_dtype_fp16_or_bf16(p.dtype)
and not self._multi_precision and not self._multi_precision
): ):
warnings.warn( warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer." "Consider using multi_precision=True option of the Adam optimizer."
) )
self._add_moments_pows(p) self._add_moments_pows(p)
...@@ -353,9 +350,8 @@ class Adam(Optimizer): ...@@ -353,9 +350,8 @@ class Adam(Optimizer):
beta2_pow_acc = self._get_accumulator( beta2_pow_acc = self._get_accumulator(
self._beta2_pow_acc_str, param_and_grad[0] self._beta2_pow_acc_str, param_and_grad[0]
) )
find_master = ( find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
self._multi_precision param_and_grad[0].dtype
and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
) )
master_weight = ( master_weight = (
self._master_weights[param_and_grad[0].name] self._master_weights[param_and_grad[0].name]
...@@ -571,7 +567,7 @@ class Adam(Optimizer): ...@@ -571,7 +567,7 @@ class Adam(Optimizer):
def _multi_tensor_init(self, target_block, parameters, param_group_idx): def _multi_tensor_init(self, target_block, parameters, param_group_idx):
""" """
All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (bfloat16, float16, float32).
This function will be overridden in the corresponding optimizer file. This function will be overridden in the corresponding optimizer file.
Args: Args:
target_block: the block in which the loss tensor is present target_block: the block in which the loss tensor is present
...@@ -604,7 +600,7 @@ class Adam(Optimizer): ...@@ -604,7 +600,7 @@ class Adam(Optimizer):
self._beta2_pow_acc_dict['FP32_LODTensor'][ self._beta2_pow_acc_dict['FP32_LODTensor'][
param_group_idx param_group_idx
].append(beta2_pow_acc) ].append(beta2_pow_acc)
elif param.dtype == paddle.float16: elif self._is_dtype_fp16_or_bf16(param.dtype):
self._param_dict['FP16_LODTensor'][param_group_idx].append( self._param_dict['FP16_LODTensor'][param_group_idx].append(
param param
) )
...@@ -628,7 +624,7 @@ class Adam(Optimizer): ...@@ -628,7 +624,7 @@ class Adam(Optimizer):
self._master_weight_dict['FP16_LODTensor'] = None self._master_weight_dict['FP16_LODTensor'] = None
else: else:
raise ValueError( raise ValueError(
"Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR." "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR."
) )
def _append_optimize_multi_tensor_op( def _append_optimize_multi_tensor_op(
...@@ -656,7 +652,7 @@ class Adam(Optimizer): ...@@ -656,7 +652,7 @@ class Adam(Optimizer):
) )
lr = self._create_param_lr(parameters_and_grads[index]) lr = self._create_param_lr(parameters_and_grads[index])
lr_dict['FP32_LODTensor'].append(lr) lr_dict['FP32_LODTensor'].append(lr)
elif tp == GRAD_TYPES[1]: elif tp == GRAD_TYPES[1] or tp == GRAD_TYPES[2]:
grad_dict['FP16_LODTensor'].append( grad_dict['FP16_LODTensor'].append(
parameters_and_grads[index][1] parameters_and_grads[index][1]
) )
...@@ -678,7 +674,7 @@ class Adam(Optimizer): ...@@ -678,7 +674,7 @@ class Adam(Optimizer):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr) lr_dict['FP32_LODTensor'].append(lr)
elif ( elif (
param_and_grad[0].dtype == paddle.float16 self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
and param_and_grad[1].type and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR == core.VarDesc.VarType.LOD_TENSOR
): ):
...@@ -711,7 +707,7 @@ class Adam(Optimizer): ...@@ -711,7 +707,7 @@ class Adam(Optimizer):
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr) lr_dict['FP32_LODTensor'].append(lr)
elif ( elif (
param_and_grad[0].dtype == paddle.float16 self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
and param_and_grad[1].type and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR == core.VarDesc.VarType.LOD_TENSOR
): ):
......
...@@ -369,8 +369,8 @@ class AdamW(Optimizer): ...@@ -369,8 +369,8 @@ class AdamW(Optimizer):
""" """
if self._name is not None: if self._name is not None:
name = self._name + "_" + name name = self._name + "_" + name
find_master = ( find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 param.dtype
) )
target_param = ( target_param = (
self._master_weights[param.name] if find_master else param self._master_weights[param.name] if find_master else param
...@@ -389,7 +389,7 @@ class AdamW(Optimizer): ...@@ -389,7 +389,7 @@ class AdamW(Optimizer):
def _add_moments_pows(self, p): def _add_moments_pows(self, p):
acc_dtype = p.dtype acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16: if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype = core.VarDesc.VarType.FP32 acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
...@@ -423,16 +423,16 @@ class AdamW(Optimizer): ...@@ -423,16 +423,16 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
continue continue
if ( if (
p.dtype == core.VarDesc.VarType.FP16 self._is_dtype_fp16_or_bf16(p.dtype)
and not self._multi_precision and not self._multi_precision
): ):
warnings.warn( warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer." "Consider using multi_precision=True option of the Adam optimizer."
) )
self._add_moments_pows(p) self._add_moments_pows(p)
...@@ -463,9 +463,8 @@ class AdamW(Optimizer): ...@@ -463,9 +463,8 @@ class AdamW(Optimizer):
beta2_pow_acc = self._get_accumulator( beta2_pow_acc = self._get_accumulator(
self._beta2_pow_acc_str, param_and_grad[0] self._beta2_pow_acc_str, param_and_grad[0]
) )
find_master = ( find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
self._multi_precision param_and_grad[0].dtype
and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
) )
master_weight = ( master_weight = (
self._master_weights[param_and_grad[0].name] self._master_weights[param_and_grad[0].name]
......
...@@ -421,15 +421,21 @@ class Optimizer: ...@@ -421,15 +421,21 @@ class Optimizer:
return self._opti_name_list return self._opti_name_list
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr # lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr
_lr_dtype = ( _lr_dtype = (
paddle.get_default_dtype() if self._dtype is None else self._dtype paddle.get_default_dtype() if self._dtype is None else self._dtype
) )
_lr_dtype = ( _lr_dtype = (
paddle.float32 paddle.float32
if ( if (
paddle.get_default_dtype() != "float16" (
and _lr_dtype == paddle.float16 paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
)
or (
paddle.get_default_dtype() != "bfloat16"
and _lr_dtype == paddle.bfloat16
)
) )
else _lr_dtype else _lr_dtype
) )
...@@ -1526,3 +1532,17 @@ class Optimizer: ...@@ -1526,3 +1532,17 @@ class Optimizer:
For Multi Tensor, append optimize merged_operator to block. For Multi Tensor, append optimize merged_operator to block.
""" """
pass pass
def _is_dtype_fp16_or_bf16(self, dtype):
"""
check the dtype is fp16 or the dtype is bf16
:param dtype: instance of core.VarDesc.VarType
:return: True if dtype is one of fp16 or bf16, False otherwise
"""
assert isinstance(
dtype, core.VarDesc.VarType
), "The dtype should be an instance of core.VarDesc.VarType."
return (
dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册