diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index da2147f40363e898987f22e4edc643c49ca7f1da..9baee026a7b6c1bdec9291ace0696f391fa53d14 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -67,6 +67,8 @@ message AMPConfig { repeated string custom_black_varnames = 9; optional bool use_pure_fp16 = 10 [ default = false ]; optional bool use_fp16_guard = 11 [ default = true ]; + optional bool use_optimizer_fp16 = 12 + [ default = false ]; // auto parallel effective only } message LocalSGDConfig { diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 31539550f1c62a3f59c31e9bd9b8fa46af15dd85..fc5f1686d0f8c91ac16644e67380084a9cc74933 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -105,9 +105,15 @@ class AutoParallelizer: config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["loss"] = loss - auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply([main_program], [startup_program], - self._pass_context) + if config["use_pure_fp16"]: + config["base_opt"] = self._optimizer + auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) + auto_parallel_fp16_pass.apply( + [main_program], [startup_program], self._pass_context) + else: + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply([main_program], [startup_program], + self._pass_context) # apply recompute pass if self._dist_strategy.recompute: diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index e27c859c4e2c53fc8de2af5d513a91dcb855ce59..c03ef9c06d80fd6a9f49c4bcbd03864c62d4b949 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -357,10 +357,11 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, src_var = src_block.var(src_varname) if src_var.type in __not_shape_var_type__: + persist = getattr(src_var, 'persistable', False) new_var = dst_block.create_var( type=src_var.type, name=dst_varname, - persistable=True, + persistable=persist, stop_gradient=True) target_shape = None else: diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 642fefc621a9dc36305cbd970de38455f1b65f90..a7b5f3a2fd0d0f0cb07a912a6898da69693b7b8d 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1047,8 +1047,7 @@ def set_grad_var_shape(program, dist_context): forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_var_name) - - assert forward_input_dist_attr is not None, f"{forward_var_name}" + assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}" forward_var = vars[forward_var_name] forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( forward_var) diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py old mode 100755 new mode 100644 index 295e743512d4926b5b90c98c3b4b6abe0e96f6fe..c12e2138287fdbf5fd9f22a924b58a6dfc375d08 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -17,6 +17,7 @@ from .fuse_all_reduce import * from .auto_parallel_gradient_merge import * from .auto_parallel_sharding import * from .auto_parallel_amp import * +from .auto_parallel_fp16 import * from .auto_parallel_recompute import * from .cpp_pass import * import os diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index d69d6d4ab3286368d242651652b34c4d11c853fb..5fdd88ac1de8afde402b4d82ab7c8ecf9e4ab68e 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -503,8 +503,6 @@ class AMPPass(PassBase): return False if self.get_attr("decr_ratio") < 0: return False - if len(self.get_attr("params_grads")) <= 0: - return False if self.get_attr("dist_context") is None: return False return True @@ -576,6 +574,8 @@ class AMPPass(PassBase): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() + OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() + loss = self.get_attr("loss") assert loss is not None loss_op = loss.op @@ -583,6 +583,37 @@ class AMPPass(PassBase): loss_op) if loss.dtype != core.VarDesc.VarType.FP32: + # cast loss here will change the effective loss tensor for the computation graph + # and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge), + # so we it is not allowed by now. fixed it in future. + raise NotImplementedError( + "Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list." + ) + + tmp_name = unique_name.generate(loss.name + ".cast_fp32") + cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) + loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( + loss) + ref_mesh = loss_op_dist_attr.process_mesh + self.dist_context.set_tensor_dist_attr_for_program(cast_loss, + loss_dist_attr) + + loss_op_idx = find_op_index(main_block.desc, loss_op.desc) + cast_op = main_block._insert_op( + loss_op_idx + 1, + type='cast', + inputs={'X': [loss]}, + outputs={'Out': [cast_loss]}, + attrs={ + "in_dtype": loss.dtype, + "out_dtype": core.VarDesc.VarType.FP32, + 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], + }) + + loss_op._set_attr(OP_ROLE_KEY, + core.op_proto_and_checker_maker.OpRole.Forward) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, [-1], self.dist_context) loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling") or self.get_attr( @@ -600,7 +631,6 @@ class AMPPass(PassBase): set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) - OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', @@ -667,8 +697,11 @@ class AMPPass(PassBase): for e in grads: check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling') - assert self._loss_scaling.dtype == e.dtype, \ - "The dtype of prev_loss_scaling should be equal to the dtype of x." + if e.dtype == core.VarDesc.VarType.FP16: + assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \ + "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." + else: + assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." inputs = { 'X': grads, diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..725b4459d7d212a99381abf728dee8c229cd7c3d --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -0,0 +1,570 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import paddle +from paddle.framework import core +from paddle.fluid import unique_name +from .pass_base import register_pass +from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type +from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping +from paddle.distributed.auto_parallel.process_group import get_world_process_group +from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists +from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str +from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute +from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op +from .auto_parallel_amp import AMPPass + +world_process_group = get_world_process_group() +# if user use python "+, -, * /" for network, there might be cast in vanilla program +__amp_skip_ops__ = [ + 'create_py_reader', + 'create_double_buffer_reader', + 'while', + 'cast', +] + + +def set_op_dtype_to_fp16(op): + if op.has_attr('in_dtype') and op.attr( + 'in_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('in_dtype', core.VarDesc.VarType.FP16) + if op.has_attr('out_dtype') and op.attr( + 'out_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.FP16) + + +# adapot for backward op +def _keep_fp32_input(op, in_name): + op_type = op.type + if op_type == 'batch_norm': + # Scale, Bias, Mean, Variance should be float32. + return in_name != 'X' + if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32(): + return in_name != 'X' + if op_type == 'fused_bn_add_activation': + return in_name not in {'X', 'Z'} + if op_type == 'resnet_unit': + return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} + if op_type in ['fused_attention', 'fused_feedforward']: + return in_name in { + 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" + } + # backward + if op_type in ['batch_norm_grad']: + return in_name not in {'X', 'Y@GRAD'} + if op_type in ['layer_norm_grad']: + return in_name not in {'X', 'Y@GRAD'} + return False + + +def _keep_fp32_output(op, out_name): + op_type = op.type + if op_type in ['batch_norm', 'fused_bn_add_activation']: + return out_name != 'Y' + if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32(): + return out_name != 'Y' + if op_type == 'resnet_unit': + return out_name not in {'Y', 'ConvX', 'ConvZ'} + if op_type in ['fused_attention', 'fused_feedforward']: + return out_name in { + 'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', + 'Ln1Variance' + } + # backward + if op_type in ['layer_norm_grad']: + return out_name != 'X@GRAD' + if op_type in ['batch_norm_grad']: + return out_name != 'X@GRAD' + return False + + +class FP16State(object): + def __init__(self, program, amp_list, dist_context, use_fp16_guard): + self.program = program + self.amp_list = amp_list + self.use_fp16_guard = use_fp16_guard + self.dist_context = dist_context + self.grad_op_to_op_map = self.dist_context.dist_op_context.grad_op_id_to_op_id + self._op_fp16_dict = { + } # op_id --> True/False. 'True' means that the op is should run in fp16 mode. + # a trick to determine leaf tensor node in program {varname: generator_op_id} + self.forward_non_leaf_tensors = {} + # record the cast ops that are inserted for a forward + self.forward_input_cast_ops = defaultdict( + list + ) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]} + self.is_train = False + + def _is_fp16_op(self, op_id): + return self._op_fp16_dict.get(op_id, None) + + def _build_state(self): + """ + mark the execution mode (fp16 or fp32) for ops in all blocks + include forward ops & backward ops + """ + # mark op dtype + # assume all backward block are behind forward blocks + for block in self.program.blocks: + for op in block.ops: + self._mark_op(op) + + # set forward tensor dtype + for block in self.program.blocks: + self.resolute_tensor_dtype(block) + + # insert cast ops + for block in self.program.blocks: + self.cast_block(block) + + return self.is_train + + def _mark_op(self, op): + + if op.type in __amp_skip_ops__: + return + + if is_forward_op(op): + + # ernie inference trick + if op.type == "assign" and "array_" in op.input_arg_names[0]: + self._op_fp16_dict[op.desc.id()] = False + return + if _need_keep_fp32(op, self.amp_list.unsupported_list, + self.use_fp16_guard): + self._op_fp16_dict[op.desc.id()] = False + else: + self._op_fp16_dict[op.desc.id()] = True + for var_name in op.output_arg_names: + # assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name) + self.forward_non_leaf_tensors[var_name] = op.desc.id() + + elif is_backward_op(op) == int(OpRole.Backward): + + if op.desc.id() in self.grad_op_to_op_map: + fwd_op_id = self.grad_op_to_op_map[op.desc.id()] + assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op)) + self._op_fp16_dict[op.desc.id()] = self._op_fp16_dict[fwd_op_id] + + if int(op.attr('op_role')) == 257: + self.is_train = True + + def set_var_to_fp16(self, var_name, block): + var = None + try: + var = block.var(var_name) + except ValueError as e: + var = self.program.global_block().var(var_name) + + # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is + # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY + if var is None or var.type not in _valid_types or "array_" in var_name: + return + + if var.dtype == core.VarDesc.VarType.FP32: + var.desc.set_dtype(core.VarDesc.VarType.FP16) + + def resolute_tensor_dtype(self, block): + + for op in block.ops: + op_id = op.desc.id() + if is_forward_op(op): + # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python + if self._is_fp16_op(op_id) == True or op.type == "cast": + for in_name in op.input_names: + if _keep_fp32_input(op, in_name): + continue + for in_var_name in op.input(in_name): + if in_var_name not in self.forward_non_leaf_tensors: + self.set_var_to_fp16(in_var_name, block) + for out_name in op.output_names: + if _keep_fp32_output(op, out_name): + continue + for out_var_name in op.output(out_name): + self.set_var_to_fp16(out_var_name, block) + set_op_dtype_to_fp16(op) + # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python + elif self._is_fp16_op(op_id) == False: + for out_var_name in op.output_arg_names: + out_var = block.vars.get(out_var_name) + if out_var is None or out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP16: + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + elif is_backward_op(op): + if self._is_fp16_op(op_id) == True: + for out_name in op.output_names: + if _keep_fp32_output(op, out_name): + continue + for out_var_name in op.output(out_name): + self.set_var_to_fp16(out_var_name, block) + set_op_dtype_to_fp16(op) + # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python + elif self._is_fp16_op(op_id) == False: + for out_var_name in op.output_arg_names: + out_var = block.vars.get(out_var_name) + if out_var is None or out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP16: + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + + def cast_block(self, block): + dist_op_context = self.dist_context.dist_op_context + idx = 0 + while idx < len(block.ops): + op = block.ops[idx] + op_id = op.desc.id() + num_cast_ops = 0 + + if op.type in __amp_skip_ops__: + idx += 1 + continue + elif is_forward_op(op): + if self._is_fp16_op(op_id) == False: + num_cast_ops = self._insert_forward_cast_ops( + op, idx, block, core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32, self.dist_context) + elif self._is_fp16_op(op_id) == True: + num_cast_ops = self._insert_forward_cast_ops( + op, idx, block, core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, self.dist_context) + elif is_backward_op(op): + if op_id in dist_op_context.grad_op_id_to_op_id: + if self._is_fp16_op(op_id) == False: + num_cast_ops = self._insert_backward_cast_ops( + op, idx, block, core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32, self.dist_context) + elif self._is_fp16_op(op_id) == True: + num_cast_ops = self._insert_backward_cast_ops( + op, idx, block, core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, self.dist_context) + elif op.type == "sum": + # all inputs dtype of sum should be equal and output dtype should follow input + out_var_name = op.output_arg_names[0] + in_var_name = op.input_arg_names[0] + out_var = block.var(out_var_name) + in_var = block._find_var_recursive(in_var_name) + for in_var_name in op.input_arg_names: + assert in_var.dtype == block.var( + in_var_name).dtype, "{}, {}, {}".format( + in_var, block.var(in_var_name), str(op)) + out_var.desc.set_dtype(in_var.dtype) + + idx += num_cast_ops + 1 + block._sync_with_cpp() + + def _insert_forward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, + dist_context): + + num_cast_ops = 0 + op_id = op.desc.id() + + for in_name in op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( + op, in_name): + continue + + consume_op_attr = dist_context.get_op_dist_attr_for_program(op) + assert consume_op_attr is not None + for in_var_name in op.input(in_name): + in_var = block._find_var_recursive(in_var_name) + if in_var is None or in_var.type not in _valid_types or in_var.dtype == dst_dtype: + continue + + if in_var.dtype == src_dtype: + cast_name = in_var.name + '.cast_' + _dtype_to_str( + dst_dtype) + cast_var = block.vars.get(cast_name) + self.forward_input_cast_ops[op_id] += [( + cast_name, in_var.name, dst_dtype, src_dtype, in_name)] + + in_var_dist_attr = consume_op_attr.get_input_dist_attr( + in_var.name) + assert in_var_dist_attr is not None + # truely insert cast op + if cast_var is None or cast_var.dtype != dst_dtype: + # NOTE we make the cast op and var's dist attr as the op that consume the + # cast var instead of the op which generates the var + # refine op's dist_attr + ref_mesh = in_var_dist_attr.process_mesh + ref_mapping = in_var_dist_attr.dims_mapping + + cast_var = block.create_var( + name=cast_name, + dtype=dst_dtype, + persistable=False, + stop_gradient=in_var.stop_gradient) + set_var_dist_attr(dist_context, cast_var, ref_mapping, + ref_mesh) + + cast_op = block._insert_op_without_sync( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": cast_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": cast_var.dtype, + }) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, ref_mapping, dist_context) + num_cast_ops += 1 + + op._rename_input(in_var.name, cast_name) + consume_op_attr.set_input_dist_attr(cast_name, + in_var_dist_attr) + + if op.has_attr('out_dtype') and op.attr('out_dtype') != -1: + assert op.attr('out_dtype') == dst_dtype + + return num_cast_ops + + def _insert_backward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, + dist_context): + + num_cast_ops = 0 + op_id = op.desc.id() + dist_op_context = dist_context.dist_op_context + forward_op_id = dist_op_context.grad_op_id_to_op_id[op_id] + + grad_op_attr = dist_context.get_op_dist_attr_for_program(op) + assert grad_op_attr is not None + + for out_var_name in op.output_arg_names: + out_var = block.var(out_var_name) + if _keep_fp32_output(op, out_var.name): + continue + assert out_var.dtype == dst_dtype, "{}, {}".format( + str(out_var), dst_dtype) + + for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[ + forward_op_id]: + + # rename input + assert src_name in op.input( + slot_name), "var: {} not in op's {}. {}".format(src_name, + slot_name, + str(op)) + src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) + assert src_var_dist_attr is not None + op._rename_input(src_name, cast_name) + grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) + + # create cast grad + grad_slot_name = slot_name + "@GRAD" + assert grad_slot_name in op.output_names + assert len(op.output(grad_slot_name)) == 1 + grad_name = op.output(grad_slot_name)[0] + grad = block.var(grad_name) + grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) + assert grad_dist_attr is not None, "{}".format(grad_name) + ref_mesh = grad_dist_attr.process_mesh + ref_mapping = grad_dist_attr.dims_mapping + + cast_grad = block.create_var( + name=unique_name.generate_with_ignorable_key("".join( + [cast_name, '@GRAD'])), + dtype=dst_dtype, + shape=grad.shape, + type=grad.type, + persistable=grad.persistable, + stop_gradient=grad.stop_gradient) + dist_context.set_tensor_dist_attr_for_program(cast_grad, + grad_dist_attr) + op._rename_output(grad_name, cast_grad.name) + grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) + + # add cast + cast_op = block._insert_op_without_sync( + idx + 1, + type="cast", + inputs={"X": [cast_grad.name]}, + outputs={"Out": [grad.name]}, + attrs={ + "in_dtype": dst_dtype, + "out_dtype": src_dtype, + }) + grad.desc.set_dtype(src_dtype) + + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, ref_mapping, dist_context) + num_cast_ops += 1 + + return num_cast_ops + + +def _check_and_update_gradient(grads, loss_scaling, name, dist_context): + + main_block = paddle.static.default_main_program().global_block() + main_block._sync_with_cpp() + + check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') + for e in grads: + check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], + 'check_finite_and_unscale') + + found_inf = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ['find_infinite_scale', name])), + shape=[1], + dtype='bool', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) + + inputs = {'X': grads, 'Scale': loss_scaling} + outputs = {'Out': grads, 'FoundInfinite': found_inf} + attrs = {'op_role': OpRole.Backward} + new_op = main_block.append_op( + type='check_finite_and_unscale', + inputs=inputs, + outputs=outputs, + attrs=attrs) + + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = world_process_group.ranks + new_op_dist_attr.impl_idx = 0 + if len(world_process_group.ranks) > 1: + new_op_dist_attr.impl_type = "check_finite_and_unscale" + for g in grads: + g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) + assert g_dist_attr is not None + new_op_dist_attr.set_input_dims_mapping(g.name, + g_dist_attr.dims_mapping) + new_op_dist_attr.set_output_dims_mapping(g.name, + g_dist_attr.dims_mapping) + dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + return grads, found_inf + + +def _split_grads(params_grads): + grads = [g for _, g in params_grads] + fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] + fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] + assert len(fp32_grads) + len(fp16_grads) == len(grads), \ + "Data types of all grads must be either fp16 or fp32." + return grads, fp32_grads, fp16_grads + + +def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = ranks + new_op_dist_attr.impl_idx = 0 + for var_name in new_op.input_arg_names: + var = block.var(var_name) + var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) + assert var_dist_attr is not None + new_op_dist_attr.set_input_dims_mapping(var_name, + var_dist_attr.dims_mapping) + for var_name in new_op.output_arg_names: + var = block.var(var_name) + var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) + assert var_dist_attr is not None + new_op_dist_attr.set_output_dims_mapping(var_name, + var_dist_attr.dims_mapping) + dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + +@register_pass("auto_parallel_fp16") +class FP16Pass(AMPPass): + def __init__(self): + super(FP16Pass, self).__init__() + + # NOTE: why FP16Pass can override apply_single_impl instead of + # apply_impl? AMP is an optimization pass for serial program, + # in distributed scenario, all ranks should have the same modification. + def _apply_single_impl(self, main_program, startup_program, context): + self.dist_context = self.get_attr("dist_context") + params_grads = self.get_attr("params_grads") + + amp_list = AutoMixedPrecisionLists( + set(self.get_attr("custom_white_list")), + set(self.get_attr("custom_black_list")), None) + + # TODO support multiple blocks + with paddle.static.program_guard(main_program, startup_program): + fp16_state = FP16State(main_program, amp_list, self.dist_context, + self.get_attr("use_fp16_guard")) + is_train = fp16_state._build_state() + + if is_train: + with paddle.static.program_guard(main_program, startup_program): + # TODO (JZ-LIANG)support cast forward program only when inference + self._init_amp_var() + self._scale_loss() + + grads, fp32_grads, fp16_grads = _split_grads(params_grads) + + if self.get_attr("use_dynamic_loss_scaling") or self.get_attr( + "init_loss_scaling") != 1.0: + found_infs = [] + if fp32_grads: + with main_program._backward_role_guard(): + _, found_inf_fp32 = _check_and_update_gradient( + fp32_grads, self._loss_scaling, "@fp32", + self.dist_context) + found_infs.append(found_inf_fp32) + if fp16_grads: + with main_program._backward_role_guard(): + _, found_inf_fp16 = _check_and_update_gradient( + fp16_grads, self._loss_scaling, "@fp16", + self.dist_context) + found_infs.append(found_inf_fp16) + with main_program._backward_role_guard(): + block = main_program.global_block() + + all_infs = paddle.fluid.layers.concat(found_infs) + set_var_dist_attr(self.dist_context, all_infs, [-1], + world_process_group.ranks) + new_op = block.ops[-1] + assert new_op.type == "concat" + _set_op_dist_attr_with_ranks(new_op, + world_process_group.ranks, + block, self.dist_context) + + found_inf = paddle.fluid.layers.reduce_any(all_infs) + set_var_dist_attr(self.dist_context, found_inf, [-1], + world_process_group.ranks) + new_op = block.ops[-1] + assert new_op.type == "reduce_any" + _set_op_dist_attr_with_ranks(new_op, + world_process_group.ranks, + block, self.dist_context) + + if self.get_attr("use_dynamic_loss_scaling"): + with main_program._backward_role_guard(): + if fp32_grads: + self._update_loss_scaling(fp32_grads, found_inf) + if fp16_grads: + self._update_loss_scaling(fp16_grads, found_inf) + + # modify optimizer + base_opt = self.get_attr("base_opt") + base_opt._multi_precision = True + if self.get_attr("use_optimizer_fp16"): + base_opt._multi_precision = False + if isinstance(base_opt, (paddle.fluid.optimizer.Adam, + paddle.optimizer.AdamW)): + # with main_program._optimized_guard([]): + # found_inf = paddle.tensor.creation._memcpy( + # found_inf, paddle.CPUPlace()) + base_opt._set_auxiliary_var('found_inf', found_inf.name) + elif hasattr(base_opt, "_set_auxiliary_var"): + base_opt._set_auxiliary_var('found_inf', found_inf.name) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt old mode 100755 new mode 100644 index 729c9c46b4f0cab2374d951b54deeaffe9cb0c1d..764a862d30f5526abc596a416fefd4f8d2648a30 --- a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt @@ -14,6 +14,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_XPU) AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass") + list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass") endif() foreach(TEST_OP ${TEST_OPS}) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc60bc6782ea78b15d9241db59dd3e338a07235 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numpy as np + +import unittest +import paddle +import paddle.distributed.fleet as fleet +from auto_parallel_pass_test_base import AutoPallelPassTestBase +from test_auto_parallel_amp_pass import TestAMPPass + + +class TestPF16Pass(TestAMPPass): + def apply_passes(self): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = True + dist_strategy.amp_configs = { + "custom_white_list": [ + 'softmax', + 'layer_norm', + 'gelu', + ], + "custom_black_list": ['c_softmax_with_cross_entropy'], + "init_loss_scaling": 32768, + "use_dynamic_loss_scaling": True, + "use_pure_fp16": True + } + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + +if __name__ == "__main__": + unittest.main()