未验证 提交 9ded5707 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel Performance] Support BF16 Training (#51285)

* update env setting

* update pass logic

* dist op support bf16

* backward cast update

* update setting

* update backward

* revert amp pass

* update fp16 backward logic

* register c_embedding bf16

* revert engine

* add unitest

* add unitest

* update unitest

* update cmake

* update math

* update math.py

* update unitest

* update unitest

* revise unitest

* revise unitest

* update unitest

* update unitest

* update unitest
上级 3094d475
...@@ -198,8 +198,14 @@ namespace plat = paddle::platform; ...@@ -198,8 +198,14 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding, REGISTER_OP_CUDA_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel<float>, ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>, ops::CEmbeddingCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingCUDAKernel<plat::bfloat16>,
#endif
ops::CEmbeddingCUDAKernel<plat::float16>); ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad, REGISTER_OP_CUDA_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel<float>, ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>, ops::CEmbeddingGradCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingGradCUDAKernel<plat::bfloat16>,
#endif
ops::CEmbeddingGradCUDAKernel<plat::float16>); ops::CEmbeddingGradCUDAKernel<plat::float16>);
...@@ -63,6 +63,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ...@@ -63,6 +63,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
######################################### #########################################
AMP = "amp" AMP = "amp"
set_field_default_config(AMP, "enable", False) set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
...@@ -72,15 +74,12 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) ...@@ -72,15 +74,12 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False) set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "enable_bf16", False)
set_field_default_config(AMP, "custom_bf16_list", []) set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", []) set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", []) set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_pure_bf16", False)
set_field_default_config(AMP, "use_bf16_guard", False) set_field_default_config(AMP, "use_bf16_guard", False)
######################################### #########################################
......
...@@ -455,7 +455,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -455,7 +455,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_var, Out_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum', 'c_allreduce_sum',
) )
...@@ -645,7 +645,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -645,7 +645,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, Out_grad,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -687,12 +687,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -687,12 +687,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}, },
) )
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
......
...@@ -220,27 +220,26 @@ class Parallelizer: ...@@ -220,27 +220,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"] self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"] + self._dist_context.serial_feed_vars["labels"]
) )
if config["enable_bf16"]: self._logger.info(
auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config) "Applying AMP-{}-{} ...".format(
auto_parallel_bf16_pass.apply( config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_bf16_pass.get_loss() loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
elif config["use_pure_fp16"]:
config["base_opt"] = optimizer config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_fp16_pass.get_loss() loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) raise ValueError("AMP level should be one of o1, o2, o3")
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
# apply quantization pass # apply quantization pass
# The pass can be applied when mode must be 'train' # The pass can be applied when mode must be 'train'
......
...@@ -632,6 +632,7 @@ class AMPPass(PassBase): ...@@ -632,6 +632,7 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", []) self.set_attr("input_data", [])
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None self._loss = None
self._loss_scaling = None self._loss_scaling = None
self._num_good_steps = None self._num_good_steps = None
...@@ -639,6 +640,8 @@ class AMPPass(PassBase): ...@@ -639,6 +640,8 @@ class AMPPass(PassBase):
self._loss = None self._loss = None
def _check_self(self): def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0: if self.get_attr("init_loss_scaling") < 0:
return False return False
if self.get_attr("incr_every_n_steps") < 0: if self.get_attr("incr_every_n_steps") < 0:
......
...@@ -29,13 +29,9 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -29,13 +29,9 @@ from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.framework import core from paddle.framework import core
from paddle.static import default_main_program, default_startup_program from paddle.static import default_main_program, default_startup_program
from paddle.static.amp.fp16_utils import (
AutoMixedPrecisionLists, # NOTE bf16 and fp16 may have diff logic for _keep_layer_norm_scale_bias_to_fp32
_dtype_to_str, from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
_keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32,
_valid_types,
)
from paddle.utils import unique_name from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.process_mesh import ProcessMesh
...@@ -50,6 +46,8 @@ __amp_skip_ops__ = [ ...@@ -50,6 +46,8 @@ __amp_skip_ops__ = [
'while', 'while',
'cast', 'cast',
] ]
__target_dtype__ = None
__amp_utils__ = None
def set_op_dtype_to_fp16(op): def set_op_dtype_to_fp16(op):
...@@ -57,17 +55,24 @@ def set_op_dtype_to_fp16(op): ...@@ -57,17 +55,24 @@ def set_op_dtype_to_fp16(op):
op.has_attr('in_dtype') op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32 and op.attr('in_dtype') == core.VarDesc.VarType.FP32
): ):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16) op._set_attr('in_dtype', __target_dtype__)
if ( if (
op.has_attr('out_dtype') op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32 and op.attr('out_dtype') == core.VarDesc.VarType.FP32
): ):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr('out_dtype', __target_dtype__)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', __target_dtype__)
if __target_dtype__ == core.VarDesc.VarType.BF16:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
# adapot for backward op # adapot for backward op
# TODO check if bf16 and fp16 still share the same logic
def _keep_fp32_input(op, in_name): def _keep_fp32_input(op, in_name):
op_type = op.type op_type = op.type
if op_type == 'batch_norm': if op_type == 'batch_norm':
...@@ -96,6 +101,7 @@ def _keep_fp32_input(op, in_name): ...@@ -96,6 +101,7 @@ def _keep_fp32_input(op, in_name):
return False return False
# TODO check if bf16 and fp16 still share the same logic
def _keep_fp32_output(op, out_name): def _keep_fp32_output(op, out_name):
op_type = op.type op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation']: if op_type in ['batch_norm', 'fused_bn_add_activation']:
...@@ -208,7 +214,7 @@ class FP16State: ...@@ -208,7 +214,7 @@ class FP16State:
self._op_fp16_dict[op.desc.original_id()] = True self._op_fp16_dict[op.desc.original_id()] = True
return return
if _need_keep_fp32( if __amp_utils__._need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard op, self.amp_list.unsupported_list, self.use_fp16_guard
): ):
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
...@@ -240,11 +246,15 @@ class FP16State: ...@@ -240,11 +246,15 @@ class FP16State:
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
if var is None or var.type not in _valid_types or "array_" in var_name: if (
var is None
or var.type not in __amp_utils__._valid_types
or "array_" in var_name
):
return return
if var.dtype == core.VarDesc.VarType.FP32: if var.dtype == core.VarDesc.VarType.FP32:
var.desc.set_dtype(core.VarDesc.VarType.FP16) var.desc.set_dtype(__target_dtype__)
def resolute_tensor_dtype(self, block): def resolute_tensor_dtype(self, block):
...@@ -274,9 +284,12 @@ class FP16State: ...@@ -274,9 +284,12 @@ class FP16State:
elif self._is_fp16_op(op.desc.original_id()) is False: elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if (
out_var is None
or out_var.type not in __amp_utils__._valid_types
):
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op): elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()) is True: if self._is_fp16_op(op.desc.original_id()) is True:
...@@ -290,9 +303,12 @@ class FP16State: ...@@ -290,9 +303,12 @@ class FP16State:
elif self._is_fp16_op(op.desc.original_id()) is False: elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if (
out_var is None
or out_var.type not in __amp_utils__._valid_types
):
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
def cast_block(self, block): def cast_block(self, block):
...@@ -311,7 +327,7 @@ class FP16State: ...@@ -311,7 +327,7 @@ class FP16State:
op, op,
idx, idx,
block, block,
core.VarDesc.VarType.FP16, __target_dtype__,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
...@@ -321,7 +337,7 @@ class FP16State: ...@@ -321,7 +337,7 @@ class FP16State:
idx, idx,
block, block,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, __target_dtype__,
self.dist_context, self.dist_context,
) )
elif is_backward_op(op): elif is_backward_op(op):
...@@ -331,7 +347,7 @@ class FP16State: ...@@ -331,7 +347,7 @@ class FP16State:
op, op,
idx, idx,
block, block,
core.VarDesc.VarType.FP16, __target_dtype__,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
...@@ -341,7 +357,7 @@ class FP16State: ...@@ -341,7 +357,7 @@ class FP16State:
idx, idx,
block, block,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, __target_dtype__,
self.dist_context, self.dist_context,
) )
elif op.type == "sum": elif op.type == "sum":
...@@ -379,14 +395,16 @@ class FP16State: ...@@ -379,14 +395,16 @@ class FP16State:
in_var = block._find_var_recursive(in_var_name) in_var = block._find_var_recursive(in_var_name)
if ( if (
in_var is None in_var is None
or in_var.type not in _valid_types or in_var.type not in __amp_utils__._valid_types
or in_var.dtype == dst_dtype or in_var.dtype == dst_dtype
): ):
continue continue
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
cast_name = ( cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype) in_var.name
+ '.cast_'
+ __amp_utils__._dtype_to_str(dst_dtype)
) )
cast_var = block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op.desc.original_id()] += [ self.forward_input_cast_ops[op.desc.original_id()] += [
...@@ -476,14 +494,15 @@ class FP16State: ...@@ -476,14 +494,15 @@ class FP16State:
slot_name, slot_name,
) in self.forward_input_cast_ops[forward_op_id]: ) in self.forward_input_cast_ops[forward_op_id]:
# rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if slot_name not in op.input_names: if slot_name in op.input_names:
continue
# rename input
assert src_name in op.input( assert src_name in op.input(
slot_name slot_name
), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op)) ), "var: {} not in op's {}. {}".format(
src_name, slot_name, str(op)
)
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name) op._rename_input(src_name, cast_name)
...@@ -491,9 +510,7 @@ class FP16State: ...@@ -491,9 +510,7 @@ class FP16State:
# create cast grad # create cast grad
grad_slot_name = slot_name + "@GRAD" grad_slot_name = slot_name + "@GRAD"
if grad_slot_name not in op.output_names: if grad_slot_name in op.output_names:
continue
# some forward input maybe stop_gradient=True, e.g. input_mask # some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0: if len(op.output(grad_slot_name)) == 0:
continue continue
...@@ -521,7 +538,9 @@ class FP16State: ...@@ -521,7 +538,9 @@ class FP16State:
cast_grad, grad_dist_attr cast_grad, grad_dist_attr
) )
op._rename_output(grad_name, cast_grad.name) op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) grad_op_attr.set_output_dist_attr(
cast_grad.name, grad_dist_attr
)
# add cast # add cast
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
...@@ -604,7 +623,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -604,7 +623,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def _split_grads(params_grads): def _split_grads(params_grads):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] fp16_grads = [g for g in grads if g.dtype == __target_dtype__]
assert len(fp32_grads) + len(fp16_grads) == len( assert len(fp32_grads) + len(fp16_grads) == len(
grads grads
), "Data types of all grads must be either fp16 or fp32." ), "Data types of all grads must be either fp16 or fp32."
...@@ -707,17 +726,17 @@ def cast_startup_program(): ...@@ -707,17 +726,17 @@ def cast_startup_program():
for op in startup_program.global_block().ops: for op in startup_program.global_block().ops:
if is_initialization_op(op): if is_initialization_op(op):
output_name = op.output_arg_names[0] output_name = op.output_arg_names[0]
if ( if param_to_dtype.get(output_name, None) == __target_dtype__:
param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
assert op.has_attr( assert op.has_attr(
'dtype' 'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format( ), "initialization op is supported to has dtype attribute but got {}.".format(
str(op) str(op)
) )
out_var = startup_program.global_block().var(output_name)
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(__target_dtype__)
if op.attr('dtype') == core.VarDesc.VarType.FP32: if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', __target_dtype__)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
...@@ -730,9 +749,37 @@ class FP16Pass(AMPPass): ...@@ -730,9 +749,37 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification. # in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context") self.dist_context = self.get_attr("dist_context")
self.target_dtype = self.get_attr("dtype")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
amp_list = AutoMixedPrecisionLists( self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None)
if self.use_optimizer_fp16 is None:
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"
# swith enviroment for fp16 / bf16.
if self.target_dtype == "float16":
import paddle.static.amp.fp16_utils as amp_utils
AMPList = amp_utils.AutoMixedPrecisionLists
__target_dtype = core.VarDesc.VarType.FP16
elif self.target_dtype == "bfloat16":
import paddle.static.amp.bf16.amp_utils as amp_utils
AMPList = amp_utils.AutoMixedPrecisionListsBF16
__target_dtype = core.VarDesc.VarType.BF16
else:
raise NotImplementedError(
"target dtype [{}] is for amp o2 not supported yet.".format(
self.target_dtype
)
)
global __target_dtype__
__target_dtype__ = __target_dtype
global __amp_utils__
__amp_utils__ = amp_utils
amp_list = AMPList(
set(self.get_attr("custom_white_list")), set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), set(self.get_attr("custom_black_list")),
None, None,
...@@ -747,7 +794,9 @@ class FP16Pass(AMPPass): ...@@ -747,7 +794,9 @@ class FP16Pass(AMPPass):
main_program, main_program,
amp_list, amp_list,
self.dist_context, self.dist_context,
self.get_attr("use_fp16_guard"), self.get_attr(
"use_fp16_guard"
), # TODO unify to use_amp_guard to be compatible with amp o1
input_data_var_names, input_data_var_names,
) )
is_train = fp16_state._build_state() is_train = fp16_state._build_state()
...@@ -755,6 +804,7 @@ class FP16Pass(AMPPass): ...@@ -755,6 +804,7 @@ class FP16Pass(AMPPass):
cast_startup_program() cast_startup_program()
if is_train: if is_train:
if self.target_dtype == "fp16":
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var() self._init_amp_var()
...@@ -864,11 +914,12 @@ class FP16Pass(AMPPass): ...@@ -864,11 +914,12 @@ class FP16Pass(AMPPass):
# modify optimizer # modify optimizer
base_opt = self.get_attr("base_opt") base_opt = self.get_attr("base_opt")
base_opt._multi_precision = True base_opt._multi_precision = True
if self.get_attr("use_optimizer_fp16"): if self.use_optimizer_fp16:
base_opt._multi_precision = False base_opt._multi_precision = False
if self.target_dtype == "fp16":
if isinstance( if isinstance(
base_opt, base_opt, (paddle.static.Adam, paddle.optimizer.AdamW)
(paddle.static.Adam, paddle.optimizer.AdamW),
): ):
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy( # found_inf = paddle.tensor.creation._memcpy(
......
...@@ -49,6 +49,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -49,6 +49,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_amp MODULES test_pass_amp) py_test_modules(test_pass_amp MODULES test_pass_amp)
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass)
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 26)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 10)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
...@@ -37,7 +37,7 @@ def apply_pass(use_amp=False, level=None): ...@@ -37,7 +37,7 @@ def apply_pass(use_amp=False, level=None):
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"] amp.level = level
amp.use_optimizer_fp16 = level == "o3" amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level) print("amp level: ", level)
return strategy return strategy
......
...@@ -39,7 +39,7 @@ def apply_pass(): ...@@ -39,7 +39,7 @@ def apply_pass():
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = True amp.level = "o2"
qat = dist_strategy.qat qat = dist_strategy.qat
qat.enable = True qat.enable = True
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -28,6 +28,8 @@ class TestStrategy(unittest.TestCase): ...@@ -28,6 +28,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
...@@ -37,15 +39,11 @@ class TestStrategy(unittest.TestCase): ...@@ -37,15 +39,11 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False) self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.enable_bf16, False)
self.assertEqual(amp.custom_bf16_list, []) self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, []) self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, []) self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_pure_bf16, False)
self.assertEqual(amp.use_bf16_guard, False) self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding sharding = strategy.sharding
...@@ -102,7 +100,6 @@ class TestStrategy(unittest.TestCase): ...@@ -102,7 +100,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"] amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"] amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"] amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True) self.assertEqual(amp.enable, True)
...@@ -115,7 +112,6 @@ class TestStrategy(unittest.TestCase): ...@@ -115,7 +112,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"]) self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True) self.assertEqual(amp.use_optimizer_fp16, True)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
import warnings import warnings
from sqlite3 import NotSupportedError
import paddle import paddle
import paddle.autograd as imperative_base import paddle.autograd as imperative_base
...@@ -217,7 +218,9 @@ def _squared_l2_norm(x): ...@@ -217,7 +218,9 @@ def _squared_l2_norm(x):
return _C_ops.squared_l2_norm(x) return _C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm' op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'float16'], op_type) check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type
)
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
...@@ -557,6 +560,20 @@ def _allow_pure_fp16_global_norm_clip(*args): ...@@ -557,6 +560,20 @@ def _allow_pure_fp16_global_norm_clip(*args):
return old_value return old_value
_allow_pure_bf16_global_norm_clip_flag = False
def _allow_pure_bf16_global_norm_clip(*args):
global _allow_pure_bf16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_bf16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_bf16_global_norm_clip_flag
_allow_pure_bf16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase): class ClipGradByGlobalNorm(ClipGradBase):
r""" r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
...@@ -720,6 +737,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -720,6 +737,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list = []
sum_square_list_fp16 = [] sum_square_list_fp16 = []
sum_square_list_bf16 = []
sum_square_list_fp32 = [] sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'): with framework.name_scope('gradient_clip'):
for p, g in params_grads: for p, g in params_grads:
...@@ -735,17 +753,29 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -735,17 +753,29 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square = _squared_l2_norm(merge_grad) sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16: if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square) sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_bf16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32: elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square) sum_square_list_fp32.append(sum_square)
else: else:
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0:
raise NotSupportedError(
'FP16 and BF16 are not supported at the same time.'
)
# all parameters have been filterd out # all parameters have been filterd out
if ( if (
len(sum_square_list) len(sum_square_list)
+ len(sum_square_list_fp16) + len(sum_square_list_fp16)
+ len(sum_square_list_fp32) + len(sum_square_list_fp32)
== 0 == 0
) and (
len(sum_square_list)
+ len(sum_square_list_bf16)
+ len(sum_square_list_fp32)
== 0
): ):
return params_grads return params_grads
...@@ -765,6 +795,18 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -765,6 +795,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
) )
else: else:
global_norm_var.append(global_norm_var_fp16) global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_bf16) > 0:
global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_bf16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_bf16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_bf16)
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32': if sum_dtype == 'float32':
...@@ -804,12 +846,18 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -804,12 +846,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g) new_g = _cast_to_mp_type_if_enabled(g)
# inplace # inplace
scale_input = ( if (
scale_var.astype('float16') new_g.dtype == core.VarDesc.VarType.FP16
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16 and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var ):
) scale_input = scale_var.astype('float16')
elif (
new_g.dtype == core.VarDesc.VarType.BF16
and scale_var.dtype != core.VarDesc.VarType.BF16
):
scale_input = scale_var.astype('bfloat16')
else:
scale_input = scale_var
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops. # will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter # We need to handle the correct block, otherwise will encounter
......
...@@ -1657,14 +1657,21 @@ def add_n(inputs, name=None): ...@@ -1657,14 +1657,21 @@ def add_n(inputs, name=None):
check_variable_and_dtype( check_variable_and_dtype(
input, input,
"inputs", "inputs",
['float16', 'float32', 'float64', 'int32', 'int64'], [
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'add_n', 'add_n',
) )
else: else:
check_variable_and_dtype( check_variable_and_dtype(
inputs, inputs,
"inputs", "inputs",
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'add_n', 'add_n',
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册