未验证 提交 5cb0f3aa 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] BF16-o1/FP16-o1 PASS support training and generation (#51147)

* [AutoParallel] support bloom

* fix import

* align amp and bf16

* update func name

* clipbyglobalnorm and add_n support bf16

* upgrade amp strategy api

* update bf16 unittest

* fix static clip

---------
Co-authored-by: Nliangjianzhong <liangjianzhong@baidu.com>
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 32baca93
...@@ -723,6 +723,14 @@ class Completer: ...@@ -723,6 +723,14 @@ class Completer:
tensor_dist_attr.process_mesh = ( tensor_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh nearest_tensor_dist_attr.process_mesh
) )
for node in while_op_node.inputs:
if node.var().name() == tensor_name:
node_dist_attr = (
self._dist_context.get_dist_attr_for_graph(node)
)
node_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
# Step 4: set the process meshes of the outputs in while_op to the process meshes of the outside output nodes # Step 4: set the process meshes of the outputs in while_op to the process meshes of the outside output nodes
while_op_outputs_dist_attrs = while_op_dist_attr.outputs_dist_attrs while_op_outputs_dist_attrs = while_op_dist_attr.outputs_dist_attrs
...@@ -749,6 +757,14 @@ class Completer: ...@@ -749,6 +757,14 @@ class Completer:
tensor_dist_attr.process_mesh = ( tensor_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh nearest_tensor_dist_attr.process_mesh
) )
for node in while_op_node.outputs:
if node.var().name() == tensor_name:
node_dist_attr = (
self._dist_context.get_dist_attr_for_graph(node)
)
node_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
# Amend the process meshes related to array # Amend the process meshes related to array
for array_node_list in self._array_nodes.values(): for array_node_list in self._array_nodes.values():
......
...@@ -75,11 +75,6 @@ set_field_default_config(AMP, "custom_white_list", []) ...@@ -75,11 +75,6 @@ set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_fp16_guard", False) set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_bf16_guard", False) set_field_default_config(AMP, "use_bf16_guard", False)
######################################### #########################################
......
...@@ -1557,6 +1557,19 @@ class Engine: ...@@ -1557,6 +1557,19 @@ class Engine:
cur_dist_attr = auto_utils.get_dist_attr(program, dist_context) cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr) converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict) state_dict = converter.convert(strict=strict)
for name, param in program.state_dict().items():
param_array = np.array(param)
if name not in state_dict:
continue
if param_array.dtype != state_dict[name].dtype:
self._logger.info(
"cast {}'s dtype from '{}' to '{}'".format(
name,
str(state_dict[name].dtype),
str(param_array.dtype),
)
)
state_dict[name] = state_dict[name].astype(param_array.dtype)
program.set_state_dict(state_dict) program.set_state_dict(state_dict)
def save(self, path, training=True): def save(self, path, training=True):
......
...@@ -272,7 +272,8 @@ def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True): ...@@ -272,7 +272,8 @@ def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
return best_compatible_impl return best_compatible_impl
def is_parameter_related(varname, block): def is_parameter_related(varname, block, dist_context=None):
# TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are be renamed
if ".subprog_" in varname: if ".subprog_" in varname:
varname = varname[: varname.index(".subprog_")] varname = varname[: varname.index(".subprog_")]
if ".cast_fp" in varname: if ".cast_fp" in varname:
...@@ -281,10 +282,17 @@ def is_parameter_related(varname, block): ...@@ -281,10 +282,17 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".cast_bf")] varname = varname[: varname.index(".cast_bf")]
if ".quantized" in varname: if ".quantized" in varname:
varname = varname[: varname.index(".quantized")] varname = varname[: varname.index(".quantized")]
# if "@RESHARD" in varname: assert block._find_var_recursive(
# varname = varname[: varname.index("@RESHARD")] varname
assert block._find_var_recursive(varname) ), "cannot find var {} in cur block".format(varname)
var = block._var_recursive(varname) var = block._var_recursive(varname)
# NOTE(hack method): to find the param which is resharded
if dist_context and "@RESHARD" in varname:
varname = varname[: varname.index("@RESHARD")]
serial_program = dist_context.serial_main_program
var = serial_program.global_block()._find_var_recursive(varname)
if var is None:
return False
return var.is_parameter return var.is_parameter
......
...@@ -28,6 +28,9 @@ class DistributedScale(DistributedOperatorImplContainer): ...@@ -28,6 +28,9 @@ class DistributedScale(DistributedOperatorImplContainer):
register_distributed_operator_impl_container(DistributedScale("scale")) register_distributed_operator_impl_container(DistributedScale("scale"))
register_distributed_operator_impl_container(DistributedScale("fill_any_like"))
register_distributed_operator_impl_container(DistributedScale("where"))
register_distributed_operator_impl_container(DistributedScale("tanh"))
class DistributedScaleImpl(DistributedOperatorImpl): class DistributedScaleImpl(DistributedOperatorImpl):
...@@ -50,11 +53,15 @@ class DistributedScaleImpl(DistributedOperatorImpl): ...@@ -50,11 +53,15 @@ class DistributedScaleImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
in_dims_mappings = []
for in_name in op_desc.input_arg_names():
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
in_dims_mappings.append(in_dims_mapping)
for x_dims_mapping in in_dims_mappings:
if x_dims_mapping != out_dims_mapping: if x_dims_mapping != out_dims_mapping:
return False return False
...@@ -78,10 +85,6 @@ class DistributedScaleImpl(DistributedOperatorImpl): ...@@ -78,10 +85,6 @@ class DistributedScaleImpl(DistributedOperatorImpl):
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping) op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
changed = True changed = True
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
@staticmethod @staticmethod
...@@ -94,3 +97,8 @@ class DistributedScaleImpl(DistributedOperatorImpl): ...@@ -94,3 +97,8 @@ class DistributedScaleImpl(DistributedOperatorImpl):
register_distributed_operator_impl("scale", DistributedScaleImpl("scale")) register_distributed_operator_impl("scale", DistributedScaleImpl("scale"))
register_distributed_operator_impl(
"fill_any_like", DistributedScaleImpl("fill_any_like")
)
register_distributed_operator_impl("where", DistributedScaleImpl("where"))
register_distributed_operator_impl("tanh", DistributedScaleImpl("tanh"))
...@@ -2213,7 +2213,11 @@ class Resharder: ...@@ -2213,7 +2213,11 @@ class Resharder:
else: else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name) op_input_attrs = self._get_common_op_input_attrs(op, var_name)
assert op_input_attrs assert (
op_input_attrs
), "The input '{}' of op '{}' has no distibution attributes in subblock".format(
op.name, var_name
)
return op_input_attrs return op_input_attrs
......
...@@ -18,7 +18,6 @@ from .auto_parallel_gradient_merge import * # noqa: F403 ...@@ -18,7 +18,6 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403 from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403 from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403 from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_bf16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403 from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403 from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403 from .auto_parallel_data_parallel_optimization import * # noqa: F403
......
...@@ -18,16 +18,18 @@ from paddle.distributed.auto_parallel.process_group import ( ...@@ -18,16 +18,18 @@ from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
get_loss_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
) )
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core from paddle.framework import core
from paddle.static.amp.bf16.amp_utils import (
AutoMixedPrecisionListsBF16,
_is_in_fp32_varnames,
)
from paddle.static.amp.fp16_utils import ( from paddle.static.amp.fp16_utils import (
AutoMixedPrecisionLists, AutoMixedPrecisionLists,
_dtype_to_str,
_is_in_black_varnames, _is_in_black_varnames,
_keep_fp32_input, _keep_fp32_input,
_keep_fp32_output, _keep_fp32_output,
...@@ -40,83 +42,232 @@ from paddle.static.amp.fp16_utils import ( ...@@ -40,83 +42,232 @@ from paddle.static.amp.fp16_utils import (
from paddle.utils import unique_name from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op from ..auto_parallel.utils import (
is_backward_op,
is_forward_op,
is_loss_grad_op,
is_loss_op,
is_optimize_op,
)
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
__amp_skip_ops__ = [
'create_py_reader',
'create_double_buffer_reader',
'cast',
'while',
]
def _dtype_to_str(dtype):
if dtype == core.VarDesc.VarType.FP16:
return 'fp16'
elif dtype == core.VarDesc.VarType.BF16:
return 'bf16'
else:
return 'fp32'
def _str_to_dtype(dstr):
if dstr == 'float16':
return core.VarDesc.VarType.FP16
elif dstr == 'bfloat16':
return core.VarDesc.VarType.BF16
else:
return core.VarDesc.VarType.FP32
class AMPLists:
def __init__(
self,
white_list=None,
black_list=None,
black_varnames=None,
dtype="float16",
):
self._amp_list = None
if dtype == "float16":
self._amp_list = AutoMixedPrecisionLists(
set(white_list), set(black_list), set(black_varnames)
)
elif dtype == "bfloat16":
self._amp_list = AutoMixedPrecisionListsBF16(
set(white_list), set(black_list), set(black_varnames)
)
assert self._amp_list is not None
self._dtype = dtype
self._is_float16 = dtype == "float16"
@property
def white_list(self):
if self._is_float16:
return self._amp_list.white_list
else:
return self._amp_list.bf16_list
@property
def black_list(self):
if self._is_float16:
return self._amp_list.black_list
else:
return self._amp_list.fp32_list
@property
def gray_list(self):
return self._amp_list.gray_list
@property
def black_varnames(self):
if self._is_float16:
return self._amp_list.black_varnames
else:
return self._amp_list.fp32_varnames
@property
def is_fp16(self):
return self._is_float16
@property
def dtype(self):
return self._dtype
@property
def amp_list(self):
return self._amp_list
def _is_in_black_fp32_varnames(self, op):
if self._is_float16:
return _is_in_black_varnames(op, self._amp_list)
else:
return _is_in_fp32_varnames(op, self._amp_list)
def _op_keep_fp32_input(self, op, in_name):
if self._is_float16:
return _keep_fp32_input(op, in_name)
else:
if op.type in ['batch_norm', 'layer_norm']:
return in_name != 'X'
if op.type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
return False
def _op_keep_fp32_output(self, op, out_name):
if self._is_float16:
return _keep_fp32_output(op, out_name)
else:
if op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
]:
return out_name != 'Y'
return False
class AMPState: class AMPState:
def __init__(self, block): def __init__(self, program, amp_lists, amp_dtype, dist_context):
self._block = block self.program = program
self._op_fp16_dict = ( self.dist_context = dist_context
{} self.amp_lists = amp_lists
) # op_id --> True/False. 'True' means that the current op is in fp16 mode. self.amp_dtype = amp_dtype
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name} self.grad_op_to_op_map = (
self.is_train = False dist_context.dist_op_context.grad_op_id_to_op_id
)
# op_id --> True/False. 'True' means that the current op is in fp16/bf16 mode.
self._op_fp16_dict = {}
# fwd_op_id --> {old_name: cast_name}
self._var_name_dict = {}
# out_var_name --> [op_ids]
self.out_var_op_deps = {}
def _is_fp16_op(self, op_id): def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None) return self._op_fp16_dict.get(op_id, None)
def _build_state(self, amp_lists, dist_context): def build_state(self):
ops = self._block.ops is_train = False
dist_op_context = dist_context.dist_op_context for block in self.program.blocks:
for op in ops: for op in block.ops:
if int(op.attr('op_role')) == 257: # to record the inplace operation and their outputs
self.is_train = True for name in op.output_arg_names:
if name not in self.out_var_op_deps:
if int(op.attr('op_role')) == int(OpRole.Forward): self.out_var_op_deps[name] = [op.desc.original_id()]
self._mark_black_white_ops(amp_lists) else:
elif int(op.attr('op_role')) == int(OpRole.Backward): self.out_var_op_deps[name].extend(
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: [op.desc.original_id()]
fwd_op_id = dist_op_context.grad_op_id_to_op_id[ )
if is_loss_grad_op(op):
is_train = True
if op.type in __amp_skip_ops__:
continue
if is_forward_op(op):
self._mark_black_white_ops(op, block.ops, block)
elif is_backward_op(op):
if op.desc.original_id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[
op.desc.original_id() op.desc.original_id()
] ]
if self._is_fp16_op(fwd_op_id) is True: assert fwd_op_id in self._op_fp16_dict, "{}".format(
self._op_fp16_dict[op.desc.original_id()] = True str(op)
elif self._is_fp16_op(fwd_op_id) is False: )
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[
elif int(op.attr('op_role')) == int(OpRole.Optimize): op.desc.original_id()
] = self._is_fp16_op(fwd_op_id)
elif is_optimize_op(op):
break break
return self.is_train # insert cast ops
for block in self.program.blocks:
self._cast_block(block)
def _mark_black_white_ops(self, amp_lists): return is_train
"""
this function is modified from paddle.static.amp
"""
self._block._sync_with_cpp()
ops = self._block.ops
for op in ops: def _mark_black_white_ops(self, op, ops, block):
if int(op.attr('op_role')) == int(OpRole.Backward):
break # ernie inference trick
if op.type == 'create_py_reader' or op.type == 'read': if op.type == "assign" and "array_" in op.input_arg_names[0]:
continue self._op_fp16_dict[op.desc.original_id()] = False
if amp_lists.black_varnames is not None and _is_in_black_varnames( return
op, amp_lists
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._is_fp16_op(self.out_var_op_deps[out_name][0]):
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
return
if (
self.amp_lists.black_varnames is not None
and self.amp_lists._is_in_black_fp32_varnames(op)
): ):
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
continue return
if op.type in amp_lists.black_list: if op.type in self.amp_lists.black_list:
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
elif op.type in amp_lists.white_list: elif op.type in self.amp_lists.white_list:
self._op_fp16_dict[op.desc.original_id()] = True self._op_fp16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list: elif op.type in self.amp_lists.gray_list:
is_black_op = False is_black_op = False
is_white_op = False is_white_op = False
for in_name in op.input_names: for in_name in op.input_names:
# if this op has inputs # if this op has inputs
if in_name: if in_name:
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name) in_var = block._var_recursive(in_var_name)
# this in_var isn't the output of other op # this in_var isn't the output of other op
if in_var.op is None: if in_var.op is None:
continue continue
elif in_var.op is op: elif in_var.op is op:
prev_op = find_true_prev_op( prev_op = find_true_prev_op(ops, op, in_var_name)
ops, op, in_var_name
)
if prev_op is None: if prev_op is None:
continue continue
else: else:
...@@ -125,13 +276,12 @@ class AMPState: ...@@ -125,13 +276,12 @@ class AMPState:
if ( if (
self._is_fp16_op(prev_op.desc.original_id()) self._is_fp16_op(prev_op.desc.original_id())
is False is False
or prev_op.type in amp_lists.black_list or prev_op.type in self.amp_lists.black_list
): ):
is_black_op = True is_black_op = True
elif ( elif (
self._is_fp16_op(prev_op.desc.original_id()) self._is_fp16_op(prev_op.desc.original_id()) is True
is True or prev_op.type in self.amp_lists.white_list
or prev_op.type in amp_lists.white_list
): ):
is_white_op = True is_white_op = True
if is_black_op: if is_black_op:
...@@ -145,37 +295,117 @@ class AMPState: ...@@ -145,37 +295,117 @@ class AMPState:
# are not determined which list they should stay. # are not determined which list they should stay.
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context): def _cast_block(self, block):
ops = self._block.ops
idx = 0 idx = 0
while idx < len(ops): appended_grad_times = 0
op = ops[idx] while idx < len(block.ops):
op = block.ops[idx]
num_cast_ops = 0 num_cast_ops = 0
if int(op.attr('op_role')) == int(OpRole.Backward):
break if op.type in __amp_skip_ops__:
idx += 1
continue
elif is_forward_op(op):
if self._is_fp16_op(op.desc.original_id()) is False: if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_forward( num_cast_ops = self._insert_cast_op_forward(
block,
op, op,
idx, idx,
core.VarDesc.VarType.FP16, _str_to_dtype(self.amp_dtype),
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
dist_context, self.dist_context,
) )
elif self._is_fp16_op(op.desc.original_id()) is True: elif self._is_fp16_op(op.desc.original_id()) is True:
if self.amp_dtype == "bfloat16":
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_forward( num_cast_ops = self._insert_cast_op_forward(
block,
op, op,
idx, idx,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, _str_to_dtype(self.amp_dtype),
dist_context, self.dist_context,
) )
else: elif is_backward_op(op):
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
op
)
if is_backward_op(op) and (
is_forward_op(block.ops[idx - 1])
or is_loss_op(block.ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
if op.desc.original_id() in self.grad_op_to_op_map:
if self._is_fp16_op(op.desc.original_id()) is False: # fp32
num_cast_ops = self._insert_cast_op_backward(
block,
op,
idx,
_str_to_dtype(self.amp_dtype),
core.VarDesc.VarType.FP32,
self.dist_context,
appended_grad_times,
)
elif (
self._is_fp16_op(op.desc.original_id()) is True
): # fp16/bf16
if self.amp_dtype == "bfloat16":
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype')
== core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_backward(
block,
op,
idx,
core.VarDesc.VarType.FP32,
_str_to_dtype(self.amp_dtype),
self.dist_context,
appended_grad_times,
)
elif op.type == "sum":
# all inputs dtype of sum should be equal and output dtype should follow input
out_var_name = op.desc.output_arg_names()[0]
in_var_name = op.desc.input_arg_names()[0]
out_var = block.var(out_var_name)
in_var = block._find_var_recursive(in_var_name)
for in_var_name in op.input_arg_names:
assert (
in_var.dtype == block.var(in_var_name).dtype
), "{}, {}, {}".format(
in_var, block.var(in_var_name), str(op)
)
out_var.desc.set_dtype(in_var.dtype)
elif int(op.attr('op_role')) == 257:
pass pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
op.type
)
)
idx += num_cast_ops + 1 idx += num_cast_ops + 1
self._block._sync_with_cpp() block._sync_with_cpp()
def _insert_cast_op_forward( def _insert_cast_op_forward(
self, op, idx, src_dtype, dst_dtype, dist_context self, block, op, idx, src_dtype, dst_dtype, dist_context
): ):
""" """
only for forward cast only for forward cast
...@@ -184,25 +414,26 @@ class AMPState: ...@@ -184,25 +414,26 @@ class AMPState:
num_cast_ops = 0 num_cast_ops = 0
var_name_dict = {} var_name_dict = {}
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if (
op, in_name src_dtype == core.VarDesc.VarType.FP32
and self.amp_lists._op_keep_fp32_input(op, in_name)
): ):
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name) in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype: if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue continue
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
cast_name = ( cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype) in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
) )
out_var = self._block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
var_name_dict[in_var.name] = cast_name var_name_dict[in_var.name] = cast_name
consume_op_attr = dist_context.get_op_dist_attr_for_program( consume_op_attr = dist_context.get_op_dist_attr_for_program(
op op
) )
assert consume_op_attr is not None assert consume_op_attr is not None
if out_var is None or out_var.dtype != dst_dtype: if cast_var is None or cast_var.dtype != dst_dtype:
# NOTE we make the cast op and var's dist attr as the op that consume the # NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var # cast var instead of the op which generates the var
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
...@@ -215,27 +446,27 @@ class AMPState: ...@@ -215,27 +446,27 @@ class AMPState:
cast_name, in_var_dist_attr cast_name, in_var_dist_attr
) )
out_var = self._block.create_var( cast_var = block.create_var(
name=cast_name, name=cast_name,
dtype=dst_dtype, dtype=dst_dtype,
persistable=False, persistable=False,
stop_gradient=in_var.stop_gradient, stop_gradient=in_var.stop_gradient,
) )
set_var_dist_attr( set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh dist_context, cast_var, ref_mapping, ref_mesh
) )
op_namescope = "/" op_namescope = "/"
if op.has_attr('op_namescope'): if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope') op_namescope = op.attr('op_namescope')
cast_op = self._block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
inputs={"X": in_var}, inputs={"X": in_var},
outputs={"Out": out_var}, outputs={"Out": cast_var},
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype, "out_dtype": cast_var.dtype,
}, },
) )
cast_op._set_attr( cast_op._set_attr(
...@@ -260,89 +491,27 @@ class AMPState: ...@@ -260,89 +491,27 @@ class AMPState:
if ( if (
src_dtype == core.VarDesc.VarType.FP32 src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.FP16 and dst_dtype == _str_to_dtype(self.amp_dtype)
): ):
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if self.amp_lists._op_keep_fp32_output(op, out_name):
continue continue
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
out_var = self._block.var(out_var_name) out_var = block._var_recursive(out_var_name)
if out_var.type not in _valid_types: if out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP32: if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16) out_var.desc.set_dtype(_str_to_dtype(self.amp_dtype))
if op.has_attr('out_dtype'): if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr(
return num_cast_ops 'out_dtype', _str_to_dtype(self.amp_dtype)
def cast_backward_program(self, params_grads, dist_context):
self._block._sync_with_cpp()
ops = self._block.ops
loss_op = get_loss_op(self._block)
loss_op_index = find_op_index(self._block.desc, loss_op.desc)
appended_grad_times = 0
idx = loss_op_index + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (
is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
grad_op_orig_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) is False: # fp32
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
) )
elif self._is_fp16_op(grad_op_orig_id) is True: # fp16 return num_cast_ops
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
for in_var_name in grad_op.desc.input_arg_names():
assert src_dtype == self._block.var(in_var_name).dtype
out_var_name = grad_op.desc.output_arg_names()[0]
out_var = self._block.var(out_var_name)
if out_var.dtype != src_dtype:
out_var.desc.set_dtype(src_dtype)
elif int(grad_op.attr('op_role')) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward( def _insert_cast_op_backward(
self, self,
grad_op, block,
op,
idx, idx,
src_dtype, src_dtype,
dst_dtype, dst_dtype,
...@@ -364,30 +533,30 @@ class AMPState: ...@@ -364,30 +533,30 @@ class AMPState:
return False return False
num_cast_ops = 0 num_cast_ops = 0
original_id = grad_op.desc.original_id() original_id = op.desc.original_id()
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id] fwd_op_id = self.grad_op_to_op_map[original_id]
for in_name in grad_op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name op, in_name
): ):
for in_var_name in grad_op.input(in_name): for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name) in_var = block._var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32 assert in_var.dtype == core.VarDesc.VarType.FP32
continue continue
for in_var_name in grad_op.input(in_name): for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name) in_var = block._var_recursive(in_var_name)
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program( consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op op
) )
if in_var_name in self._var_name_dict[fwd_op_id]: if in_var_name in self._var_name_dict[fwd_op_id]:
# NOTE: if in_var of consume grad_op has been casted before, # NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr. # it should be renamed and reset dist_attr.
cast_name = self._var_name_dict[fwd_op_id][in_var_name] cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name) op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name in_var_name
) )
...@@ -398,26 +567,26 @@ class AMPState: ...@@ -398,26 +567,26 @@ class AMPState:
assert ( assert (
in_var.dtype == dst_dtype in_var.dtype == dst_dtype
), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, op.type,
in_name, in_name,
dst_dtype, dst_dtype,
in_var.dtype, in_var.dtype,
str(grad_op), str(op),
) )
for out_name in grad_op.output_names: for out_name in op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name op, out_name
): ):
for out_var_name in grad_op.output(out_name): for out_var_name in op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name) out_var = block._var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32 assert out_var.dtype == core.VarDesc.VarType.FP32
continue continue
for out_var_name in grad_op.output(out_name): for out_var_name in op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name) out_var = block._var_recursive(out_var_name)
out_var_name_prefix = out_var_name[: out_var_name.find("@")] out_var_name_prefix = out_var_name[: out_var_name.find("@")]
fwd_var = self._block._find_var_recursive(out_var_name_prefix) fwd_var = block._var_recursive(out_var_name_prefix)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if out_var.dtype != fwd_var.dtype: if out_var.dtype != fwd_var.dtype:
out_var.desc.set_dtype(fwd_var.dtype) out_var.desc.set_dtype(fwd_var.dtype)
...@@ -428,7 +597,7 @@ class AMPState: ...@@ -428,7 +597,7 @@ class AMPState:
# it should be renamed and reset dist_attr, then we insert cast op to # it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype # convert the cast_var to original dtype
consume_op_attr = ( consume_op_attr = (
dist_context.get_op_dist_attr_for_program(grad_op) dist_context.get_op_dist_attr_for_program(op)
) )
fwd_cast_name = self._var_name_dict[fwd_op_id][ fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix out_var_name_prefix
...@@ -439,9 +608,9 @@ class AMPState: ...@@ -439,9 +608,9 @@ class AMPState:
out_var_name.find("@RENAME") : out_var_name.find("@RENAME") :
] ]
cast_name = fwd_cast_name + "@GRAD" + suffix cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype: if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name) op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = ( out_var_dist_attr = (
consume_op_attr.get_output_dist_attr( consume_op_attr.get_output_dist_attr(
out_var_name out_var_name
...@@ -453,7 +622,7 @@ class AMPState: ...@@ -453,7 +622,7 @@ class AMPState:
cast_name, out_var_dist_attr cast_name, out_var_dist_attr
) )
assert ref_mapping is not None assert ref_mapping is not None
cast_var = self._block.create_var( cast_var = block.create_var(
name=cast_name, name=cast_name,
shape=out_var.shape, shape=out_var.shape,
dtype=dst_dtype, dtype=dst_dtype,
...@@ -467,7 +636,7 @@ class AMPState: ...@@ -467,7 +636,7 @@ class AMPState:
appended_grad_times appended_grad_times
][cast_name] = fwd_cast_name ][cast_name] = fwd_cast_name
cast_op = self._block._insert_op( cast_op = block._insert_op(
idx + 1, idx + 1,
type="cast", type="cast",
inputs={"X": cast_var}, inputs={"X": cast_var},
...@@ -491,7 +660,88 @@ class AMPState: ...@@ -491,7 +660,88 @@ class AMPState:
return num_cast_ops return num_cast_ops
def _update_backward_cast_ops(params_grads, dist_context): @register_pass("auto_parallel_amp")
class AMPPass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("dtype", "") # fp16/bf16
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("custom_white_list", None)
self.set_attr("custom_black_list", None)
self.set_attr("custom_black_varnames", None)
self.set_attr("init_loss_scaling", 32768.0)
self.set_attr("incr_every_n_steps", 1000)
self.set_attr("decr_every_n_nan_or_inf", 2)
self.set_attr("incr_ratio", 2.0)
self.set_attr("decr_ratio", 0.8)
self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", [])
self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None
self._loss_scaling = None
self._num_good_steps = None
self._num_bad_steps = None
def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0:
return False
if self.get_attr("incr_every_n_steps") < 0:
return False
if self.get_attr("decr_every_n_nan_or_inf") < 0:
return False
if self.get_attr("incr_ratio") < 0:
return False
if self.get_attr("decr_ratio") < 0:
return False
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
# NOTE: why AMPBackwardPass can override apply_single_impl instead of
# apply_impl? AMP is an optimization pass for serial program,
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
self.params_grads = self.get_attr("params_grads")
self.amp_dtype = self.get_attr("dtype")
amp_lists = AMPLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")),
self.amp_dtype,
)
with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(
main_program, amp_lists, self.amp_dtype, self.dist_context
)
is_train = amp_state.build_state()
if is_train:
self._update_backward_cast_ops()
self._cast_loss()
if is_train and self.amp_dtype == "float16":
self._init_amp_var()
self._scale_loss()
if (
self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
grads, found_inf = self._check_and_update_gradient()
if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)
def _update_backward_cast_ops(self):
""" """
move param grad cast to the end of backward segment move param grad cast to the end of backward segment
in order to enabel fp16 allreduce in order to enabel fp16 allreduce
...@@ -501,12 +751,12 @@ def _update_backward_cast_ops(params_grads, dist_context): ...@@ -501,12 +751,12 @@ def _update_backward_cast_ops(params_grads, dist_context):
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp() main_block._sync_with_cpp()
for p, g in params_grads: for p, g in self.params_grads:
op = g.op op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr( if int(op.attr('op_role')) == int(
'op_role_var' OpRole.Backward
): ) and op.has_attr('op_role_var'):
op._remove_attr("op_role_var") op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name) post_ops = find_true_post_op(main_block.ops, op, g.name)
...@@ -534,17 +784,21 @@ def _update_backward_cast_ops(params_grads, dist_context): ...@@ -534,17 +784,21 @@ def _update_backward_cast_ops(params_grads, dist_context):
main_block.ops.append(new_op) main_block.ops.append(new_op)
# dist attr # dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p) param_dist_attr = (
output_dist_attr = dist_context.get_tensor_dist_attr_for_program( self.dist_context.get_tensor_dist_attr_for_program(p)
)
output_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0]) main_block.var(op.output_arg_names[0])
) )
)
assert param_dist_attr is not None assert param_dist_attr is not None
assert output_dist_attr is not None assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, new_op,
param_dist_attr.process_mesh, param_dist_attr.process_mesh,
param_dist_attr.dims_mapping, param_dist_attr.dims_mapping,
dist_context, self.dist_context,
) )
output_dist_attr.process_mesh = param_dist_attr.process_mesh output_dist_attr.process_mesh = param_dist_attr.process_mesh
...@@ -557,13 +811,12 @@ def _update_backward_cast_ops(params_grads, dist_context): ...@@ -557,13 +811,12 @@ def _update_backward_cast_ops(params_grads, dist_context):
main_block._sync_with_cpp() main_block._sync_with_cpp()
def _check_and_update_gradient(self):
def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp() main_block._sync_with_cpp()
grads = [g for _, g in params_grads] grads = [g for _, g in self.params_grads]
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads: for e in grads:
check_variable_and_dtype( check_variable_and_dtype(
...@@ -583,9 +836,11 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -583,9 +836,11 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
persistable=False, persistable=False,
stop_gradient=False, stop_gradient=False,
) )
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) set_var_dist_attr(
self.dist_context, found_inf, [-1], world_process_group.ranks
)
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': self._loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op( new_op = main_block.append_op(
...@@ -603,7 +858,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -603,7 +858,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
if len(world_process_group.ranks) > 1: if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale" new_op_dist_attr.impl_type = "check_finite_and_unscale"
for g in grads: for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping( new_op_dist_attr.set_input_dims_mapping(
g.name, g_dist_attr.dims_mapping g.name, g_dist_attr.dims_mapping
...@@ -611,91 +866,9 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -611,91 +866,9 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
new_op_dist_attr.set_output_dims_mapping( new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping g.name, g_dist_attr.dims_mapping
) )
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf return grads, found_inf
@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("custom_white_list", None)
self.set_attr("custom_black_list", None)
self.set_attr("custom_black_varnames", None)
self.set_attr("init_loss_scaling", 32768.0)
self.set_attr("incr_every_n_steps", 1000)
self.set_attr("decr_every_n_nan_or_inf", 2)
self.set_attr("incr_ratio", 2.0)
self.set_attr("decr_ratio", 0.8)
self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", [])
self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None
self._loss_scaling = None
self._num_good_steps = None
self._num_bad_steps = None
self._loss = None
def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0:
return False
if self.get_attr("incr_every_n_steps") < 0:
return False
if self.get_attr("decr_every_n_nan_or_inf") < 0:
return False
if self.get_attr("incr_ratio") < 0:
return False
if self.get_attr("decr_ratio") < 0:
return False
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
# NOTE: why AMPBackwardPass can override apply_single_impl instead of
# apply_impl? AMP is an optimization pass for serial program,
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
amp_lists = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")),
)
with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block())
is_train = amp_state._build_state(amp_lists, self.dist_context)
amp_state.cast_forward_program(self.dist_context)
if is_train:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
self._init_amp_var()
self._scale_loss()
if (
self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context
)
if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)
def _init_amp_var(self): def _init_amp_var(self):
self._loss_scaling = paddle.static.create_global_var( self._loss_scaling = paddle.static.create_global_var(
name=unique_name.generate("loss_scaling"), name=unique_name.generate("loss_scaling"),
...@@ -740,11 +913,10 @@ class AMPPass(PassBase): ...@@ -740,11 +913,10 @@ class AMPPass(PassBase):
world_process_group.ranks, world_process_group.ranks,
) )
def _scale_loss(self): def _cast_loss(self):
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp() main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss") loss = self.get_attr("loss")
assert loss is not None assert loss is not None
...@@ -777,13 +949,11 @@ class AMPPass(PassBase): ...@@ -777,13 +949,11 @@ class AMPPass(PassBase):
attrs={ attrs={
"in_dtype": loss.dtype, "in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32, "out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY], "op_role": loss_op.all_attrs()[OP_ROLE_KEY],
}, },
) )
loss_op._set_attr( loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context cast_op, ref_mesh, [-1], self.dist_context
) )
...@@ -814,8 +984,8 @@ class AMPPass(PassBase): ...@@ -814,8 +984,8 @@ class AMPPass(PassBase):
outputs={'Out': [pre_grad_name]}, outputs={'Out': [pre_grad_name]},
attrs={ attrs={
"in_dtype": core.VarDesc.VarType.FP32, "in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16, "out_dtype": _str_to_dtype(self.amp_dtype),
'op_role': core.op_proto_and_checker_maker.OpRole.Backward, "op_role": OpRole.Backward,
}, },
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
...@@ -823,6 +993,18 @@ class AMPPass(PassBase): ...@@ -823,6 +993,18 @@ class AMPPass(PassBase):
) )
loss_op = cast_op loss_op = cast_op
loss = cast_loss loss = cast_loss
self._loss = loss
main_block._sync_with_cpp()
def _scale_loss(self):
main_block = paddle.static.default_main_program().global_block()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op
)
if ( if (
self.get_attr("use_dynamic_loss_scaling") self.get_attr("use_dynamic_loss_scaling")
...@@ -833,28 +1015,24 @@ class AMPPass(PassBase): ...@@ -833,28 +1015,24 @@ class AMPPass(PassBase):
# forward # forward
ref_mesh = loss_op_dist_attr.process_mesh ref_mesh = loss_op_dist_attr.process_mesh
self._scaled_loss = main_block.create_var( scaled_loss = main_block.create_var(
name=unique_name.generate("scaled_loss"), name=unique_name.generate("scaled_loss"),
shape=loss.shape, shape=loss.shape,
dtype=loss.dtype, dtype=loss.dtype,
persistable=loss.persistable, persistable=loss.persistable,
) )
set_var_dist_attr( set_var_dist_attr(self.dist_context, scaled_loss, [-1], ref_mesh)
self.dist_context, self._scaled_loss, [-1], ref_mesh
)
elementwise_mul_op = main_block._insert_op( elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1, loss_op_idx + 1,
type='elementwise_mul', type='elementwise_mul',
inputs={'X': [loss], 'Y': [self._loss_scaling]}, inputs={'X': [loss], 'Y': [self._loss_scaling]},
outputs={'Out': [self._scaled_loss]}, outputs={'Out': [scaled_loss]},
attrs={ attrs={
'op_role': loss_op.all_attrs()[OP_ROLE_KEY], 'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
}, },
) )
loss_op._set_attr( loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_op, ref_mesh, [-1], self.dist_context elementwise_mul_op, ref_mesh, [-1], self.dist_context
) )
...@@ -865,23 +1043,23 @@ class AMPPass(PassBase): ...@@ -865,23 +1043,23 @@ class AMPPass(PassBase):
first_backward_op.type == "fill_constant" first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
) )
self._scaled_loss_grad = main_block.create_var( scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD", name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape, shape=loss.shape,
dtype=loss.dtype, dtype=loss.dtype,
persistable=loss.persistable, persistable=loss.persistable,
) )
set_var_dist_attr( set_var_dist_attr(
self.dist_context, self._scaled_loss_grad, [-1], ref_mesh self.dist_context, scaled_loss_grad, [-1], ref_mesh
) )
pre_grad_name = first_backward_op.output_arg_names[0] pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output( first_backward_op._rename_output(
pre_grad_name, self._scaled_loss_grad.name pre_grad_name, scaled_loss_grad.name
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context first_backward_op, ref_mesh, [-1], self.dist_context
) )
self._scaled_loss_grad.op = first_backward_op scaled_loss_grad.op = first_backward_op
# FIXME(JZ-LIANG) a trick to insert backward op # FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp() main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op( elementwise_mul_grad_op_desc = main_block.desc._insert_op(
...@@ -889,7 +1067,7 @@ class AMPPass(PassBase): ...@@ -889,7 +1067,7 @@ class AMPPass(PassBase):
) )
elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
elementwise_mul_grad_op_desc.set_input( elementwise_mul_grad_op_desc.set_input(
'Out@GRAD', [self._scaled_loss_grad.name] 'Out@GRAD', [scaled_loss_grad.name]
) )
elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name])
elementwise_mul_grad_op_desc.set_input( elementwise_mul_grad_op_desc.set_input(
...@@ -897,9 +1075,7 @@ class AMPPass(PassBase): ...@@ -897,9 +1075,7 @@ class AMPPass(PassBase):
) )
elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
elementwise_mul_grad_op_desc._set_attr( elementwise_mul_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
)
elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op_desc._set_attr('axis', -1)
elementwise_mul_grad_op = paddle.static.Operator( elementwise_mul_grad_op = paddle.static.Operator(
main_block, elementwise_mul_grad_op_desc main_block, elementwise_mul_grad_op_desc
...@@ -911,10 +1087,9 @@ class AMPPass(PassBase): ...@@ -911,10 +1087,9 @@ class AMPPass(PassBase):
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
) )
else: else:
self._scaled_loss = loss scaled_loss = loss
self._loss = loss self._loss = scaled_loss
main_block._sync_with_cpp() main_block._sync_with_cpp()
def _update_loss_scaling(self, grads, found_inf): def _update_loss_scaling(self, grads, found_inf):
...@@ -994,9 +1169,9 @@ class AMPPass(PassBase): ...@@ -994,9 +1169,9 @@ class AMPPass(PassBase):
main_block._sync_with_cpp() main_block._sync_with_cpp()
def get_loss(self): def get_loss(self):
# the amp / fp16 might change the effective loss variable for network and # the amp might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss. # therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp / fp16 pass. # return the effective loss after amp pass.
if self._loss: if self._loss:
return self._loss return self._loss
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import static
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.passes.pass_base import PassBase, register_pass
from paddle.framework import Block, core
from paddle.static.amp.bf16 import AutoMixedPrecisionListsBF16
from paddle.static.amp.bf16.amp_utils import (
_dtype_to_str,
_is_in_fp32_varnames,
_valid_types,
find_true_post_op,
)
from paddle.static.amp.fp16_utils import (
_rename_arg,
find_op_index,
find_true_prev_op,
)
from paddle.utils import unique_name
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
world_process_group = get_world_process_group()
class BF16State:
def __init__(self, block):
self._block: Block = block
self._op_bf16_dict = {}
self._var_name_dict = {}
def _is_bf16_op(self, op_id):
return self._op_bf16_dict.get(op_id, None)
def _build_state(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
training = False
for op in ops:
if int(op.attr("op_role")) == 257:
training = True
if int(op.attr("op_role")) == int(OpRole.Forward):
self._mark_black_white_op(amp_lists, op, ops)
elif int(op.attr("op_role")) == int(OpRole.Backward):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()
]
if self._is_bf16_op(fwd_op_id) is True:
self._op_bf16_dict[op.desc.original_id()] = True
elif self._is_bf16_op(fwd_op_id) is False:
self._op_bf16_dict[op.desc.original_id()] = False
elif int(op.attr("op_role")) == int(OpRole.Optimize):
break
return training
def _mark_black_white_op(self, amp_lists, op, ops):
if op.type == "create_py_reader" or op.type == "read":
return
if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames(
op, amp_lists
):
self._op_bf16_dict[op.desc.original_id()] = False
return
if op.type in amp_lists.bf16_list:
self._op_bf16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list:
is_fp32_op = False
is_bf16_op = False
for in_name in op.input_names:
if in_name:
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
if (
self._op_bf16_dict.get(
prev_op.desc.original_id(), False
)
is False
or prev_op.type in amp_lists.fp32_list
):
is_fp32_op = True
elif (
self._op_bf16_dict.get(
prev_op.desc.original_id(), False
)
is True
or prev_op.type in amp_lists.bf16_list
):
is_bf16_op = True
if is_fp32_op:
self._op_bf16_dict[op.desc.original_id()] = False
elif is_bf16_op:
self._op_bf16_dict[op.desc.original_id()] = True
else:
pass
else:
self._op_bf16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context):
ops = self._block.ops
idx = 0
while idx < len(ops):
num_cast_ops = 0
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if self._is_bf16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32,
dist_context,
)
elif self._is_bf16_op(op.desc.original_id()) is True:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16,
dist_context,
)
else:
pass
idx += num_cast_ops + 1
self._block._sync_with_cpp()
def _insert_cast_op_forward(
self, op, idx, src_dtype, dst_dtype, dist_context: DistributedContext
):
num_cast_ops = 0
var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
]:
if in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
var_name_dict[in_var.name] = cast_name
out_var = self._block.vars.get(cast_name)
consume_op_attr = dist_context.get_op_dist_attr_for_program(
op
)
assert consume_op_attr is not None
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
if out_var is None or out_var.dtype != dst_dtype:
assert in_var_dist_attr is not None
ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
out_var = self._block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient,
)
set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh
)
cast_op = self._block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
_rename_arg(op, in_var_name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.original_id()] = var_name_dict
if (
src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.BF16
):
for out_name in op.output_names:
if (
op.type
in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']
and out_name != 'Y'
):
continue
for out_var_name in op.output(out_name):
out_var = self._block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
return num_cast_ops
def cast_backward_program(self, params_grads, dist_context):
self._block._sync_with_cpp()
ops = self._block.ops
appended_grad_times = 0
dist_op_context = dist_context.dist_op_context
loss_op = get_loss_op(self._block)
idx = find_op_index(self._block.desc, loss_op.desc) + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (
is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
if (
grad_op.desc.original_id()
in dist_op_context.grad_op_id_to_op_id
):
if self._is_bf16_op(grad_op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
)
elif self._is_bf16_op(grad_op.desc.original_id()) is True:
if grad_op.has_attr('use_mkldnn'):
grad_op._set_attr('use_mkldnn', True)
grad_op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
grad_op.has_attr('dtype')
and grad_op.attr('dtype') == core.VarDesc.VarType.FP32
):
grad_op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
for in_var_name in grad_op.desc.input_arg_names():
assert src_dtype == self._block.var(in_var_name).dtype
out_var_name = grad_op.desc.output_arg_names()[0]
out_var = self._block.var(out_var_name)
if out_var.dtype != src_dtype:
out_var.desc.set_dtype(src_dtype)
elif int(grad_op.attr("op_role")) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(
self,
grad_op,
idx,
src_dtype,
dst_dtype,
dist_context,
appended_grad_times,
):
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return in_name not in {'X', 'Y@GRAD'}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return out_name != 'X@GRAD'
return False
num_cast_ops = 0
original_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name
):
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32
continue
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op
)
if in_var_name in self._var_name_dict[fwd_op_id]:
cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
else:
assert (
in_var.dtype == dst_dtype
), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type,
in_name,
dst_dtype,
in_var.dtype,
str(grad_op),
)
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name
):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32
continue
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
out_var_name_prefix = out_var_name[: out_var_name.find('@')]
fwd_var = self._block._find_var_recursive(out_var_name_prefix)
if out_var.dtype != fwd_var.dtype:
out_var.desc.set_dtype(fwd_var.dtype)
if out_var.dtype == src_dtype:
if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
consume_op_attr = (
dist_context.get_op_dist_attr_for_program(grad_op)
)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix
]
suffix = ''
if "@RENAME" in out_var_name:
suffix = out_var_name[
out_var_name.find("@RENAME") :
]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = (
consume_op_attr.get_output_dist_attr(
out_var_name
)
)
ref_mesh = out_var_dist_attr.process_mesh
ref_mapping = out_var_dist_attr.dims_mapping
consume_op_attr.set_output_dist_attr(
cast_name, out_var_dist_attr
)
assert ref_mapping is not None
cast_var = self._block.create_var(
name=cast_name,
shape=out_var.shape,
dtype=dst_dtype,
persistable=False,
stop_gradient=out_var.stop_gradient,
)
set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
dist_op_context.grad_var_to_var[
appended_grad_times
][cast_name] = fwd_cast_name
cast_op = self._block._insert_op(
idx + 1,
type="cast",
inputs={"X": cast_var},
outputs={"Out": out_var},
attrs={
"in_dtype": cast_var.dtype,
"out_dtype": out_var.dtype,
"op_role": OpRole.Backward,
},
)
cast_op._remove_attr("op_role_var")
cast_op._remove_attr("op_namescope")
cast_op._remove_attr("with_quant_attr")
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
assert out_var.dtype == dst_dtype
return num_cast_ops
def _update_backward_cast_ops(params_grads, dist_context):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
'op_role_var'
):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]:
continue
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.static.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None,
)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0])
)
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op,
param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
op_idx = find_op_index(main_block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
main_block._remove_op(op_idx, sync=False)
main_block._sync_with_cpp()
@register_pass("auto_parallel_bf16")
class BF16Pass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("dist_context", None)
self.set_attr("custom_bf16_list", None)
self.set_attr("custom_fp32_list", None)
self.set_attr("custom_fp32_varnames", None)
self.set_attr("input_data", [])
self.set_attr("loss", None)
self.set_attr("params_grads", [])
self.set_attr("use_bf16_guard", False)
self._loss = None
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
amp_lists = AutoMixedPrecisionListsBF16(
self.get_attr("custom_bf16_list"),
self.get_attr("custom_fp32_list"),
self.get_attr("custom_fp32_varnames"),
)
with static.program_guard(main_program, startup_program):
amp_state = BF16State(main_program.global_block())
training = amp_state._build_state(amp_lists, self.dist_context)
amp_state.cast_forward_program(self.dist_context)
if training:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
self._scale_loss()
def _scale_loss(self):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op
)
if loss.dtype != core.VarDesc.VarType.FP32:
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(
name=tmp_name, dtype=core.VarDesc.VarType.FP32
)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss
)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
cast_loss, loss_dist_attr
)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
type='cast',
inputs={"X": [loss]},
outputs={"Out": [cast_loss]},
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
"op_role": loss_op.all_attrs()[OP_ROLE_KEY],
},
)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context
)
first_backward_op = main_block.ops[loss_op_idx + 2]
assert (
first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
cast_loss_grad = main_block.create_var(
name=unique_name.generate(tmp_name + "@GRAD"),
shape=loss.shape,
dtype=core.VarDesc.VarType.FP32,
persistable=loss.persistable,
)
set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
cast_grad_op = main_block._insert_op(
loss_op_idx + 3,
type='cast',
inputs={'X': [cast_loss_grad]},
outputs={'Out': [pre_grad_name]},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_grad_op, ref_mesh, [-1], self.dist_context
)
loss = cast_loss
self._loss = loss
main_block._sync_with_cpp()
def get_loss(self):
if self._loss:
return self._loss
else:
return self.get_attr("loss")
...@@ -1611,7 +1611,9 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): ...@@ -1611,7 +1611,9 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dp_group = None dp_group = None
for input_name in op.input_arg_names: for input_name in op.input_arg_names:
if not is_parameter_related(input_name, op.block): # TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are renamed,
# to solve the param@RESHARD cannot be identifed.
if not is_parameter_related(input_name, op.block, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_attr = dist_context.get_op_dist_attr_for_program(op)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name) input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
......
...@@ -126,6 +126,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -126,6 +126,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_convert_to_process_meshes) test_convert_to_process_meshes)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16) py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver) py_test_modules(test_dist_saver MODULES test_dist_saver)
py_test_modules(test_engine_save_load MODULES test_engine_save_load)
# End of unittests WITH single card WITHOUT timeout # End of unittests WITH single card WITHOUT timeout
endif() endif()
...@@ -29,6 +29,8 @@ def apply_pass(use_amp=False, level=None): ...@@ -29,6 +29,8 @@ def apply_pass(use_amp=False, level=None):
if use_amp: if use_amp:
amp = strategy.amp amp = strategy.amp
amp.enable = True amp.enable = True
amp.dtype = "float16"
amp.level = level
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [ amp.custom_black_list = [
'c_softmax_with_cross_entropy', 'c_softmax_with_cross_entropy',
...@@ -37,8 +39,6 @@ def apply_pass(use_amp=False, level=None): ...@@ -37,8 +39,6 @@ def apply_pass(use_amp=False, level=None):
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.level = level
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level) print("amp level: ", level)
return strategy return strategy
......
...@@ -31,6 +31,8 @@ def apply_pass(): ...@@ -31,6 +31,8 @@ def apply_pass():
amp = dist_strategy.amp amp = dist_strategy.amp
amp.enable = True amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
amp.custom_white_list = ["lookup_table", "lookup_table_v2"] amp.custom_white_list = ["lookup_table", "lookup_table_v2"]
amp.custom_black_list = [ amp.custom_black_list = [
"reduce_sum", "reduce_sum",
...@@ -38,8 +40,6 @@ def apply_pass(): ...@@ -38,8 +40,6 @@ def apply_pass():
"elementwise_div", "elementwise_div",
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.level = "o2"
qat = dist_strategy.qat qat = dist_strategy.qat
qat.enable = True qat.enable = True
...@@ -119,9 +119,6 @@ class TestQuantizationPassExport(unittest.TestCase): ...@@ -119,9 +119,6 @@ class TestQuantizationPassExport(unittest.TestCase):
def test_qat_pass_2(self): def test_qat_pass_2(self):
batch_size = 1
batch_num = 10
strategy = apply_pass() strategy = apply_pass()
model, loss = generate_model("mp") model, loss = generate_model("mp")
engine = auto.Engine(model, loss, strategy=strategy) engine = auto.Engine(model, loss, strategy=strategy)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.fleet import auto
paddle.enable_static()
batch_size = 2
hidden_size = 1024
# sequence_len = 512
image_size = hidden_size
class_num = 10
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super().__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
auto.shard_tensor(input, auto.ProcessMesh([0]), [None, None])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
class TestSaveLoad(unittest.TestCase):
def test_fp32_save_fp16_load(self):
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
inputs_spec = [
paddle.static.InputSpec(
shape=[batch_size, image_size], name="input", dtype="float32"
)
]
labels_spec = [
paddle.static.InputSpec(
shape=[batch_size, 1], name="label", dtype="int64"
)
]
# build fp32 model
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine_fp32 = auto.Engine(
mlp, loss, optimizer, metric, strategy=strategy
)
engine_fp32.prepare(inputs_spec, labels_spec, mode="train")
fp32_state = {
k: np.array(v)
for k, v in engine_fp32.main_program.state_dict("param").items()
}
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine_fp32.save(model_filename)
# build fp16 model
strategy = auto.Strategy()
strategy.auto_mode = "semi"
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
engine_fp16 = auto.Engine(
mlp, loss, optimizer, metric, strategy=strategy
)
engine_fp16.load(model_filename)
engine_fp16.prepare(inputs_spec, labels_spec, mode="train")
fp16_state = {
k: np.array(v)
for k, v in engine_fp16.main_program.state_dict("param").items()
}
# check param
for name, fp32_param in fp32_state.items():
fp16_param = fp16_state[name]
if "layer_norm" in name:
assert fp16_param.dtype == np.float32
else:
assert fp16_param.dtype == np.float16
np.testing.assert_allclose(fp32_param, fp16_param, atol=1e-4)
temp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -36,7 +36,8 @@ def apply_pass(use_bf16=False): ...@@ -36,7 +36,8 @@ def apply_pass(use_bf16=False):
if use_bf16: if use_bf16:
amp = strategy.amp amp = strategy.amp
amp.enable = True amp.enable = True
amp.enable_bf16 = True amp.dtype = "bfloat16"
amp.level = "o1"
return strategy return strategy
......
...@@ -28,8 +28,8 @@ class TestStrategy(unittest.TestCase): ...@@ -28,8 +28,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16") self.assertEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1") self.assertEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
...@@ -40,10 +40,6 @@ class TestStrategy(unittest.TestCase): ...@@ -40,10 +40,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_bf16_guard, False) self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding sharding = strategy.sharding
...@@ -91,6 +87,8 @@ class TestStrategy(unittest.TestCase): ...@@ -91,6 +87,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
amp.enable = True amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
amp.init_loss_scaling = 16384.0 amp.init_loss_scaling = 16384.0
amp.incr_every_n_steps = 2000 amp.incr_every_n_steps = 2000
amp.decr_every_n_nan_or_inf = 4 amp.decr_every_n_nan_or_inf = 4
...@@ -101,8 +99,9 @@ class TestStrategy(unittest.TestCase): ...@@ -101,8 +99,9 @@ class TestStrategy(unittest.TestCase):
amp.custom_black_list = ["y"] amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"] amp.custom_black_varnames = ["z"]
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True) self.assertEqual(amp.enable, True)
self.assertEqual(amp.dtype, "float16")
self.assertEqual(amp.level, "o2")
self.assertAlmostEqual(amp.init_loss_scaling, 16384.0) self.assertAlmostEqual(amp.init_loss_scaling, 16384.0)
self.assertEqual(amp.incr_every_n_steps, 2000) self.assertEqual(amp.incr_every_n_steps, 2000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 4) self.assertEqual(amp.decr_every_n_nan_or_inf, 4)
...@@ -113,7 +112,6 @@ class TestStrategy(unittest.TestCase): ...@@ -113,7 +112,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"]) self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
sharding = strategy.sharding sharding = strategy.sharding
sharding.enable = True sharding.enable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册