未验证 提交 5cb0f3aa 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] BF16-o1/FP16-o1 PASS support training and generation (#51147)

* [AutoParallel] support bloom

* fix import

* align amp and bf16

* update func name

* clipbyglobalnorm and add_n support bf16

* upgrade amp strategy api

* update bf16 unittest

* fix static clip

---------
Co-authored-by: Nliangjianzhong <liangjianzhong@baidu.com>
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 32baca93
......@@ -723,6 +723,14 @@ class Completer:
tensor_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
for node in while_op_node.inputs:
if node.var().name() == tensor_name:
node_dist_attr = (
self._dist_context.get_dist_attr_for_graph(node)
)
node_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
# Step 4: set the process meshes of the outputs in while_op to the process meshes of the outside output nodes
while_op_outputs_dist_attrs = while_op_dist_attr.outputs_dist_attrs
......@@ -749,6 +757,14 @@ class Completer:
tensor_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
for node in while_op_node.outputs:
if node.var().name() == tensor_name:
node_dist_attr = (
self._dist_context.get_dist_attr_for_graph(node)
)
node_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
# Amend the process meshes related to array
for array_node_list in self._array_nodes.values():
......
......@@ -75,11 +75,6 @@ set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_bf16_guard", False)
#########################################
......
......@@ -1557,6 +1557,19 @@ class Engine:
cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
for name, param in program.state_dict().items():
param_array = np.array(param)
if name not in state_dict:
continue
if param_array.dtype != state_dict[name].dtype:
self._logger.info(
"cast {}'s dtype from '{}' to '{}'".format(
name,
str(state_dict[name].dtype),
str(param_array.dtype),
)
)
state_dict[name] = state_dict[name].astype(param_array.dtype)
program.set_state_dict(state_dict)
def save(self, path, training=True):
......
......@@ -272,7 +272,8 @@ def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
return best_compatible_impl
def is_parameter_related(varname, block):
def is_parameter_related(varname, block, dist_context=None):
# TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are be renamed
if ".subprog_" in varname:
varname = varname[: varname.index(".subprog_")]
if ".cast_fp" in varname:
......@@ -281,10 +282,17 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".cast_bf")]
if ".quantized" in varname:
varname = varname[: varname.index(".quantized")]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert block._find_var_recursive(varname)
assert block._find_var_recursive(
varname
), "cannot find var {} in cur block".format(varname)
var = block._var_recursive(varname)
# NOTE(hack method): to find the param which is resharded
if dist_context and "@RESHARD" in varname:
varname = varname[: varname.index("@RESHARD")]
serial_program = dist_context.serial_main_program
var = serial_program.global_block()._find_var_recursive(varname)
if var is None:
return False
return var.is_parameter
......
......@@ -28,6 +28,9 @@ class DistributedScale(DistributedOperatorImplContainer):
register_distributed_operator_impl_container(DistributedScale("scale"))
register_distributed_operator_impl_container(DistributedScale("fill_any_like"))
register_distributed_operator_impl_container(DistributedScale("where"))
register_distributed_operator_impl_container(DistributedScale("tanh"))
class DistributedScaleImpl(DistributedOperatorImpl):
......@@ -50,13 +53,17 @@ class DistributedScaleImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if x_dims_mapping != out_dims_mapping:
return False
in_dims_mappings = []
for in_name in op_desc.input_arg_names():
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
in_dims_mappings.append(in_dims_mapping)
for x_dims_mapping in in_dims_mappings:
if x_dims_mapping != out_dims_mapping:
return False
return True
......@@ -78,10 +85,6 @@ class DistributedScaleImpl(DistributedOperatorImpl):
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
changed = True
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed
@staticmethod
......@@ -94,3 +97,8 @@ class DistributedScaleImpl(DistributedOperatorImpl):
register_distributed_operator_impl("scale", DistributedScaleImpl("scale"))
register_distributed_operator_impl(
"fill_any_like", DistributedScaleImpl("fill_any_like")
)
register_distributed_operator_impl("where", DistributedScaleImpl("where"))
register_distributed_operator_impl("tanh", DistributedScaleImpl("tanh"))
......@@ -2213,7 +2213,11 @@ class Resharder:
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
assert op_input_attrs
assert (
op_input_attrs
), "The input '{}' of op '{}' has no distibution attributes in subblock".format(
op.name, var_name
)
return op_input_attrs
......
......@@ -18,7 +18,6 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_bf16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
......
......@@ -18,16 +18,18 @@ from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core
from paddle.static.amp.bf16.amp_utils import (
AutoMixedPrecisionListsBF16,
_is_in_fp32_varnames,
)
from paddle.static.amp.fp16_utils import (
AutoMixedPrecisionLists,
_dtype_to_str,
_is_in_black_varnames,
_keep_fp32_input,
_keep_fp32_output,
......@@ -40,142 +42,370 @@ from paddle.static.amp.fp16_utils import (
from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
from ..auto_parallel.utils import (
is_backward_op,
is_forward_op,
is_loss_grad_op,
is_loss_op,
is_optimize_op,
)
from .pass_base import PassBase, register_pass
world_process_group = get_world_process_group()
__amp_skip_ops__ = [
'create_py_reader',
'create_double_buffer_reader',
'cast',
'while',
]
def _dtype_to_str(dtype):
if dtype == core.VarDesc.VarType.FP16:
return 'fp16'
elif dtype == core.VarDesc.VarType.BF16:
return 'bf16'
else:
return 'fp32'
def _str_to_dtype(dstr):
if dstr == 'float16':
return core.VarDesc.VarType.FP16
elif dstr == 'bfloat16':
return core.VarDesc.VarType.BF16
else:
return core.VarDesc.VarType.FP32
class AMPLists:
def __init__(
self,
white_list=None,
black_list=None,
black_varnames=None,
dtype="float16",
):
self._amp_list = None
if dtype == "float16":
self._amp_list = AutoMixedPrecisionLists(
set(white_list), set(black_list), set(black_varnames)
)
elif dtype == "bfloat16":
self._amp_list = AutoMixedPrecisionListsBF16(
set(white_list), set(black_list), set(black_varnames)
)
assert self._amp_list is not None
self._dtype = dtype
self._is_float16 = dtype == "float16"
@property
def white_list(self):
if self._is_float16:
return self._amp_list.white_list
else:
return self._amp_list.bf16_list
@property
def black_list(self):
if self._is_float16:
return self._amp_list.black_list
else:
return self._amp_list.fp32_list
@property
def gray_list(self):
return self._amp_list.gray_list
@property
def black_varnames(self):
if self._is_float16:
return self._amp_list.black_varnames
else:
return self._amp_list.fp32_varnames
@property
def is_fp16(self):
return self._is_float16
@property
def dtype(self):
return self._dtype
@property
def amp_list(self):
return self._amp_list
def _is_in_black_fp32_varnames(self, op):
if self._is_float16:
return _is_in_black_varnames(op, self._amp_list)
else:
return _is_in_fp32_varnames(op, self._amp_list)
def _op_keep_fp32_input(self, op, in_name):
if self._is_float16:
return _keep_fp32_input(op, in_name)
else:
if op.type in ['batch_norm', 'layer_norm']:
return in_name != 'X'
if op.type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
return False
def _op_keep_fp32_output(self, op, out_name):
if self._is_float16:
return _keep_fp32_output(op, out_name)
else:
if op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
]:
return out_name != 'Y'
return False
class AMPState:
def __init__(self, block):
self._block = block
self._op_fp16_dict = (
{}
) # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
self.is_train = False
def __init__(self, program, amp_lists, amp_dtype, dist_context):
self.program = program
self.dist_context = dist_context
self.amp_lists = amp_lists
self.amp_dtype = amp_dtype
self.grad_op_to_op_map = (
dist_context.dist_op_context.grad_op_id_to_op_id
)
# op_id --> True/False. 'True' means that the current op is in fp16/bf16 mode.
self._op_fp16_dict = {}
# fwd_op_id --> {old_name: cast_name}
self._var_name_dict = {}
# out_var_name --> [op_ids]
self.out_var_op_deps = {}
def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)
def _build_state(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
for op in ops:
if int(op.attr('op_role')) == 257:
self.is_train = True
if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()
]
if self._is_fp16_op(fwd_op_id) is True:
self._op_fp16_dict[op.desc.original_id()] = True
elif self._is_fp16_op(fwd_op_id) is False:
self._op_fp16_dict[op.desc.original_id()] = False
elif int(op.attr('op_role')) == int(OpRole.Optimize):
break
return self.is_train
def _mark_black_white_ops(self, amp_lists):
"""
this function is modified from paddle.static.amp
"""
self._block._sync_with_cpp()
ops = self._block.ops
def build_state(self):
is_train = False
for block in self.program.blocks:
for op in block.ops:
# to record the inplace operation and their outputs
for name in op.output_arg_names:
if name not in self.out_var_op_deps:
self.out_var_op_deps[name] = [op.desc.original_id()]
else:
self.out_var_op_deps[name].extend(
[op.desc.original_id()]
)
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists
):
self._op_fp16_dict[op.desc.original_id()] = False
continue
if op.type in amp_lists.black_list:
self._op_fp16_dict[op.desc.original_id()] = False
elif op.type in amp_lists.white_list:
self._op_fp16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(
ops, op, in_var_name
)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if (
self._is_fp16_op(prev_op.desc.original_id())
is False
or prev_op.type in amp_lists.black_list
):
is_black_op = True
elif (
self._is_fp16_op(prev_op.desc.original_id())
is True
or prev_op.type in amp_lists.white_list
):
is_white_op = True
if is_black_op:
if is_loss_grad_op(op):
is_train = True
if op.type in __amp_skip_ops__:
continue
if is_forward_op(op):
self._mark_black_white_ops(op, block.ops, block)
elif is_backward_op(op):
if op.desc.original_id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[
op.desc.original_id()
]
assert fwd_op_id in self._op_fp16_dict, "{}".format(
str(op)
)
self._op_fp16_dict[
op.desc.original_id()
] = self._is_fp16_op(fwd_op_id)
elif is_optimize_op(op):
break
# insert cast ops
for block in self.program.blocks:
self._cast_block(block)
return is_train
def _mark_black_white_ops(self, op, ops, block):
# ernie inference trick
if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False
return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if op.type == "assign":
out_name = op.output_arg_names[0]
if len(self.out_var_op_deps[out_name]) > 1:
if not self._is_fp16_op(self.out_var_op_deps[out_name][0]):
self._op_fp16_dict[op.desc.original_id()] = False
elif is_white_op:
self._op_fp16_dict[op.desc.original_id()] = True
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
self._op_fp16_dict[op.desc.original_id()] = True
return
if (
self.amp_lists.black_varnames is not None
and self.amp_lists._is_in_black_fp32_varnames(op)
):
self._op_fp16_dict[op.desc.original_id()] = False
return
if op.type in self.amp_lists.black_list:
self._op_fp16_dict[op.desc.original_id()] = False
elif op.type in self.amp_lists.white_list:
self._op_fp16_dict[op.desc.original_id()] = True
elif op.type in self.amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block._var_recursive(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if (
self._is_fp16_op(prev_op.desc.original_id())
is False
or prev_op.type in self.amp_lists.black_list
):
is_black_op = True
elif (
self._is_fp16_op(prev_op.desc.original_id()) is True
or prev_op.type in self.amp_lists.white_list
):
is_white_op = True
if is_black_op:
self._op_fp16_dict[op.desc.original_id()] = False
elif is_white_op:
self._op_fp16_dict[op.desc.original_id()] = True
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
self._op_fp16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context):
ops = self._block.ops
def _cast_block(self, block):
idx = 0
while idx < len(ops):
op = ops[idx]
appended_grad_times = 0
while idx < len(block.ops):
op = block.ops[idx]
num_cast_ops = 0
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) is True:
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
if op.type in __amp_skip_ops__:
idx += 1
continue
elif is_forward_op(op):
if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_forward(
block,
op,
idx,
_str_to_dtype(self.amp_dtype),
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) is True:
if self.amp_dtype == "bfloat16":
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_forward(
block,
op,
idx,
core.VarDesc.VarType.FP32,
_str_to_dtype(self.amp_dtype),
self.dist_context,
)
elif is_backward_op(op):
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
op
)
else:
pass
if is_backward_op(op) and (
is_forward_op(block.ops[idx - 1])
or is_loss_op(block.ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
if op.desc.original_id() in self.grad_op_to_op_map:
if self._is_fp16_op(op.desc.original_id()) is False: # fp32
num_cast_ops = self._insert_cast_op_backward(
block,
op,
idx,
_str_to_dtype(self.amp_dtype),
core.VarDesc.VarType.FP32,
self.dist_context,
appended_grad_times,
)
elif (
self._is_fp16_op(op.desc.original_id()) is True
): # fp16/bf16
if self.amp_dtype == "bfloat16":
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype')
== core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_backward(
block,
op,
idx,
core.VarDesc.VarType.FP32,
_str_to_dtype(self.amp_dtype),
self.dist_context,
appended_grad_times,
)
elif op.type == "sum":
# all inputs dtype of sum should be equal and output dtype should follow input
out_var_name = op.desc.output_arg_names()[0]
in_var_name = op.desc.input_arg_names()[0]
out_var = block.var(out_var_name)
in_var = block._find_var_recursive(in_var_name)
for in_var_name in op.input_arg_names:
assert (
in_var.dtype == block.var(in_var_name).dtype
), "{}, {}, {}".format(
in_var, block.var(in_var_name), str(op)
)
out_var.desc.set_dtype(in_var.dtype)
elif int(op.attr('op_role')) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
block._sync_with_cpp()
def _insert_cast_op_forward(
self, op, idx, src_dtype, dst_dtype, dist_context
self, block, op, idx, src_dtype, dst_dtype, dist_context
):
"""
only for forward cast
......@@ -184,25 +414,26 @@ class AMPState:
num_cast_ops = 0
var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name
if (
src_dtype == core.VarDesc.VarType.FP32
and self.amp_lists._op_keep_fp32_input(op, in_name)
):
continue
for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
out_var = self._block.vars.get(cast_name)
cast_var = block.vars.get(cast_name)
var_name_dict[in_var.name] = cast_name
consume_op_attr = dist_context.get_op_dist_attr_for_program(
op
)
assert consume_op_attr is not None
if out_var is None or out_var.dtype != dst_dtype:
if cast_var is None or cast_var.dtype != dst_dtype:
# NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
......@@ -215,27 +446,27 @@ class AMPState:
cast_name, in_var_dist_attr
)
out_var = self._block.create_var(
cast_var = block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient,
)
set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh
dist_context, cast_var, ref_mapping, ref_mesh
)
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = self._block._insert_op_without_sync(
cast_op = block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
outputs={"Out": cast_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
"out_dtype": cast_var.dtype,
},
)
cast_op._set_attr(
......@@ -260,89 +491,27 @@ class AMPState:
if (
src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.FP16
and dst_dtype == _str_to_dtype(self.amp_dtype)
):
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
if self.amp_lists._op_keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = self._block.var(out_var_name)
out_var = block._var_recursive(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
out_var.desc.set_dtype(_str_to_dtype(self.amp_dtype))
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
op._set_attr(
'out_dtype', _str_to_dtype(self.amp_dtype)
)
return num_cast_ops
def cast_backward_program(self, params_grads, dist_context):
self._block._sync_with_cpp()
ops = self._block.ops
loss_op = get_loss_op(self._block)
loss_op_index = find_op_index(self._block.desc, loss_op.desc)
appended_grad_times = 0
idx = loss_op_index + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (
is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
grad_op_orig_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) is False: # fp32
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
)
elif self._is_fp16_op(grad_op_orig_id) is True: # fp16
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
for in_var_name in grad_op.desc.input_arg_names():
assert src_dtype == self._block.var(in_var_name).dtype
out_var_name = grad_op.desc.output_arg_names()[0]
out_var = self._block.var(out_var_name)
if out_var.dtype != src_dtype:
out_var.desc.set_dtype(src_dtype)
elif int(grad_op.attr('op_role')) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(
self,
grad_op,
block,
op,
idx,
src_dtype,
dst_dtype,
......@@ -364,30 +533,30 @@ class AMPState:
return False
num_cast_ops = 0
original_id = grad_op.desc.original_id()
original_id = op.desc.original_id()
dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
fwd_op_id = self.grad_op_to_op_map[original_id]
for in_name in grad_op.input_names:
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name
op, in_name
):
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
for in_var_name in op.input(in_name):
in_var = block._var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32
continue
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
for in_var_name in op.input(in_name):
in_var = block._var_recursive(in_var_name)
if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op
op
)
if in_var_name in self._var_name_dict[fwd_op_id]:
# NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr.
cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name)
op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
......@@ -398,26 +567,26 @@ class AMPState:
assert (
in_var.dtype == dst_dtype
), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type,
op.type,
in_name,
dst_dtype,
in_var.dtype,
str(grad_op),
str(op),
)
for out_name in grad_op.output_names:
for out_name in op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name
op, out_name
):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
for out_var_name in op.output(out_name):
out_var = block._var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32
continue
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
for out_var_name in op.output(out_name):
out_var = block._var_recursive(out_var_name)
out_var_name_prefix = out_var_name[: out_var_name.find("@")]
fwd_var = self._block._find_var_recursive(out_var_name_prefix)
fwd_var = block._var_recursive(out_var_name_prefix)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if out_var.dtype != fwd_var.dtype:
out_var.desc.set_dtype(fwd_var.dtype)
......@@ -428,7 +597,7 @@ class AMPState:
# it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype
consume_op_attr = (
dist_context.get_op_dist_attr_for_program(grad_op)
dist_context.get_op_dist_attr_for_program(op)
)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix
......@@ -439,9 +608,9 @@ class AMPState:
out_var_name.find("@RENAME") :
]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name)
cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = (
consume_op_attr.get_output_dist_attr(
out_var_name
......@@ -453,7 +622,7 @@ class AMPState:
cast_name, out_var_dist_attr
)
assert ref_mapping is not None
cast_var = self._block.create_var(
cast_var = block.create_var(
name=cast_name,
shape=out_var.shape,
dtype=dst_dtype,
......@@ -467,7 +636,7 @@ class AMPState:
appended_grad_times
][cast_name] = fwd_cast_name
cast_op = self._block._insert_op(
cast_op = block._insert_op(
idx + 1,
type="cast",
inputs={"X": cast_var},
......@@ -491,134 +660,11 @@ class AMPState:
return num_cast_ops
def _update_backward_cast_ops(params_grads, dist_context):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
'op_role_var'
):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]:
continue
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.static.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None,
)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0])
)
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op,
param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
op_idx = find_op_index(main_block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
main_block._remove_op(op_idx, sync=False)
main_block._sync_with_cpp()
def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
grads = [g for _, g in params_grads]
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads:
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64'],
'check_finite_and_unscale',
)
found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(['find_infinite_scale', 'tmp'])
),
shape=[1],
dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
# Constructing dist attr from op_desc can
# give all inputs and outputs default dist attrs
new_op_dist_attr = OperatorDistAttr(new_op.desc)
new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale"
for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf
@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("dtype", "") # fp16/bf16
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("custom_white_list", None)
......@@ -637,7 +683,6 @@ class AMPPass(PassBase):
self._loss_scaling = None
self._num_good_steps = None
self._num_bad_steps = None
self._loss = None
def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
......@@ -657,7 +702,6 @@ class AMPPass(PassBase):
return True
def _check_conflict(self, other_pass):
return True
# NOTE: why AMPBackwardPass can override apply_single_impl instead of
......@@ -665,37 +709,166 @@ class AMPPass(PassBase):
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
self.params_grads = self.get_attr("params_grads")
self.amp_dtype = self.get_attr("dtype")
amp_lists = AutoMixedPrecisionLists(
amp_lists = AMPLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")),
self.amp_dtype,
)
with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block())
is_train = amp_state._build_state(amp_lists, self.dist_context)
amp_state = AMPState(
main_program, amp_lists, self.amp_dtype, self.dist_context
)
is_train = amp_state.build_state()
amp_state.cast_forward_program(self.dist_context)
if is_train:
self._update_backward_cast_ops()
self._cast_loss()
if is_train:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
if is_train and self.amp_dtype == "float16":
self._init_amp_var()
self._scale_loss()
if (
self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context
)
grads, found_inf = self._check_and_update_gradient()
if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)
def _update_backward_cast_ops(self):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
for p, g in self.params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(
OpRole.Backward
) and op.has_attr('op_role_var'):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]:
continue
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.static.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None,
)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(p)
)
output_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0])
)
)
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op,
param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
self.dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
op_idx = find_op_index(main_block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
main_block._remove_op(op_idx, sync=False)
main_block._sync_with_cpp()
def _check_and_update_gradient(self):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
grads = [g for _, g in self.params_grads]
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads:
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64'],
'check_finite_and_unscale',
)
found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(['find_infinite_scale', 'tmp'])
),
shape=[1],
dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
set_var_dist_attr(
self.dist_context, found_inf, [-1], world_process_group.ranks
)
inputs = {'X': grads, 'Scale': self._loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
# Constructing dist attr from op_desc can
# give all inputs and outputs default dist attrs
new_op_dist_attr = OperatorDistAttr(new_op.desc)
new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale"
for g in grads:
g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf
def _init_amp_var(self):
self._loss_scaling = paddle.static.create_global_var(
name=unique_name.generate("loss_scaling"),
......@@ -740,11 +913,10 @@ class AMPPass(PassBase):
world_process_group.ranks,
)
def _scale_loss(self):
def _cast_loss(self):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss")
assert loss is not None
......@@ -777,13 +949,11 @@ class AMPPass(PassBase):
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
"op_role": loss_op.all_attrs()[OP_ROLE_KEY],
},
)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context
)
......@@ -814,8 +984,8 @@ class AMPPass(PassBase):
outputs={'Out': [pre_grad_name]},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
"out_dtype": _str_to_dtype(self.amp_dtype),
"op_role": OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
......@@ -823,6 +993,18 @@ class AMPPass(PassBase):
)
loss_op = cast_op
loss = cast_loss
self._loss = loss
main_block._sync_with_cpp()
def _scale_loss(self):
main_block = paddle.static.default_main_program().global_block()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op
)
if (
self.get_attr("use_dynamic_loss_scaling")
......@@ -833,28 +1015,24 @@ class AMPPass(PassBase):
# forward
ref_mesh = loss_op_dist_attr.process_mesh
self._scaled_loss = main_block.create_var(
scaled_loss = main_block.create_var(
name=unique_name.generate("scaled_loss"),
shape=loss.shape,
dtype=loss.dtype,
persistable=loss.persistable,
)
set_var_dist_attr(
self.dist_context, self._scaled_loss, [-1], ref_mesh
)
set_var_dist_attr(self.dist_context, scaled_loss, [-1], ref_mesh)
elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1,
type='elementwise_mul',
inputs={'X': [loss], 'Y': [self._loss_scaling]},
outputs={'Out': [self._scaled_loss]},
outputs={'Out': [scaled_loss]},
attrs={
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
},
)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_op, ref_mesh, [-1], self.dist_context
)
......@@ -865,23 +1043,23 @@ class AMPPass(PassBase):
first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
self._scaled_loss_grad = main_block.create_var(
scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape,
dtype=loss.dtype,
persistable=loss.persistable,
)
set_var_dist_attr(
self.dist_context, self._scaled_loss_grad, [-1], ref_mesh
self.dist_context, scaled_loss_grad, [-1], ref_mesh
)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(
pre_grad_name, self._scaled_loss_grad.name
pre_grad_name, scaled_loss_grad.name
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context
)
self._scaled_loss_grad.op = first_backward_op
scaled_loss_grad.op = first_backward_op
# FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op(
......@@ -889,7 +1067,7 @@ class AMPPass(PassBase):
)
elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
elementwise_mul_grad_op_desc.set_input(
'Out@GRAD', [self._scaled_loss_grad.name]
'Out@GRAD', [scaled_loss_grad.name]
)
elementwise_mul_grad_op_desc.set_input('X', [loss.name])
elementwise_mul_grad_op_desc.set_input(
......@@ -897,9 +1075,7 @@ class AMPPass(PassBase):
)
elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
elementwise_mul_grad_op_desc._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward
)
elementwise_mul_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
elementwise_mul_grad_op_desc._set_attr('axis', -1)
elementwise_mul_grad_op = paddle.static.Operator(
main_block, elementwise_mul_grad_op_desc
......@@ -911,10 +1087,9 @@ class AMPPass(PassBase):
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context
)
else:
self._scaled_loss = loss
self._loss = loss
scaled_loss = loss
self._loss = scaled_loss
main_block._sync_with_cpp()
def _update_loss_scaling(self, grads, found_inf):
......@@ -994,9 +1169,9 @@ class AMPPass(PassBase):
main_block._sync_with_cpp()
def get_loss(self):
# the amp / fp16 might change the effective loss variable for network and
# the amp might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp / fp16 pass.
# return the effective loss after amp pass.
if self._loss:
return self._loss
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import static
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.passes.pass_base import PassBase, register_pass
from paddle.framework import Block, core
from paddle.static.amp.bf16 import AutoMixedPrecisionListsBF16
from paddle.static.amp.bf16.amp_utils import (
_dtype_to_str,
_is_in_fp32_varnames,
_valid_types,
find_true_post_op,
)
from paddle.static.amp.fp16_utils import (
_rename_arg,
find_op_index,
find_true_prev_op,
)
from paddle.utils import unique_name
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
world_process_group = get_world_process_group()
class BF16State:
def __init__(self, block):
self._block: Block = block
self._op_bf16_dict = {}
self._var_name_dict = {}
def _is_bf16_op(self, op_id):
return self._op_bf16_dict.get(op_id, None)
def _build_state(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
training = False
for op in ops:
if int(op.attr("op_role")) == 257:
training = True
if int(op.attr("op_role")) == int(OpRole.Forward):
self._mark_black_white_op(amp_lists, op, ops)
elif int(op.attr("op_role")) == int(OpRole.Backward):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()
]
if self._is_bf16_op(fwd_op_id) is True:
self._op_bf16_dict[op.desc.original_id()] = True
elif self._is_bf16_op(fwd_op_id) is False:
self._op_bf16_dict[op.desc.original_id()] = False
elif int(op.attr("op_role")) == int(OpRole.Optimize):
break
return training
def _mark_black_white_op(self, amp_lists, op, ops):
if op.type == "create_py_reader" or op.type == "read":
return
if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames(
op, amp_lists
):
self._op_bf16_dict[op.desc.original_id()] = False
return
if op.type in amp_lists.bf16_list:
self._op_bf16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list:
is_fp32_op = False
is_bf16_op = False
for in_name in op.input_names:
if in_name:
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
if (
self._op_bf16_dict.get(
prev_op.desc.original_id(), False
)
is False
or prev_op.type in amp_lists.fp32_list
):
is_fp32_op = True
elif (
self._op_bf16_dict.get(
prev_op.desc.original_id(), False
)
is True
or prev_op.type in amp_lists.bf16_list
):
is_bf16_op = True
if is_fp32_op:
self._op_bf16_dict[op.desc.original_id()] = False
elif is_bf16_op:
self._op_bf16_dict[op.desc.original_id()] = True
else:
pass
else:
self._op_bf16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context):
ops = self._block.ops
idx = 0
while idx < len(ops):
num_cast_ops = 0
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if self._is_bf16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32,
dist_context,
)
elif self._is_bf16_op(op.desc.original_id()) is True:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16,
dist_context,
)
else:
pass
idx += num_cast_ops + 1
self._block._sync_with_cpp()
def _insert_cast_op_forward(
self, op, idx, src_dtype, dst_dtype, dist_context: DistributedContext
):
num_cast_ops = 0
var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
]:
if in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
var_name_dict[in_var.name] = cast_name
out_var = self._block.vars.get(cast_name)
consume_op_attr = dist_context.get_op_dist_attr_for_program(
op
)
assert consume_op_attr is not None
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
if out_var is None or out_var.dtype != dst_dtype:
assert in_var_dist_attr is not None
ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
out_var = self._block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient,
)
set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh
)
cast_op = self._block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
_rename_arg(op, in_var_name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.original_id()] = var_name_dict
if (
src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.BF16
):
for out_name in op.output_names:
if (
op.type
in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']
and out_name != 'Y'
):
continue
for out_var_name in op.output(out_name):
out_var = self._block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
return num_cast_ops
def cast_backward_program(self, params_grads, dist_context):
self._block._sync_with_cpp()
ops = self._block.ops
appended_grad_times = 0
dist_op_context = dist_context.dist_op_context
loss_op = get_loss_op(self._block)
idx = find_op_index(self._block.desc, loss_op.desc) + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (
is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
if (
grad_op.desc.original_id()
in dist_op_context.grad_op_id_to_op_id
):
if self._is_bf16_op(grad_op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
)
elif self._is_bf16_op(grad_op.desc.original_id()) is True:
if grad_op.has_attr('use_mkldnn'):
grad_op._set_attr('use_mkldnn', True)
grad_op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
grad_op.has_attr('dtype')
and grad_op.attr('dtype') == core.VarDesc.VarType.FP32
):
grad_op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
for in_var_name in grad_op.desc.input_arg_names():
assert src_dtype == self._block.var(in_var_name).dtype
out_var_name = grad_op.desc.output_arg_names()[0]
out_var = self._block.var(out_var_name)
if out_var.dtype != src_dtype:
out_var.desc.set_dtype(src_dtype)
elif int(grad_op.attr("op_role")) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(
self,
grad_op,
idx,
src_dtype,
dst_dtype,
dist_context,
appended_grad_times,
):
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return in_name not in {'X', 'Y@GRAD'}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return out_name != 'X@GRAD'
return False
num_cast_ops = 0
original_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name
):
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32
continue
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op
)
if in_var_name in self._var_name_dict[fwd_op_id]:
cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
else:
assert (
in_var.dtype == dst_dtype
), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type,
in_name,
dst_dtype,
in_var.dtype,
str(grad_op),
)
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name
):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32
continue
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
out_var_name_prefix = out_var_name[: out_var_name.find('@')]
fwd_var = self._block._find_var_recursive(out_var_name_prefix)
if out_var.dtype != fwd_var.dtype:
out_var.desc.set_dtype(fwd_var.dtype)
if out_var.dtype == src_dtype:
if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
consume_op_attr = (
dist_context.get_op_dist_attr_for_program(grad_op)
)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix
]
suffix = ''
if "@RENAME" in out_var_name:
suffix = out_var_name[
out_var_name.find("@RENAME") :
]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = (
consume_op_attr.get_output_dist_attr(
out_var_name
)
)
ref_mesh = out_var_dist_attr.process_mesh
ref_mapping = out_var_dist_attr.dims_mapping
consume_op_attr.set_output_dist_attr(
cast_name, out_var_dist_attr
)
assert ref_mapping is not None
cast_var = self._block.create_var(
name=cast_name,
shape=out_var.shape,
dtype=dst_dtype,
persistable=False,
stop_gradient=out_var.stop_gradient,
)
set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
dist_op_context.grad_var_to_var[
appended_grad_times
][cast_name] = fwd_cast_name
cast_op = self._block._insert_op(
idx + 1,
type="cast",
inputs={"X": cast_var},
outputs={"Out": out_var},
attrs={
"in_dtype": cast_var.dtype,
"out_dtype": out_var.dtype,
"op_role": OpRole.Backward,
},
)
cast_op._remove_attr("op_role_var")
cast_op._remove_attr("op_namescope")
cast_op._remove_attr("with_quant_attr")
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
assert out_var.dtype == dst_dtype
return num_cast_ops
def _update_backward_cast_ops(params_grads, dist_context):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
'op_role_var'
):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]:
continue
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.static.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None,
)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0])
)
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op,
param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
op_idx = find_op_index(main_block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
main_block._remove_op(op_idx, sync=False)
main_block._sync_with_cpp()
@register_pass("auto_parallel_bf16")
class BF16Pass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("dist_context", None)
self.set_attr("custom_bf16_list", None)
self.set_attr("custom_fp32_list", None)
self.set_attr("custom_fp32_varnames", None)
self.set_attr("input_data", [])
self.set_attr("loss", None)
self.set_attr("params_grads", [])
self.set_attr("use_bf16_guard", False)
self._loss = None
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
amp_lists = AutoMixedPrecisionListsBF16(
self.get_attr("custom_bf16_list"),
self.get_attr("custom_fp32_list"),
self.get_attr("custom_fp32_varnames"),
)
with static.program_guard(main_program, startup_program):
amp_state = BF16State(main_program.global_block())
training = amp_state._build_state(amp_lists, self.dist_context)
amp_state.cast_forward_program(self.dist_context)
if training:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
self._scale_loss()
def _scale_loss(self):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op
)
if loss.dtype != core.VarDesc.VarType.FP32:
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(
name=tmp_name, dtype=core.VarDesc.VarType.FP32
)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss
)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
cast_loss, loss_dist_attr
)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
type='cast',
inputs={"X": [loss]},
outputs={"Out": [cast_loss]},
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
"op_role": loss_op.all_attrs()[OP_ROLE_KEY],
},
)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context
)
first_backward_op = main_block.ops[loss_op_idx + 2]
assert (
first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
cast_loss_grad = main_block.create_var(
name=unique_name.generate(tmp_name + "@GRAD"),
shape=loss.shape,
dtype=core.VarDesc.VarType.FP32,
persistable=loss.persistable,
)
set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
cast_grad_op = main_block._insert_op(
loss_op_idx + 3,
type='cast',
inputs={'X': [cast_loss_grad]},
outputs={'Out': [pre_grad_name]},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_grad_op, ref_mesh, [-1], self.dist_context
)
loss = cast_loss
self._loss = loss
main_block._sync_with_cpp()
def get_loss(self):
if self._loss:
return self._loss
else:
return self.get_attr("loss")
......@@ -1611,7 +1611,9 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dp_group = None
for input_name in op.input_arg_names:
if not is_parameter_related(input_name, op.block):
# TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are renamed,
# to solve the param@RESHARD cannot be identifed.
if not is_parameter_related(input_name, op.block, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
......
......@@ -126,6 +126,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_convert_to_process_meshes)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver)
py_test_modules(test_engine_save_load MODULES test_engine_save_load)
# End of unittests WITH single card WITHOUT timeout
endif()
......@@ -29,6 +29,8 @@ def apply_pass(use_amp=False, level=None):
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = level
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
......@@ -37,8 +39,6 @@ def apply_pass(use_amp=False, level=None):
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.level = level
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
......
......@@ -31,6 +31,8 @@ def apply_pass():
amp = dist_strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
amp.custom_white_list = ["lookup_table", "lookup_table_v2"]
amp.custom_black_list = [
"reduce_sum",
......@@ -38,8 +40,6 @@ def apply_pass():
"elementwise_div",
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.level = "o2"
qat = dist_strategy.qat
qat.enable = True
......@@ -119,9 +119,6 @@ class TestQuantizationPassExport(unittest.TestCase):
def test_qat_pass_2(self):
batch_size = 1
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, strategy=strategy)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.fleet import auto
paddle.enable_static()
batch_size = 2
hidden_size = 1024
# sequence_len = 512
image_size = hidden_size
class_num = 10
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super().__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
auto.shard_tensor(input, auto.ProcessMesh([0]), [None, None])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
class TestSaveLoad(unittest.TestCase):
def test_fp32_save_fp16_load(self):
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
inputs_spec = [
paddle.static.InputSpec(
shape=[batch_size, image_size], name="input", dtype="float32"
)
]
labels_spec = [
paddle.static.InputSpec(
shape=[batch_size, 1], name="label", dtype="int64"
)
]
# build fp32 model
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine_fp32 = auto.Engine(
mlp, loss, optimizer, metric, strategy=strategy
)
engine_fp32.prepare(inputs_spec, labels_spec, mode="train")
fp32_state = {
k: np.array(v)
for k, v in engine_fp32.main_program.state_dict("param").items()
}
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine_fp32.save(model_filename)
# build fp16 model
strategy = auto.Strategy()
strategy.auto_mode = "semi"
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
engine_fp16 = auto.Engine(
mlp, loss, optimizer, metric, strategy=strategy
)
engine_fp16.load(model_filename)
engine_fp16.prepare(inputs_spec, labels_spec, mode="train")
fp16_state = {
k: np.array(v)
for k, v in engine_fp16.main_program.state_dict("param").items()
}
# check param
for name, fp32_param in fp32_state.items():
fp16_param = fp16_state[name]
if "layer_norm" in name:
assert fp16_param.dtype == np.float32
else:
assert fp16_param.dtype == np.float16
np.testing.assert_allclose(fp32_param, fp16_param, atol=1e-4)
temp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -36,7 +36,8 @@ def apply_pass(use_bf16=False):
if use_bf16:
amp = strategy.amp
amp.enable = True
amp.enable_bf16 = True
amp.dtype = "bfloat16"
amp.level = "o1"
return strategy
......
......@@ -28,8 +28,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp
self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertEqual(amp.dtype, "float16")
self.assertEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
......@@ -40,10 +40,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding
......@@ -91,6 +87,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
amp.init_loss_scaling = 16384.0
amp.incr_every_n_steps = 2000
amp.decr_every_n_nan_or_inf = 4
......@@ -101,8 +99,9 @@ class TestStrategy(unittest.TestCase):
amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"]
amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True)
self.assertEqual(amp.dtype, "float16")
self.assertEqual(amp.level, "o2")
self.assertAlmostEqual(amp.init_loss_scaling, 16384.0)
self.assertEqual(amp.incr_every_n_steps, 2000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 4)
......@@ -113,7 +112,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
sharding = strategy.sharding
sharding.enable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册