未验证 提交 6f3c9643 编写于 作者: J JZ-LIANG 提交者: GitHub

Eb118 BF16 Adoption (#52827)

* pr1

* pr2

* pr3

* fixed unitest

* adopt for scale
上级 8cbc75ca
...@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ...@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
######################################### #########################################
AMP = "amp" AMP = "amp"
set_field_default_config(AMP, "enable", False) set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
...@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) ...@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_fp16_guard", True) set_field_default_config(AMP, "use_bf16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False) set_field_default_config(AMP, "use_optimizer_fp16", False)
######################################### #########################################
......
...@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_var, Out_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum', 'c_allreduce_sum',
) )
...@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, Out_grad,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}, },
) )
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
......
...@@ -254,17 +254,26 @@ class Parallelizer: ...@@ -254,17 +254,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"] self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"] + self._dist_context.serial_feed_vars["labels"]
) )
if config["use_pure_fp16"]: self._logger.info(
"Applying AMP-{}-{} ...".format(
config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
config["base_opt"] = optimizer config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) raise ValueError("AMP level should be one of o1, o2, o3")
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
......
...@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.fluid.contrib.mixed_precision.fp16_utils import ( from paddle.fluid.contrib.mixed_precision.fp16_lists import (
AutoMixedPrecisionLists, AutoMixedPrecisionLists,
) )
from paddle.fluid.contrib.mixed_precision.fp16_utils import ( from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_layer_norm_scale_bias_to_fp32, _keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32, _need_keep_fp32,
_valid_types, _valid_types,
_dtype_to_str,
) )
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute, OperatorDistributedAttribute,
...@@ -55,6 +54,23 @@ __amp_skip_ops__ = [ ...@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
'while', 'while',
'cast', 'cast',
] ]
__target_dtype__ = None
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16'
elif dtype == core.VarDesc.VarType.BF16:
return 'bf16'
else:
return 'fp32'
def set_op_dtype_to_fp16(op): def set_op_dtype_to_fp16(op):
...@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op): ...@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
op.has_attr('in_dtype') op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32 and op.attr('in_dtype') == core.VarDesc.VarType.FP32
): ):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16) op._set_attr('in_dtype', __target_dtype__)
if ( if (
op.has_attr('out_dtype') op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32 and op.attr('out_dtype') == core.VarDesc.VarType.FP32
): ):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr('out_dtype', __target_dtype__)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', __target_dtype__)
if __target_dtype__ == core.VarDesc.VarType.BF16:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
# adapot for backward op # adapot for backward op
...@@ -156,6 +178,7 @@ class FP16State(object): ...@@ -156,6 +178,7 @@ class FP16State(object):
list list
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]} ) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self.is_train = False self.is_train = False
self.out_var_op_deps = {}
def _is_fp16_op(self, op_id): def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None) return self._op_fp16_dict.get(op_id, None)
...@@ -169,6 +192,14 @@ class FP16State(object): ...@@ -169,6 +192,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks # assume all backward block are behind forward blocks
for block in self.program.blocks: for block in self.program.blocks:
for op in block.ops: for op in block.ops:
for name in op.output_arg_names:
if name not in self.out_var_op_deps:
self.out_var_op_deps[name] = [op.desc.original_id()]
else:
self.out_var_op_deps[name].extend(
[op.desc.original_id()]
)
self._mark_op(op) self._mark_op(op)
# set forward tensor dtype # set forward tensor dtype
...@@ -192,6 +223,18 @@ class FP16State(object): ...@@ -192,6 +223,18 @@ class FP16State(object):
if op.type == "assign" and "array_" in op.input_arg_names[0]: if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
return return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._op_fp16_dict[
self.out_var_op_deps[out_name][0]
]:
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
return
if _need_keep_fp32( if _need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard op, self.amp_list.unsupported_list, self.use_fp16_guard
): ):
...@@ -228,7 +271,7 @@ class FP16State(object): ...@@ -228,7 +271,7 @@ class FP16State(object):
return return
if var.dtype == core.VarDesc.VarType.FP32: if var.dtype == core.VarDesc.VarType.FP32:
var.desc.set_dtype(core.VarDesc.VarType.FP16) var.desc.set_dtype(__target_dtype__)
def resolute_tensor_dtype(self, block): def resolute_tensor_dtype(self, block):
...@@ -260,7 +303,7 @@ class FP16State(object): ...@@ -260,7 +303,7 @@ class FP16State(object):
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op): elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()) == True: if self._is_fp16_op(op.desc.original_id()) == True:
...@@ -276,7 +319,7 @@ class FP16State(object): ...@@ -276,7 +319,7 @@ class FP16State(object):
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
def cast_block(self, block): def cast_block(self, block):
...@@ -295,7 +338,7 @@ class FP16State(object): ...@@ -295,7 +338,7 @@ class FP16State(object):
op, op,
idx, idx,
block, block,
core.VarDesc.VarType.FP16, __target_dtype__,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
...@@ -305,7 +348,7 @@ class FP16State(object): ...@@ -305,7 +348,7 @@ class FP16State(object):
idx, idx,
block, block,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, __target_dtype__,
self.dist_context, self.dist_context,
) )
elif is_backward_op(op): elif is_backward_op(op):
...@@ -315,7 +358,7 @@ class FP16State(object): ...@@ -315,7 +358,7 @@ class FP16State(object):
op, op,
idx, idx,
block, block,
core.VarDesc.VarType.FP16, __target_dtype__,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
...@@ -325,7 +368,7 @@ class FP16State(object): ...@@ -325,7 +368,7 @@ class FP16State(object):
idx, idx,
block, block,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, __target_dtype__,
self.dist_context, self.dist_context,
) )
elif op.type == "sum": elif op.type == "sum":
...@@ -399,6 +442,9 @@ class FP16State(object): ...@@ -399,6 +442,9 @@ class FP16State(object):
dist_context, cast_var, ref_mapping, ref_mesh dist_context, cast_var, ref_mapping, ref_mesh
) )
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
...@@ -410,6 +456,9 @@ class FP16State(object): ...@@ -410,6 +456,9 @@ class FP16State(object):
OP_ROLE_KEY: OpRole.Forward, OP_ROLE_KEY: OpRole.Forward,
}, },
) )
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context cast_op, ref_mesh, ref_mapping, dist_context
) )
...@@ -455,22 +504,36 @@ class FP16State(object): ...@@ -455,22 +504,36 @@ class FP16State(object):
) in self.forward_input_cast_ops[forward_op_id]: ) in self.forward_input_cast_ops[forward_op_id]:
# rename input # rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if op.type != "scale" and slot_name in op.input_names:
assert src_name in op.input( assert src_name in op.input(
slot_name slot_name
), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op)) ), "var: {} not in op's {}. {}".format(
src_name, slot_name, str(op)
)
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name) op._rename_input(src_name, cast_name)
grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)
# NOTE Special for scale op, scale op's grad op is scale,
# so slot name map rule could not apply to grad scale op
# cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X.
if op.type == "scale":
grad_slot_name = "X"
# create cast grad # create cast grad
else:
grad_slot_name = slot_name + "@GRAD" grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
if grad_slot_name in op.output_names:
# some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0: if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue continue
assert len(op.output(grad_slot_name)) == 1 assert (
len(op.output(grad_slot_name)) == 1
), "[{}], Current Op: {}".format(grad_slot_name, str(op))
grad_name = op.output(grad_slot_name)[0] grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name) grad = block.var(grad_name)
grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
...@@ -492,7 +555,9 @@ class FP16State(object): ...@@ -492,7 +555,9 @@ class FP16State(object):
cast_grad, grad_dist_attr cast_grad, grad_dist_attr
) )
op._rename_output(grad_name, cast_grad.name) op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) grad_op_attr.set_output_dist_attr(
cast_grad.name, grad_dist_attr
)
# add cast # add cast
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
...@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def _split_grads(params_grads): def _split_grads(params_grads):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] fp16_grads = [g for g in grads if g.dtype == __target_dtype__]
assert len(fp32_grads) + len(fp16_grads) == len( assert len(fp32_grads) + len(fp16_grads) == len(
grads grads
), "Data types of all grads must be either fp16 or fp32." ), "Data types of all grads must be either fp16 or fp32."
...@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
# TODO to support CUDAPinned/NPU/XPU Places # TODO to support CUDAPinned/NPU/XPU Places
if direction == "D2H": if direction == "D2H":
dst_place_type = 0 dst_place_type = 0
elif direction == "D2H":
dst_place_type = 1
else: else:
raise NotImplementedError( raise NotImplementedError(
"direction [{}] is not supported yet.".format(direction) f"direction [{direction}] is not supported yet."
) )
attrs = {'dst_place_type': dst_place_type} attrs = {'dst_place_type': dst_place_type}
new_op = block._insert_op_without_sync( new_op = block._insert_op_without_sync(
index=idx, index=idx,
type='memcpy', type='memcpy_d2h',
inputs={'X': [src_var]}, inputs={'X': [src_var]},
outputs={'Out': [output_var]}, outputs={'Out': [output_var]},
attrs=attrs, attrs=attrs,
...@@ -678,17 +741,17 @@ def cast_startup_program(): ...@@ -678,17 +741,17 @@ def cast_startup_program():
for op in startup_program.global_block().ops: for op in startup_program.global_block().ops:
if is_initialization_op(op): if is_initialization_op(op):
output_name = op.output_arg_names[0] output_name = op.output_arg_names[0]
if ( if param_to_dtype.get(output_name, None) == __target_dtype__:
param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
assert op.has_attr( assert op.has_attr(
'dtype' 'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format( ), "initialization op is supported to has dtype attribute but got {}.".format(
str(op) str(op)
) )
out_var = startup_program.global_block().var(output_name)
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(__target_dtype__)
if op.attr('dtype') == core.VarDesc.VarType.FP32: if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', __target_dtype__)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
...@@ -701,14 +764,44 @@ class FP16Pass(AMPPass): ...@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification. # in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context") self.dist_context = self.get_attr("dist_context")
self.target_dtype = self.get_attr("dtype")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None)
if self.use_optimizer_fp16 is None:
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"
# swith enviroment for fp16 / bf16.
if self.target_dtype == "float16":
__target_dtype = core.VarDesc.VarType.FP16
elif self.target_dtype == "bfloat16":
__target_dtype = core.VarDesc.VarType.BF16
else:
raise NotImplementedError(
"target dtype [{}] is for amp o2 not supported yet.".format(
self.target_dtype
)
)
global __target_dtype__
__target_dtype__ = __target_dtype
amp_list = AutoMixedPrecisionLists( amp_list = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")), set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), set(self.get_attr("custom_black_list")),
None, dtype=self.target_dtype,
) )
amp_list.unsupported_list -= {
"conditional_block_grad",
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"while_grad",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
# NOTE don't not change input data dtype, since it is controled by dataloader # NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass # and which is out of control of FP16 Pass
input_data_var_names = [var.name for var in self.get_attr("input_data")] input_data_var_names = [var.name for var in self.get_attr("input_data")]
...@@ -726,6 +819,7 @@ class FP16Pass(AMPPass): ...@@ -726,6 +819,7 @@ class FP16Pass(AMPPass):
cast_startup_program() cast_startup_program()
if is_train: if is_train:
if self.target_dtype == "fp16":
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var() self._init_amp_var()
...@@ -801,10 +895,12 @@ class FP16Pass(AMPPass): ...@@ -801,10 +895,12 @@ class FP16Pass(AMPPass):
# modify optimizer # modify optimizer
base_opt = self.get_attr("base_opt") base_opt = self.get_attr("base_opt")
base_opt._multi_precision = True base_opt._multi_precision = True
if self.get_attr("use_optimizer_fp16"): if self.use_optimizer_fp16:
base_opt._multi_precision = False base_opt._multi_precision = False
if self.target_dtype == "fp16":
if isinstance( if isinstance(
base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW) base_opt,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
): ):
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy( # found_inf = paddle.tensor.creation._memcpy(
......
...@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS}) py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass ENVS ${dist_ENVS})
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_iterable_dataset set_tests_properties(test_iterable_dataset
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip = None
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 25)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 11)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
...@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None): ...@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"] amp.level = level
amp.use_optimizer_fp16 = level == "o3" amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level) print("amp level: ", level)
return strategy return strategy
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
import os import os
# import yaml # import yaml
import unittest import unittest
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
class TestStrategy(unittest.TestCase): class TestStrategy(unittest.TestCase):
def test_default_config(self): def test_default_config(self):
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase): ...@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
...@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase): ...@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False) self.assertEqual(amp.use_optimizer_fp16, False)
sharding = strategy.sharding sharding = strategy.sharding
...@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase): ...@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"] amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"] amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"] amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True) self.assertEqual(amp.enable, True)
...@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase): ...@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"]) self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True) self.assertEqual(amp.use_optimizer_fp16, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册