未验证 提交 e5ed5257 编写于 作者: Y Yuang Liu 提交者: GitHub

Support bfloat16 for adamw and adam optimizer. Fit the lr for pure bf16...

Support bfloat16 for adamw and adam optimizer. Fit the lr for pure bf16 training with tensor fusion. (#48041)

* add bfloat16 for adamw

* set lr not to bfloat16 for pure bf16 training

* update the logic

* update the adamw optimizer

* support bfloat for adam
上级 4f57da5f
......@@ -268,7 +268,8 @@ PyObject* eager_api_get_grads_types(PyObject* self,
if (meta && grad.initialized()) {
if (grad.is_dense_tensor() &&
(tensor.dtype() == paddle::experimental::DataType::FLOAT32 ||
tensor.dtype() == paddle::experimental::DataType::FLOAT16)) {
tensor.dtype() == paddle::experimental::DataType::FLOAT16 ||
tensor.dtype() == paddle::experimental::DataType::BFLOAT16)) {
ret.emplace_back(
paddle::framework::TransToProtoVarType(tensor.dtype()));
}
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -300,7 +301,8 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel,
float,
double,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
......
......@@ -28,7 +28,7 @@ from paddle import _C_ops, _legacy_C_ops
__all__ = []
GRAD_TYPES = [int(paddle.float32), int(paddle.float16)]
GRAD_TYPES = [int(paddle.float32), int(paddle.float16), int(paddle.bfloat16)]
class Adam(Optimizer):
......@@ -265,8 +265,8 @@ class Adam(Optimizer):
"""
if self._name is not None:
name = self._name + "_" + name
find_master = (
self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
param.dtype
)
target_param = (
self._master_weights[param.name] if find_master else param
......@@ -285,10 +285,7 @@ class Adam(Optimizer):
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if (
acc_dtype == core.VarDesc.VarType.FP16
or acc_dtype == core.VarDesc.VarType.BF16
):
if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
......@@ -322,16 +319,16 @@ class Adam(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
continue
if (
p.dtype == core.VarDesc.VarType.FP16
self._is_dtype_fp16_or_bf16(p.dtype)
and not self._multi_precision
):
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
self._add_moments_pows(p)
......@@ -353,9 +350,8 @@ class Adam(Optimizer):
beta2_pow_acc = self._get_accumulator(
self._beta2_pow_acc_str, param_and_grad[0]
)
find_master = (
self._multi_precision
and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
param_and_grad[0].dtype
)
master_weight = (
self._master_weights[param_and_grad[0].name]
......@@ -571,7 +567,7 @@ class Adam(Optimizer):
def _multi_tensor_init(self, target_block, parameters, param_group_idx):
"""
All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (bfloat16, float16, float32).
This function will be overridden in the corresponding optimizer file.
Args:
target_block: the block in which the loss tensor is present
......@@ -604,7 +600,7 @@ class Adam(Optimizer):
self._beta2_pow_acc_dict['FP32_LODTensor'][
param_group_idx
].append(beta2_pow_acc)
elif param.dtype == paddle.float16:
elif self._is_dtype_fp16_or_bf16(param.dtype):
self._param_dict['FP16_LODTensor'][param_group_idx].append(
param
)
......@@ -628,7 +624,7 @@ class Adam(Optimizer):
self._master_weight_dict['FP16_LODTensor'] = None
else:
raise ValueError(
"Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR."
"Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR."
)
def _append_optimize_multi_tensor_op(
......@@ -656,7 +652,7 @@ class Adam(Optimizer):
)
lr = self._create_param_lr(parameters_and_grads[index])
lr_dict['FP32_LODTensor'].append(lr)
elif tp == GRAD_TYPES[1]:
elif tp == GRAD_TYPES[1] or tp == GRAD_TYPES[2]:
grad_dict['FP16_LODTensor'].append(
parameters_and_grads[index][1]
)
......@@ -678,7 +674,7 @@ class Adam(Optimizer):
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif (
param_and_grad[0].dtype == paddle.float16
self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR
):
......@@ -711,7 +707,7 @@ class Adam(Optimizer):
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif (
param_and_grad[0].dtype == paddle.float16
self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR
):
......
......@@ -369,8 +369,8 @@ class AdamW(Optimizer):
"""
if self._name is not None:
name = self._name + "_" + name
find_master = (
self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
param.dtype
)
target_param = (
self._master_weights[param.name] if find_master else param
......@@ -389,7 +389,7 @@ class AdamW(Optimizer):
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
......@@ -423,16 +423,16 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
continue
if (
p.dtype == core.VarDesc.VarType.FP16
self._is_dtype_fp16_or_bf16(p.dtype)
and not self._multi_precision
):
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
self._add_moments_pows(p)
......@@ -463,9 +463,8 @@ class AdamW(Optimizer):
beta2_pow_acc = self._get_accumulator(
self._beta2_pow_acc_str, param_and_grad[0]
)
find_master = (
self._multi_precision
and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
param_and_grad[0].dtype
)
master_weight = (
self._master_weights[param_and_grad[0].name]
......
......@@ -421,16 +421,22 @@ class Optimizer:
return self._opti_name_list
def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr
# lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr
_lr_dtype = (
paddle.get_default_dtype() if self._dtype is None else self._dtype
)
_lr_dtype = (
paddle.float32
if (
(
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
)
or (
paddle.get_default_dtype() != "bfloat16"
and _lr_dtype == paddle.bfloat16
)
)
else _lr_dtype
)
if isinstance(self._learning_rate, LRScheduler):
......@@ -1526,3 +1532,17 @@ class Optimizer:
For Multi Tensor, append optimize merged_operator to block.
"""
pass
def _is_dtype_fp16_or_bf16(self, dtype):
"""
check the dtype is fp16 or the dtype is bf16
:param dtype: instance of core.VarDesc.VarType
:return: True if dtype is one of fp16 or bf16, False otherwise
"""
assert isinstance(
dtype, core.VarDesc.VarType
), "The dtype should be an instance of core.VarDesc.VarType."
return (
dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册