未验证 提交 418edae5 编写于 作者: X xu98bin 提交者: GitHub

auto parallel bf16 (#49079)

* auto parallel bf16
上级 1078e064
......@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat,
ops::CConcatOpCUDAKernel<double>,
ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>,
#if NCCL_VERSION_CODE >= 21000
ops::CConcatOpCUDAKernel<plat::bfloat16>,
#endif
ops::CConcatOpCUDAKernel<plat::float16>);
......@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity,
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
#if NCCL_VERSION_CODE >= 21000
ops::CIdentityOpKernel<plat::bfloat16>,
#endif
ops::CIdentityOpKernel<plat::float16>);
......@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "enable_bf16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_pure_bf16", False)
set_field_default_config(AMP, "use_bf16_guard", False)
#########################################
# sharding configuration
#########################################
......
......@@ -266,8 +266,12 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".subprog_")]
if ".cast_fp" in varname:
varname = varname[: varname.index(".cast_fp")]
if ".cast_bf" in varname:
varname = varname[: varname.index(".cast_bf")]
if ".quantized" in varname:
varname = varname[: varname.index(".quantized")]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert block._find_var_recursive(varname)
var = block._var_recursive(varname)
return var.is_parameter
......
......@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype(
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
......@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype(
intermediate_var_0,
'x',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
set_comm_op_dist_attr_for_program(
......@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_variable_and_dtype(
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
......@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
......@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'transpose_X': trans_x,
......@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_variable_and_dtype(
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
......@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
......@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'trans_x': trans_x,
......@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
check_variable_and_dtype(
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
......@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False}
......@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl):
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
......
......@@ -221,13 +221,21 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"]
)
if config["use_pure_fp16"]:
if config["enable_bf16"]:
auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config)
auto_parallel_bf16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_bf16_pass.get_loss()
elif config["use_pure_fp16"]:
config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
......
......@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_bf16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import static
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.passes.pass_base import PassBase, register_pass
from paddle.fluid import unique_name
from paddle.fluid.contrib.mixed_precision.bf16 import (
AutoMixedPrecisionListsBF16,
)
from paddle.fluid.contrib.mixed_precision.bf16.amp_utils import (
_dtype_to_str,
_is_in_fp32_varnames,
_valid_types,
find_op_index,
find_true_post_op,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_rename_arg,
find_true_prev_op,
)
from paddle.fluid.framework import Block
from paddle.framework import core
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
world_process_group = get_world_process_group()
class BF16State(object):
def __init__(self, block):
self._block: Block = block
self._op_bf16_dict = {}
self._var_name_dict = {}
def _is_bf16_op(self, op_id):
return self._op_bf16_dict.get(op_id, None)
def _build_state(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
training = False
for op in ops:
if int(op.attr("op_role")) == 257:
training = True
if int(op.attr("op_role")) == int(OpRole.Forward):
self._mark_black_white_op(amp_lists, op, ops)
elif int(op.attr("op_role")) == int(OpRole.Backward):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[
op.desc.original_id()
]
if self._is_bf16_op(fwd_op_id) is True:
self._op_bf16_dict[op.desc.original_id()] = True
elif self._is_bf16_op(fwd_op_id) is False:
self._op_bf16_dict[op.desc.original_id()] = False
elif int(op.attr("op_role")) == int(OpRole.Optimize):
break
return training
def _mark_black_white_op(self, amp_lists, op, ops):
if op.type == "create_py_reader" or op.type == "read":
return
if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames(
op, amp_lists
):
self._op_bf16_dict[op.desc.original_id()] = False
return
if op.type in amp_lists.bf16_list:
self._op_bf16_dict[op.desc.original_id()] = True
elif op.type in amp_lists.gray_list:
is_fp32_op = False
is_bf16_op = False
for in_name in op.input_names:
if in_name:
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
if (
self._op_bf16_dict.get(
prev_op.desc.original_id(), False
)
is False
or prev_op.type in amp_lists.fp32_list
):
is_fp32_op = True
elif (
self._op_bf16_dict.get(
prev_op.desc.original_id(), False
)
is True
or prev_op.type in amp_lists.bf16_list
):
is_bf16_op = True
if is_fp32_op:
self._op_bf16_dict[op.desc.original_id()] = False
elif is_bf16_op:
self._op_bf16_dict[op.desc.original_id()] = True
else:
pass
else:
self._op_bf16_dict[op.desc.original_id()] = False
def cast_forward_program(self, dist_context):
ops = self._block.ops
idx = 0
while idx < len(ops):
num_cast_ops = 0
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if self._is_bf16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32,
dist_context,
)
elif self._is_bf16_op(op.desc.original_id()) is True:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_forward(
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16,
dist_context,
)
else:
pass
idx += num_cast_ops + 1
self._block._sync_with_cpp()
def _insert_cast_op_forward(
self, op, idx, src_dtype, dst_dtype, dist_context: DistributedContext
):
num_cast_ops = 0
var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
]:
if in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
var_name_dict[in_var.name] = cast_name
out_var = self._block.vars.get(cast_name)
consume_op_attr = dist_context.get_op_dist_attr_for_program(
op
)
assert consume_op_attr is not None
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
if out_var is None or out_var.dtype != dst_dtype:
assert in_var_dist_attr is not None
ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
out_var = self._block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient,
)
set_var_dist_attr(
dist_context, out_var, ref_mapping, ref_mesh
)
cast_op = self._block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
_rename_arg(op, in_var_name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.original_id()] = var_name_dict
if (
src_dtype == core.VarDesc.VarType.FP32
and dst_dtype == core.VarDesc.VarType.BF16
):
for out_name in op.output_names:
if (
op.type
in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']
and out_name != 'Y'
):
continue
for out_var_name in op.output(out_name):
out_var = self._block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
return num_cast_ops
def cast_backward_program(self, params_grads, dist_context):
self._block._sync_with_cpp()
ops = self._block.ops
appended_grad_times = 0
dist_op_context = dist_context.dist_op_context
loss_op = get_loss_op(self._block)
idx = find_op_index(self._block.desc, loss_op.desc) + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (
is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
if (
grad_op.desc.original_id()
in dist_op_context.grad_op_id_to_op_id
):
if self._is_bf16_op(grad_op.desc.original_id()) is False:
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32,
dist_context,
appended_grad_times,
)
elif self._is_bf16_op(grad_op.desc.original_id()) is True:
if grad_op.has_attr('use_mkldnn'):
grad_op._set_attr('use_mkldnn', True)
grad_op._set_attr('mkldnn_data_type', 'bfloat16')
elif (
grad_op.has_attr('dtype')
and grad_op.attr('dtype') == core.VarDesc.VarType.FP32
):
grad_op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = self._insert_cast_op_backward(
grad_op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16,
dist_context,
appended_grad_times,
)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
for in_var_name in grad_op.desc.input_arg_names():
assert src_dtype == self._block.var(in_var_name).dtype
out_var_name = grad_op.desc.output_arg_names()[0]
out_var = self._block.var(out_var_name)
if out_var.dtype != src_dtype:
out_var.desc.set_dtype(src_dtype)
elif int(grad_op.attr("op_role")) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type
)
)
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(
self,
grad_op,
idx,
src_dtype,
dst_dtype,
dist_context,
appended_grad_times,
):
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return in_name not in {'X', 'Y@GRAD'}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return out_name != 'X@GRAD'
return False
num_cast_ops = 0
original_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name
):
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32
continue
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op
)
if in_var_name in self._var_name_dict[fwd_op_id]:
cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name
)
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
else:
assert (
in_var.dtype == dst_dtype
), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type,
in_name,
dst_dtype,
in_var.dtype,
str(grad_op),
)
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name
):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32
continue
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
out_var_name_prefix = out_var_name[: out_var_name.find('@')]
fwd_var = self._block._find_var_recursive(out_var_name_prefix)
if out_var.dtype != fwd_var.dtype:
out_var.desc.set_dtype(fwd_var.dtype)
if out_var.dtype == src_dtype:
if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
consume_op_attr = (
dist_context.get_op_dist_attr_for_program(grad_op)
)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix
]
suffix = ''
if "@RENAME" in out_var_name:
suffix = out_var_name[
out_var_name.find("@RENAME") :
]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = (
consume_op_attr.get_output_dist_attr(
out_var_name
)
)
ref_mesh = out_var_dist_attr.process_mesh
ref_mapping = out_var_dist_attr.dims_mapping
consume_op_attr.set_output_dist_attr(
cast_name, out_var_dist_attr
)
assert ref_mapping is not None
cast_var = self._block.create_var(
name=cast_name,
shape=out_var.shape,
dtype=dst_dtype,
persistable=False,
stop_gradient=out_var.stop_gradient,
)
set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
dist_op_context.grad_var_to_var[
appended_grad_times
][cast_name] = fwd_cast_name
cast_op = self._block._insert_op(
idx + 1,
type="cast",
inputs={"X": cast_var},
outputs={"Out": out_var},
attrs={
"in_dtype": cast_var.dtype,
"out_dtype": out_var.dtype,
"op_role": OpRole.Backward,
},
)
cast_op._remove_attr("op_role_var")
cast_op._remove_attr("op_namescope")
cast_op._remove_attr("with_quant_attr")
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
else:
assert out_var.dtype == dst_dtype
return num_cast_ops
def _update_backward_cast_ops(params_grads, dist_context):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
'op_role_var'
):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
if op == main_block.ops[-1]:
continue
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.fluid.framework.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None,
)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0])
)
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op,
param_dist_attr.process_mesh,
param_dist_attr.dims_mapping,
dist_context,
)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
op_idx = find_op_index(main_block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
main_block._remove_op(op_idx, sync=False)
main_block._sync_with_cpp()
@register_pass("auto_parallel_bf16")
class BF16Pass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("dist_context", None)
self.set_attr("custom_bf16_list", None)
self.set_attr("custom_fp32_list", None)
self.set_attr("custom_fp32_varnames", None)
self.set_attr("input_data", [])
self.set_attr("loss", None)
self.set_attr("params_grads", [])
self.set_attr("use_bf16_guard", False)
self._loss = None
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
amp_lists = AutoMixedPrecisionListsBF16(
self.get_attr("custom_bf16_list"),
self.get_attr("custom_fp32_list"),
self.get_attr("custom_fp32_varnames"),
)
with static.program_guard(main_program, startup_program):
amp_state = BF16State(main_program.global_block())
training = amp_state._build_state(amp_lists, self.dist_context)
amp_state.cast_forward_program(self.dist_context)
if training:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
self._scale_loss()
def _scale_loss(self):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op
)
if loss.dtype != core.VarDesc.VarType.FP32:
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(
name=tmp_name, dtype=core.VarDesc.VarType.FP32
)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss
)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
cast_loss, loss_dist_attr
)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
type='cast',
inputs={"X": [loss]},
outputs={"Out": [cast_loss]},
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
"op_role": loss_op.all_attrs()[OP_ROLE_KEY],
},
)
loss_op._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context
)
first_backward_op = main_block.ops[loss_op_idx + 2]
assert (
first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
cast_loss_grad = main_block.create_var(
name=unique_name.generate(tmp_name + "@GRAD"),
shape=loss.shape,
dtype=core.VarDesc.VarType.FP32,
persistable=loss.persistable,
)
set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
cast_grad_op = main_block._insert_op(
loss_op_idx + 3,
type='cast',
inputs={'X': [cast_loss_grad]},
outputs={'Out': [pre_grad_name]},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_grad_op, ref_mesh, [-1], self.dist_context
)
loss = cast_loss
self._loss = loss
main_block._sync_with_cpp()
def get_loss(self):
if self._loss:
return self._loss
else:
return self.get_attr("loss")
......@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_partition MODULES test_cluster_partition)
py_test_modules(test_convert_to_process_meshes MODULES
test_convert_to_process_meshes)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver)
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.distributed.fleet import auto
from paddle.fluid.contrib.mixed_precision.bf16.amp_utils import _valid_types
from paddle.fluid.contrib.mixed_precision.fp16_utils import find_true_prev_op
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.static import InputSpec
from paddle.vision.datasets import MNIST
paddle.enable_static()
def apply_pass(use_bf16=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_bf16:
amp = strategy.amp
amp.enable = True
amp.enable_bf16 = True
return strategy
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super().__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return (img,)
def __len__(self):
return len(self.images)
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Model(nn.Layer):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 120)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(120, 10)
def forward(self, input):
input.stop_gradient = True
x = self.flatten(input)
x = self.relu1(self.fc1(x))
x = self.fc2(x)
return x
class TestBF16Pass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 256
self.batch_num = 10
self.dataset = MnistDataset("train")
self.eval_dataset = MnistDataset("test")
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_bf16=False):
reset_prog()
strategy = apply_pass(use_bf16)
model = Model()
opt = paddle.optimizer.SGD(0.001, parameters=model.parameters())
loss = nn.CrossEntropyLoss()
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_program(self, program):
bf16_op_list = {
"matmul_v2",
"elementwise_add",
"relu",
"elementwise_add_grad",
"matmul_v2_grad",
"relu_grad",
}
fp32_op_list = {
"flatten_contiguous_range",
"reduce_mean",
"softmax_with_cross_entropy",
"fill_constant",
"reduce_mean_grad",
"softmax_with_cross_entropy_grad",
}
for block in program.blocks:
for op in block.ops:
if op not in bf16_op_list and op not in fp32_op_list:
continue
for in_name in op.input_names:
for in_var_name in op.input(in_name):
var = None
try:
var = block.var(in_var_name)
except ValueError as e:
var = block._var_recursive(in_var_name)
if var is None or var.type not in _valid_types:
break
if op.type in bf16_op_list:
assert var.dtype == core.VarDesc.VarType.BF16
if "cast_bf16" in in_var_name:
if "@GRAD" in in_var_name:
tmp_in_var_name = in_var_name[
: in_var_name.find("@GRAD")
]
else:
tmp_in_var_name = in_var_name
prev_op = find_true_prev_op(
block.ops, op, tmp_in_var_name
)
assert prev_op is not None
assert prev_op.type == "cast"
for in_name in prev_op.input_names:
for in_var_name in prev_op.input(in_name):
var = block.var(in_var_name)
assert (
var.dtype
== core.VarDesc.VarType.FP32
)
elif op.type in fp32_op_list:
if (
op.type == "softmax_with_cross_entropy"
or op.type == "softmax_with_cross_entropy_grad"
) and in_var_name == "label0":
continue
assert var.dtype == core.VarDesc.VarType.FP32
if "cast_fp32" in in_var_name:
prev_op = find_true_prev_op(
block.ops, op, tmp_in_var_name
)
assert prev_op is not None
assert prev_op.type == "cast"
for in_name in prev_op.input_names:
for in_var_name in prev_op.input(in_name):
var = block.var(in_var_name)
assert (
var.dtype
== core.VarDesc.VarType.BF16
)
for out_name in op.output_names:
for out_var_name in op.output(out_name):
var = None
try:
var = block.var(out_var_name)
except ValueError as e:
var = block._var_recursive(out_var_name)
if var is None or var.type not in _valid_types:
break
if op.type in bf16_op_list:
assert var.dtype == core.VarDesc.VarType.BF16
elif op.type in fp32_op_list:
assert var.dtype == core.VarDesc.VarType.FP32
def test_bf16_pass(self):
bf16_o1_engine = self.get_engine(True)
inputs_spec = [InputSpec([None, 1, 28, 28], 'float32', 'input0')]
labels_spec = [InputSpec([None, 1], 'int64', 'label0')]
bf16_o1_engine.prepare(
inputs_spec=inputs_spec, labels_spec=labels_spec, mode="train"
)
self.check_program(bf16_o1_engine._dist_main_progs["train"][0])
print("BF16!check program successfully!")
if __name__ == "__main__":
unittest.main()
......@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.enable_bf16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_pure_bf16, False)
self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册