From 418edae569b0385cedde82b06067d37aed9ec80d Mon Sep 17 00:00:00 2001 From: xu98bin <78574951+xu98bin@users.noreply.github.com> Date: Thu, 29 Dec 2022 15:43:38 +0800 Subject: [PATCH] auto parallel bf16 (#49079) * auto parallel bf16 --- .../operators/collective/c_concat_op.cu.cc | 3 + .../operators/collective/c_identity_op.cu.cc | 3 + .../distributed/auto_parallel/constants.py | 7 + .../auto_parallel/operators/common.py | 4 + .../auto_parallel/operators/dist_matmul.py | 54 +- .../auto_parallel/parallelizer_v2.py | 10 +- python/paddle/distributed/passes/__init__.py | 1 + .../distributed/passes/auto_parallel_bf16.py | 661 ++++++++++++++++++ .../unittests/auto_parallel/CMakeLists.txt | 1 + .../unittests/auto_parallel/test_pass_bf16.py | 211 ++++++ .../unittests/auto_parallel/test_strategy.py | 7 + 11 files changed, 943 insertions(+), 19 deletions(-) create mode 100644 python/paddle/distributed/passes/auto_parallel_bf16.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index cb52ac0479f..96282ebde36 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat, ops::CConcatOpCUDAKernel, ops::CConcatOpCUDAKernel, ops::CConcatOpCUDAKernel, +#if NCCL_VERSION_CODE >= 21000 + ops::CConcatOpCUDAKernel, +#endif ops::CConcatOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_identity_op.cu.cc b/paddle/fluid/operators/collective/c_identity_op.cu.cc index 0b2f5b7eb1a..0ba98ab315d 100644 --- a/paddle/fluid/operators/collective/c_identity_op.cu.cc +++ b/paddle/fluid/operators/collective/c_identity_op.cu.cc @@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity, ops::CIdentityOpKernel, ops::CIdentityOpKernel, ops::CIdentityOpKernel, +#if NCCL_VERSION_CODE >= 21000 + ops::CIdentityOpKernel, +#endif ops::CIdentityOpKernel); diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 044bc78887b..f0c9655c81e 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", True) set_field_default_config(AMP, "use_optimizer_fp16", False) +set_field_default_config(AMP, "enable_bf16", False) +set_field_default_config(AMP, "custom_bf16_list", []) +set_field_default_config(AMP, "custom_fp32_list", []) +set_field_default_config(AMP, "custom_fp32_varnames", []) +set_field_default_config(AMP, "use_pure_bf16", False) +set_field_default_config(AMP, "use_bf16_guard", False) + ######################################### # sharding configuration ######################################### diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index ff1518f9b8f..ef865dc13bb 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -266,8 +266,12 @@ def is_parameter_related(varname, block): varname = varname[: varname.index(".subprog_")] if ".cast_fp" in varname: varname = varname[: varname.index(".cast_fp")] + if ".cast_bf" in varname: + varname = varname[: varname.index(".cast_bf")] if ".quantized" in varname: varname = varname[: varname.index(".quantized")] + # if "@RESHARD" in varname: + # varname = varname[: varname.index("@RESHARD")] assert block._find_var_recursive(varname) var = block._var_recursive(varname) return var.is_parameter diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index e3d17b26b68..dca8e24bc59 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): check_variable_and_dtype( Out_grad, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) @@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): check_variable_and_dtype( intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) set_comm_op_dist_attr_for_program( @@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): check_variable_and_dtype( X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) @@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): intermediate_var_0.desc.set_shape(ref_shape_x) check_variable_and_dtype( - intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) attrs = { @@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): group = new_process_group(group_ranks) check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64'], 'linear' + X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' ) check_dtype( - X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear' + X_var.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) attrs = { 'transpose_X': trans_x, @@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): check_variable_and_dtype( X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) c_identity_op = main_block.append_op( @@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): intermediate_var_0.desc.set_shape(ref_shape_x) check_variable_and_dtype( - intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) attrs = { @@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): group = new_process_group(group_ranks) check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64'], 'linear' + X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' ) check_dtype( - X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear' + X_var.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) attrs = { 'trans_x': trans_x, @@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): check_variable_and_dtype( X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) c_identity_op = main_block.append_op( @@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl): intermediate_var_0.desc.set_shape(ref_shape_x) check_variable_and_dtype( - intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) # attrs = {'trans_x': False, 'trans_y': False} @@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl): group = new_process_group(group_ranks) check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64'], 'linear' + X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' ) check_dtype( - X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear' + X_var.dtype, + 'dtype', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) # attrs = {'trans_x': False, 'trans_y': False} attrs = { diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 0a05c716ded..2ff8f0ee7d1 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -221,13 +221,21 @@ class Parallelizer: self._dist_context.serial_feed_vars["inputs"] + self._dist_context.serial_feed_vars["labels"] ) - if config["use_pure_fp16"]: + if config["enable_bf16"]: + auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config) + auto_parallel_bf16_pass.apply( + [main_program], [startup_program], self._pass_context + ) + loss = auto_parallel_bf16_pass.get_loss() + + elif config["use_pure_fp16"]: config["base_opt"] = optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply( [main_program], [startup_program], self._pass_context ) loss = auto_parallel_fp16_pass.get_loss() + else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply( diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 056540a4a15..886d29a30b4 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403 from .auto_parallel_sharding import * # noqa: F403 from .auto_parallel_amp import * # noqa: F403 from .auto_parallel_fp16 import * # noqa: F403 +from .auto_parallel_bf16 import * # noqa: F403 from .auto_parallel_recompute import * # noqa: F403 from .auto_parallel_quantization import * # noqa: F403 from .auto_parallel_data_parallel_optimization import * # noqa: F403 diff --git a/python/paddle/distributed/passes/auto_parallel_bf16.py b/python/paddle/distributed/passes/auto_parallel_bf16.py new file mode 100644 index 00000000000..3344c648244 --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_bf16.py @@ -0,0 +1,661 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import static +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.process_group import ( + get_world_process_group, +) +from paddle.distributed.auto_parallel.utils import ( + get_loss_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_var_dist_attr, +) +from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.distributed.passes.pass_base import PassBase, register_pass +from paddle.fluid import unique_name +from paddle.fluid.contrib.mixed_precision.bf16 import ( + AutoMixedPrecisionListsBF16, +) +from paddle.fluid.contrib.mixed_precision.bf16.amp_utils import ( + _dtype_to_str, + _is_in_fp32_varnames, + _valid_types, + find_op_index, + find_true_post_op, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + _rename_arg, + find_true_prev_op, +) +from paddle.fluid.framework import Block +from paddle.framework import core + +from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op + +world_process_group = get_world_process_group() + + +class BF16State(object): + def __init__(self, block): + self._block: Block = block + self._op_bf16_dict = {} + self._var_name_dict = {} + + def _is_bf16_op(self, op_id): + return self._op_bf16_dict.get(op_id, None) + + def _build_state(self, amp_lists, dist_context): + ops = self._block.ops + dist_op_context = dist_context.dist_op_context + training = False + for op in ops: + if int(op.attr("op_role")) == 257: + training = True + + if int(op.attr("op_role")) == int(OpRole.Forward): + self._mark_black_white_op(amp_lists, op, ops) + elif int(op.attr("op_role")) == int(OpRole.Backward): + if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: + fwd_op_id = dist_op_context.grad_op_id_to_op_id[ + op.desc.original_id() + ] + if self._is_bf16_op(fwd_op_id) is True: + self._op_bf16_dict[op.desc.original_id()] = True + elif self._is_bf16_op(fwd_op_id) is False: + self._op_bf16_dict[op.desc.original_id()] = False + elif int(op.attr("op_role")) == int(OpRole.Optimize): + break + return training + + def _mark_black_white_op(self, amp_lists, op, ops): + if op.type == "create_py_reader" or op.type == "read": + return + if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames( + op, amp_lists + ): + self._op_bf16_dict[op.desc.original_id()] = False + return + if op.type in amp_lists.bf16_list: + self._op_bf16_dict[op.desc.original_id()] = True + elif op.type in amp_lists.gray_list: + is_fp32_op = False + is_bf16_op = False + for in_name in op.input_names: + if in_name: + for in_var_name in op.input(in_name): + in_var = self._block.var(in_var_name) + if in_var.op is None: + continue + elif in_var.op is op: + prev_op = find_true_prev_op(ops, op, in_var_name) + if prev_op is None: + continue + else: + prev_op = in_var.op + if ( + self._op_bf16_dict.get( + prev_op.desc.original_id(), False + ) + is False + or prev_op.type in amp_lists.fp32_list + ): + is_fp32_op = True + elif ( + self._op_bf16_dict.get( + prev_op.desc.original_id(), False + ) + is True + or prev_op.type in amp_lists.bf16_list + ): + is_bf16_op = True + if is_fp32_op: + self._op_bf16_dict[op.desc.original_id()] = False + elif is_bf16_op: + self._op_bf16_dict[op.desc.original_id()] = True + else: + pass + else: + self._op_bf16_dict[op.desc.original_id()] = False + + def cast_forward_program(self, dist_context): + ops = self._block.ops + idx = 0 + while idx < len(ops): + num_cast_ops = 0 + op = ops[idx] + if int(op.attr('op_role')) == int(OpRole.Backward): + break + if self._is_bf16_op(op.desc.original_id()) is False: + num_cast_ops = self._insert_cast_op_forward( + op, + idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32, + dist_context, + ) + elif self._is_bf16_op(op.desc.original_id()) is True: + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + op._set_attr('mkldnn_data_type', 'bfloat16') + elif ( + op.has_attr('dtype') + and op.attr('dtype') == core.VarDesc.VarType.FP32 + ): + op._set_attr('dtype', core.VarDesc.VarType.BF16) + + num_cast_ops = self._insert_cast_op_forward( + op, + idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16, + dist_context, + ) + else: + pass + + idx += num_cast_ops + 1 + self._block._sync_with_cpp() + + def _insert_cast_op_forward( + self, op, idx, src_dtype, dst_dtype, dist_context: DistributedContext + ): + num_cast_ops = 0 + var_name_dict = {} + for in_name in op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ + 'batch_norm', + 'fused_bn_add_activation', + 'layer_norm', + ]: + if in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = self._block.var(in_var_name) + if in_var.type not in _valid_types or in_var.dtype == dst_dtype: + continue + if in_var.dtype == src_dtype: + cast_name = ( + in_var.name + '.cast_' + _dtype_to_str(dst_dtype) + ) + var_name_dict[in_var.name] = cast_name + out_var = self._block.vars.get(cast_name) + consume_op_attr = dist_context.get_op_dist_attr_for_program( + op + ) + assert consume_op_attr is not None + in_var_dist_attr = consume_op_attr.get_input_dist_attr( + in_var_name + ) + if out_var is None or out_var.dtype != dst_dtype: + assert in_var_dist_attr is not None + ref_mesh = in_var_dist_attr.process_mesh + ref_mapping = in_var_dist_attr.dims_mapping + consume_op_attr.set_input_dist_attr( + cast_name, in_var_dist_attr + ) + + out_var = self._block.create_var( + name=cast_name, + dtype=dst_dtype, + persistable=False, + stop_gradient=in_var.stop_gradient, + ) + set_var_dist_attr( + dist_context, out_var, ref_mapping, ref_mesh + ) + + cast_op = self._block._insert_op_without_sync( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype, + }, + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, ref_mapping, dist_context + ) + num_cast_ops += 1 + else: + consume_op_attr.set_input_dist_attr( + cast_name, in_var_dist_attr + ) + _rename_arg(op, in_var_name, out_var.name) + else: + if op.has_attr('in_dtype'): + op._set_attr('in_dtype', dst_dtype) + self._var_name_dict[op.desc.original_id()] = var_name_dict + + if ( + src_dtype == core.VarDesc.VarType.FP32 + and dst_dtype == core.VarDesc.VarType.BF16 + ): + for out_name in op.output_names: + if ( + op.type + in ['batch_norm', 'fused_bn_add_activation', 'layer_norm'] + and out_name != 'Y' + ): + continue + for out_var_name in op.output(out_name): + out_var = self._block.var(out_var_name) + if out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + if op.has_attr('out_dtype'): + op._set_attr('out_dtype', core.VarDesc.VarType.BF16) + return num_cast_ops + + def cast_backward_program(self, params_grads, dist_context): + self._block._sync_with_cpp() + ops = self._block.ops + appended_grad_times = 0 + dist_op_context = dist_context.dist_op_context + loss_op = get_loss_op(self._block) + idx = find_op_index(self._block.desc, loss_op.desc) + 1 + while idx < len(ops): + num_cast_ops = 0 + grad_op = ops[idx] + op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op) + if is_backward_op(grad_op) and ( + is_forward_op(ops[idx - 1]) or is_loss_op(ops[idx - 1]) + ): + if not op_dist_attr.is_recompute: + appended_grad_times += 1 + if ( + grad_op.desc.original_id() + in dist_op_context.grad_op_id_to_op_id + ): + if self._is_bf16_op(grad_op.desc.original_id()) is False: + num_cast_ops = self._insert_cast_op_backward( + grad_op, + idx, + core.VarDesc.VarType.BF16, + core.VarDesc.VarType.FP32, + dist_context, + appended_grad_times, + ) + elif self._is_bf16_op(grad_op.desc.original_id()) is True: + if grad_op.has_attr('use_mkldnn'): + grad_op._set_attr('use_mkldnn', True) + grad_op._set_attr('mkldnn_data_type', 'bfloat16') + elif ( + grad_op.has_attr('dtype') + and grad_op.attr('dtype') == core.VarDesc.VarType.FP32 + ): + grad_op._set_attr('dtype', core.VarDesc.VarType.BF16) + num_cast_ops = self._insert_cast_op_backward( + grad_op, + idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.BF16, + dist_context, + appended_grad_times, + ) + elif grad_op.type == "sum": + in_var_name = grad_op.desc.input_arg_names()[0] + src_dtype = self._block.var(in_var_name).dtype + for in_var_name in grad_op.desc.input_arg_names(): + assert src_dtype == self._block.var(in_var_name).dtype + out_var_name = grad_op.desc.output_arg_names()[0] + out_var = self._block.var(out_var_name) + if out_var.dtype != src_dtype: + out_var.desc.set_dtype(src_dtype) + elif int(grad_op.attr("op_role")) == 257: + pass + else: + raise ValueError( + "'{}' op is not supported in the complete amp pass.".format( + grad_op.type + ) + ) + idx += num_cast_ops + 1 + self._block._sync_with_cpp() + _update_backward_cast_ops(params_grads, dist_context) + + def _insert_cast_op_backward( + self, + grad_op, + idx, + src_dtype, + dst_dtype, + dist_context, + appended_grad_times, + ): + def _keep_fp32_input(op, in_name): + op_type = op.type + if op_type in ['layer_norm_grad']: + return in_name not in {'X', 'Y@GRAD'} + return False + + def _keep_fp32_output(op, out_name): + op_type = op.type + if op_type in ['layer_norm_grad']: + return out_name != 'X@GRAD' + return False + + num_cast_ops = 0 + original_id = grad_op.desc.original_id() + dist_op_context = dist_context.dist_op_context + fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id] + for in_name in grad_op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( + grad_op, in_name + ): + for in_var_name in grad_op.input(in_name): + in_var = self._block._find_var_recursive(in_var_name) + assert in_var.dtype == core.VarDesc.VarType.FP32 + continue + for in_var_name in grad_op.input(in_name): + in_var = self._block._find_var_recursive(in_var_name) + if in_var.dtype == src_dtype: + consume_op_attr = dist_context.get_op_dist_attr_for_program( + grad_op + ) + if in_var_name in self._var_name_dict[fwd_op_id]: + cast_name = self._var_name_dict[fwd_op_id][in_var_name] + grad_op.desc._rename_input(in_var_name, cast_name) + in_var_dist_attr = consume_op_attr.get_input_dist_attr( + in_var_name + ) + consume_op_attr.set_input_dist_attr( + cast_name, in_var_dist_attr + ) + else: + assert ( + in_var.dtype == dst_dtype + ), "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format( + grad_op.type, + in_name, + dst_dtype, + in_var.dtype, + str(grad_op), + ) + + for out_name in grad_op.output_names: + if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( + grad_op, out_name + ): + for out_var_name in grad_op.output(out_name): + out_var = self._block._find_var_recursive(out_var_name) + assert out_var.dtype == core.VarDesc.VarType.FP32 + continue + + for out_var_name in grad_op.output(out_name): + out_var = self._block._find_var_recursive(out_var_name) + out_var_name_prefix = out_var_name[: out_var_name.find('@')] + fwd_var = self._block._find_var_recursive(out_var_name_prefix) + if out_var.dtype != fwd_var.dtype: + out_var.desc.set_dtype(fwd_var.dtype) + + if out_var.dtype == src_dtype: + if out_var_name_prefix in self._var_name_dict[fwd_op_id]: + consume_op_attr = ( + dist_context.get_op_dist_attr_for_program(grad_op) + ) + fwd_cast_name = self._var_name_dict[fwd_op_id][ + out_var_name_prefix + ] + suffix = '' + if "@RENAME" in out_var_name: + suffix = out_var_name[ + out_var_name.find("@RENAME") : + ] + cast_name = fwd_cast_name + "@GRAD" + suffix + cast_var = self._block.vars.get(cast_name) + if cast_var is None or cast_var.dtype != dst_dtype: + grad_op.desc._rename_output(out_var_name, cast_name) + out_var_dist_attr = ( + consume_op_attr.get_output_dist_attr( + out_var_name + ) + ) + ref_mesh = out_var_dist_attr.process_mesh + ref_mapping = out_var_dist_attr.dims_mapping + consume_op_attr.set_output_dist_attr( + cast_name, out_var_dist_attr + ) + assert ref_mapping is not None + cast_var = self._block.create_var( + name=cast_name, + shape=out_var.shape, + dtype=dst_dtype, + persistable=False, + stop_gradient=out_var.stop_gradient, + ) + set_var_dist_attr( + dist_context, cast_var, ref_mapping, ref_mesh + ) + dist_op_context.grad_var_to_var[ + appended_grad_times + ][cast_name] = fwd_cast_name + + cast_op = self._block._insert_op( + idx + 1, + type="cast", + inputs={"X": cast_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": cast_var.dtype, + "out_dtype": out_var.dtype, + "op_role": OpRole.Backward, + }, + ) + cast_op._remove_attr("op_role_var") + cast_op._remove_attr("op_namescope") + cast_op._remove_attr("with_quant_attr") + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, ref_mapping, dist_context + ) + num_cast_ops += 1 + else: + assert out_var.dtype == dst_dtype + return num_cast_ops + + +def _update_backward_cast_ops(params_grads, dist_context): + """ + move param grad cast to the end of backward segment + in order to enabel fp16 allreduce + """ + # TODO filter optimize ops in future + + main_block = paddle.static.default_main_program().global_block() + main_block._sync_with_cpp() + + for p, g in params_grads: + op = g.op + if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': + if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr( + 'op_role_var' + ): + op._remove_attr("op_role_var") + + post_ops = find_true_post_op(main_block.ops, op, g.name) + if post_ops: + raise ValueError( + "The cast op {0}'s output should not be" + "used by a non-optimize op, however, it" + "is used by {1}".format(op, post_ops[0]) + ) + + if op == main_block.ops[-1]: + continue + + # add new op in the python and cpp at the same time + new_op_desc = main_block.desc.append_op() + new_op_desc.copy_from(op.desc) + new_op = paddle.fluid.framework.Operator( + block=main_block, + desc=new_op_desc, + type=None, + inputs=None, + outputs=None, + attrs=None, + ) + main_block.ops.append(new_op) + + # dist attr + param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p) + output_dist_attr = dist_context.get_tensor_dist_attr_for_program( + main_block.var(op.output_arg_names[0]) + ) + assert param_dist_attr is not None + assert output_dist_attr is not None + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + new_op, + param_dist_attr.process_mesh, + param_dist_attr.dims_mapping, + dist_context, + ) + + output_dist_attr.process_mesh = param_dist_attr.process_mesh + output_dist_attr.dims_mapping = param_dist_attr.dims_mapping + + op_idx = find_op_index(main_block.desc, op.desc) + if op_idx == -1: + raise ValueError("The op {0} is not in program".format(op)) + main_block._remove_op(op_idx, sync=False) + + main_block._sync_with_cpp() + + +@register_pass("auto_parallel_bf16") +class BF16Pass(PassBase): + def __init__(self): + super().__init__() + self.set_attr("dist_context", None) + self.set_attr("custom_bf16_list", None) + self.set_attr("custom_fp32_list", None) + self.set_attr("custom_fp32_varnames", None) + self.set_attr("input_data", []) + self.set_attr("loss", None) + self.set_attr("params_grads", []) + self.set_attr("use_bf16_guard", False) + self._loss = None + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + self.dist_context = self.get_attr("dist_context") + params_grads = self.get_attr("params_grads") + + amp_lists = AutoMixedPrecisionListsBF16( + self.get_attr("custom_bf16_list"), + self.get_attr("custom_fp32_list"), + self.get_attr("custom_fp32_varnames"), + ) + + with static.program_guard(main_program, startup_program): + amp_state = BF16State(main_program.global_block()) + training = amp_state._build_state(amp_lists, self.dist_context) + amp_state.cast_forward_program(self.dist_context) + + if training: + with paddle.static.program_guard(main_program, startup_program): + amp_state.cast_backward_program(params_grads, self.dist_context) + self._scale_loss() + + def _scale_loss(self): + + main_block = paddle.static.default_main_program().global_block() + main_block._sync_with_cpp() + OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() + + loss = self.get_attr("loss") + assert loss is not None + loss_op = loss.op + loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( + loss_op + ) + if loss.dtype != core.VarDesc.VarType.FP32: + tmp_name = unique_name.generate(loss.name + ".cast_fp32") + cast_loss = main_block.create_var( + name=tmp_name, dtype=core.VarDesc.VarType.FP32 + ) + loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( + loss + ) + ref_mesh = loss_op_dist_attr.process_mesh + self.dist_context.set_tensor_dist_attr_for_program( + cast_loss, loss_dist_attr + ) + + loss_op_idx = find_op_index(main_block.desc, loss_op.desc) + cast_op = main_block._insert_op( + loss_op_idx + 1, + type='cast', + inputs={"X": [loss]}, + outputs={"Out": [cast_loss]}, + attrs={ + "in_dtype": loss.dtype, + "out_dtype": core.VarDesc.VarType.FP32, + "op_role": loss_op.all_attrs()[OP_ROLE_KEY], + }, + ) + + loss_op._set_attr( + OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, [-1], self.dist_context + ) + first_backward_op = main_block.ops[loss_op_idx + 2] + assert ( + first_backward_op.type == "fill_constant" + and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 + ) + cast_loss_grad = main_block.create_var( + name=unique_name.generate(tmp_name + "@GRAD"), + shape=loss.shape, + dtype=core.VarDesc.VarType.FP32, + persistable=loss.persistable, + ) + set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh) + pre_grad_name = first_backward_op.output_arg_names[0] + first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name) + cast_grad_op = main_block._insert_op( + loss_op_idx + 3, + type='cast', + inputs={'X': [cast_loss_grad]}, + outputs={'Out': [pre_grad_name]}, + attrs={ + "in_dtype": core.VarDesc.VarType.FP32, + "out_dtype": core.VarDesc.VarType.FP16, + 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, + }, + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_grad_op, ref_mesh, [-1], self.dist_context + ) + loss = cast_loss + self._loss = loss + main_block._sync_with_cpp() + + def get_loss(self): + if self._loss: + return self._loss + else: + return self.get_attr("loss") diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 90394ce24d0..249387a0781 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_cluster_partition MODULES test_cluster_partition) py_test_modules(test_convert_to_process_meshes MODULES test_convert_to_process_meshes) + py_test_modules(test_pass_bf16 MODULES test_pass_bf16) py_test_modules(test_dist_saver MODULES test_dist_saver) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py new file mode 100644 index 00000000000..f26908df2cf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py @@ -0,0 +1,211 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.nn as nn +from paddle.distributed.fleet import auto +from paddle.fluid.contrib.mixed_precision.bf16.amp_utils import _valid_types +from paddle.fluid.contrib.mixed_precision.fp16_utils import find_true_prev_op +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.static import InputSpec +from paddle.vision.datasets import MNIST + +paddle.enable_static() + + +def apply_pass(use_bf16=False): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_bf16: + amp = strategy.amp + amp.enable = True + amp.enable_bf16 = True + return strategy + + +class MnistDataset(MNIST): + def __init__(self, mode, return_label=True): + super().__init__(mode=mode) + self.return_label = return_label + + def __getitem__(self, idx): + img = np.reshape(self.images[idx], [1, 28, 28]) + if self.return_label: + return img, np.array(self.labels[idx]).astype('int64') + return (img,) + + def __len__(self): + return len(self.images) + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class Model(nn.Layer): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(784, 120) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(120, 10) + + def forward(self, input): + input.stop_gradient = True + x = self.flatten(input) + x = self.relu1(self.fc1(x)) + x = self.fc2(x) + return x + + +class TestBF16Pass(unittest.TestCase): + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 256 + self.batch_num = 10 + self.dataset = MnistDataset("train") + self.eval_dataset = MnistDataset("test") + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_bf16=False): + reset_prog() + + strategy = apply_pass(use_bf16) + model = Model() + opt = paddle.optimizer.SGD(0.001, parameters=model.parameters()) + loss = nn.CrossEntropyLoss() + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_program(self, program): + bf16_op_list = { + "matmul_v2", + "elementwise_add", + "relu", + "elementwise_add_grad", + "matmul_v2_grad", + "relu_grad", + } + + fp32_op_list = { + "flatten_contiguous_range", + "reduce_mean", + "softmax_with_cross_entropy", + "fill_constant", + "reduce_mean_grad", + "softmax_with_cross_entropy_grad", + } + + for block in program.blocks: + for op in block.ops: + if op not in bf16_op_list and op not in fp32_op_list: + continue + + for in_name in op.input_names: + for in_var_name in op.input(in_name): + var = None + try: + var = block.var(in_var_name) + except ValueError as e: + var = block._var_recursive(in_var_name) + if var is None or var.type not in _valid_types: + break + + if op.type in bf16_op_list: + assert var.dtype == core.VarDesc.VarType.BF16 + if "cast_bf16" in in_var_name: + if "@GRAD" in in_var_name: + tmp_in_var_name = in_var_name[ + : in_var_name.find("@GRAD") + ] + else: + tmp_in_var_name = in_var_name + prev_op = find_true_prev_op( + block.ops, op, tmp_in_var_name + ) + assert prev_op is not None + assert prev_op.type == "cast" + for in_name in prev_op.input_names: + for in_var_name in prev_op.input(in_name): + var = block.var(in_var_name) + assert ( + var.dtype + == core.VarDesc.VarType.FP32 + ) + + elif op.type in fp32_op_list: + if ( + op.type == "softmax_with_cross_entropy" + or op.type == "softmax_with_cross_entropy_grad" + ) and in_var_name == "label0": + continue + assert var.dtype == core.VarDesc.VarType.FP32 + if "cast_fp32" in in_var_name: + prev_op = find_true_prev_op( + block.ops, op, tmp_in_var_name + ) + assert prev_op is not None + assert prev_op.type == "cast" + for in_name in prev_op.input_names: + for in_var_name in prev_op.input(in_name): + var = block.var(in_var_name) + assert ( + var.dtype + == core.VarDesc.VarType.BF16 + ) + + for out_name in op.output_names: + for out_var_name in op.output(out_name): + var = None + try: + var = block.var(out_var_name) + except ValueError as e: + var = block._var_recursive(out_var_name) + + if var is None or var.type not in _valid_types: + break + if op.type in bf16_op_list: + assert var.dtype == core.VarDesc.VarType.BF16 + elif op.type in fp32_op_list: + assert var.dtype == core.VarDesc.VarType.FP32 + + def test_bf16_pass(self): + bf16_o1_engine = self.get_engine(True) + inputs_spec = [InputSpec([None, 1, 28, 28], 'float32', 'input0')] + labels_spec = [InputSpec([None, 1], 'int64', 'label0')] + bf16_o1_engine.prepare( + inputs_spec=inputs_spec, labels_spec=labels_spec, mode="train" + ) + self.check_program(bf16_o1_engine._dist_main_progs["train"][0]) + print("BF16!check program successfully!") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 529d1d5f625..0b41e323ffd 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase): self.assertEqual(amp.use_fp16_guard, True) self.assertEqual(amp.use_optimizer_fp16, False) + self.assertEqual(amp.enable_bf16, False) + self.assertEqual(amp.custom_bf16_list, []) + self.assertEqual(amp.custom_fp32_list, []) + self.assertEqual(amp.custom_fp32_varnames, []) + self.assertEqual(amp.use_pure_bf16, False) + self.assertEqual(amp.use_bf16_guard, False) + sharding = strategy.sharding self.assertEqual(sharding.enable, False) self.assertEqual(sharding.stage, 1) -- GitLab