diff --git a/paddle/fluid/operators/collective/c_embedding_op.cu b/paddle/fluid/operators/collective/c_embedding_op.cu index 3a8294f52fced927ac052253daa82dcf2545322a..81f142c9e23628fcbde738db36295f9470efa37e 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cu +++ b/paddle/fluid/operators/collective/c_embedding_op.cu @@ -198,8 +198,14 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_embedding, ops::CEmbeddingCUDAKernel, ops::CEmbeddingCUDAKernel, +#if NCCL_VERSION_CODE >= 21000 + ops::CEmbeddingCUDAKernel, +#endif ops::CEmbeddingCUDAKernel); REGISTER_OP_CUDA_KERNEL(c_embedding_grad, ops::CEmbeddingGradCUDAKernel, ops::CEmbeddingGradCUDAKernel, +#if NCCL_VERSION_CODE >= 21000 + ops::CEmbeddingGradCUDAKernel, +#endif ops::CEmbeddingGradCUDAKernel); diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 3f7d2e7a4948d631647848fa7761fff081fcfd9c..d38ae1ff017b4ef61e9200275b525f6428b4c4b8 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -63,6 +63,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ######################################### AMP = "amp" set_field_default_config(AMP, "enable", False) +set_field_default_config(AMP, "dtype", "float16") +set_field_default_config(AMP, "level", "o1") set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) @@ -72,15 +74,12 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_varnames", []) -set_field_default_config(AMP, "use_pure_fp16", False) -set_field_default_config(AMP, "use_fp16_guard", True) +set_field_default_config(AMP, "use_fp16_guard", False) set_field_default_config(AMP, "use_optimizer_fp16", False) -set_field_default_config(AMP, "enable_bf16", False) set_field_default_config(AMP, "custom_bf16_list", []) set_field_default_config(AMP, "custom_fp32_list", []) set_field_default_config(AMP, "custom_fp32_varnames", []) -set_field_default_config(AMP, "use_pure_bf16", False) set_field_default_config(AMP, "use_bf16_guard", False) ######################################### diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index cb4060b2593eedbda9bc627927e48909703ec87b..08b00a5c7f63bbb83061ff798d96a057a37e6916 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -455,7 +455,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): check_variable_and_dtype( Out_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'c_allreduce_sum', ) @@ -645,7 +645,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): check_variable_and_dtype( Out_grad, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], '_c_identity', ) @@ -687,12 +687,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): }, ) check_variable_and_dtype( - intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' + intermediate_var_0, + 'x', + ['float16', 'float32', 'float64', 'uint16'], + 'linear', ) check_dtype( intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], + ['float16', 'float32', 'float64', 'uint16'], 'linear', ) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 175a1d263498671a879c2dabc7d9898f7232a491..3f7c3999cebef1e0a16c67f911c42583a110bc05 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -220,27 +220,26 @@ class Parallelizer: self._dist_context.serial_feed_vars["inputs"] + self._dist_context.serial_feed_vars["labels"] ) - if config["enable_bf16"]: - auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config) - auto_parallel_bf16_pass.apply( + self._logger.info( + "Applying AMP-{}-{} ...".format( + config["dtype"], config['level'] + ), + ) + if config['level'] == "o1": + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply( [main_program], [startup_program], self._pass_context ) - loss = auto_parallel_bf16_pass.get_loss() - - elif config["use_pure_fp16"]: + loss = auto_parallel_amp_pass.get_loss() + elif config['level'] in ['o2', 'o3']: config["base_opt"] = optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply( [main_program], [startup_program], self._pass_context ) loss = auto_parallel_fp16_pass.get_loss() - else: - auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply( - [main_program], [startup_program], self._pass_context - ) - loss = auto_parallel_amp_pass.get_loss() + raise ValueError("AMP level should be one of o1, o2, o3") # apply quantization pass # The pass can be applied when mode must be 'train' diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index c2e5933dbf965701689eb197b5974f37d4c92695..05f044b48de244cffbe7671c9caf6f45f02ddd9b 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -632,6 +632,7 @@ class AMPPass(PassBase): self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("input_data", []) self.set_attr("params_grads", []) + self.set_attr("dtype", "") # fp16/bf16 self._loss = None self._loss_scaling = None self._num_good_steps = None @@ -639,6 +640,8 @@ class AMPPass(PassBase): self._loss = None def _check_self(self): + if self.get_attr("dtype") not in ["float16", "bfloat16"]: + return False if self.get_attr("init_loss_scaling") < 0: return False if self.get_attr("incr_every_n_steps") < 0: diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 50307ce22f8428b08bc5413ededd95517b0dd16f..6613200f4a43106327e896251408097e17af915f 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -29,13 +29,9 @@ from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.framework import core from paddle.static import default_main_program, default_startup_program -from paddle.static.amp.fp16_utils import ( - AutoMixedPrecisionLists, - _dtype_to_str, - _keep_layer_norm_scale_bias_to_fp32, - _need_keep_fp32, - _valid_types, -) + +# NOTE bf16 and fp16 may have diff logic for _keep_layer_norm_scale_bias_to_fp32 +from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32 from paddle.utils import unique_name from ..auto_parallel.process_mesh import ProcessMesh @@ -50,6 +46,8 @@ __amp_skip_ops__ = [ 'while', 'cast', ] +__target_dtype__ = None +__amp_utils__ = None def set_op_dtype_to_fp16(op): @@ -57,17 +55,24 @@ def set_op_dtype_to_fp16(op): op.has_attr('in_dtype') and op.attr('in_dtype') == core.VarDesc.VarType.FP32 ): - op._set_attr('in_dtype', core.VarDesc.VarType.FP16) + op._set_attr('in_dtype', __target_dtype__) if ( op.has_attr('out_dtype') and op.attr('out_dtype') == core.VarDesc.VarType.FP32 ): - op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + op._set_attr('out_dtype', __target_dtype__) if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + op._set_attr('dtype', __target_dtype__) + + if __target_dtype__ == core.VarDesc.VarType.BF16: + if op.has_attr('use_mkldnn'): + op._set_attr('use_mkldnn', True) + if op.has_attr('mkldnn_data_type'): + op._set_attr('mkldnn_data_type', 'bfloat16') # adapot for backward op +# TODO check if bf16 and fp16 still share the same logic def _keep_fp32_input(op, in_name): op_type = op.type if op_type == 'batch_norm': @@ -96,6 +101,7 @@ def _keep_fp32_input(op, in_name): return False +# TODO check if bf16 and fp16 still share the same logic def _keep_fp32_output(op, out_name): op_type = op.type if op_type in ['batch_norm', 'fused_bn_add_activation']: @@ -208,7 +214,7 @@ class FP16State: self._op_fp16_dict[op.desc.original_id()] = True return - if _need_keep_fp32( + if __amp_utils__._need_keep_fp32( op, self.amp_list.unsupported_list, self.use_fp16_guard ): self._op_fp16_dict[op.desc.original_id()] = False @@ -240,11 +246,15 @@ class FP16State: # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY - if var is None or var.type not in _valid_types or "array_" in var_name: + if ( + var is None + or var.type not in __amp_utils__._valid_types + or "array_" in var_name + ): return if var.dtype == core.VarDesc.VarType.FP32: - var.desc.set_dtype(core.VarDesc.VarType.FP16) + var.desc.set_dtype(__target_dtype__) def resolute_tensor_dtype(self, block): @@ -274,9 +284,12 @@ class FP16State: elif self._is_fp16_op(op.desc.original_id()) is False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: + if ( + out_var is None + or out_var.type not in __amp_utils__._valid_types + ): continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == __target_dtype__: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) elif is_backward_op(op): if self._is_fp16_op(op.desc.original_id()) is True: @@ -290,9 +303,12 @@ class FP16State: elif self._is_fp16_op(op.desc.original_id()) is False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) - if out_var is None or out_var.type not in _valid_types: + if ( + out_var is None + or out_var.type not in __amp_utils__._valid_types + ): continue - if out_var.dtype == core.VarDesc.VarType.FP16: + if out_var.dtype == __target_dtype__: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) def cast_block(self, block): @@ -311,7 +327,7 @@ class FP16State: op, idx, block, - core.VarDesc.VarType.FP16, + __target_dtype__, core.VarDesc.VarType.FP32, self.dist_context, ) @@ -321,7 +337,7 @@ class FP16State: idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + __target_dtype__, self.dist_context, ) elif is_backward_op(op): @@ -331,7 +347,7 @@ class FP16State: op, idx, block, - core.VarDesc.VarType.FP16, + __target_dtype__, core.VarDesc.VarType.FP32, self.dist_context, ) @@ -341,7 +357,7 @@ class FP16State: idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, + __target_dtype__, self.dist_context, ) elif op.type == "sum": @@ -379,14 +395,16 @@ class FP16State: in_var = block._find_var_recursive(in_var_name) if ( in_var is None - or in_var.type not in _valid_types + or in_var.type not in __amp_utils__._valid_types or in_var.dtype == dst_dtype ): continue if in_var.dtype == src_dtype: cast_name = ( - in_var.name + '.cast_' + _dtype_to_str(dst_dtype) + in_var.name + + '.cast_' + + __amp_utils__._dtype_to_str(dst_dtype) ) cast_var = block.vars.get(cast_name) self.forward_input_cast_ops[op.desc.original_id()] += [ @@ -476,71 +494,72 @@ class FP16State: slot_name, ) in self.forward_input_cast_ops[forward_op_id]: + # rename input # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy - if slot_name not in op.input_names: - continue + if slot_name in op.input_names: - # rename input - assert src_name in op.input( - slot_name - ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op)) - src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) - assert src_var_dist_attr is not None - op._rename_input(src_name, cast_name) - grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) + assert src_name in op.input( + slot_name + ), "var: {} not in op's {}. {}".format( + src_name, slot_name, str(op) + ) + src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) + assert src_var_dist_attr is not None + op._rename_input(src_name, cast_name) + grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr) # create cast grad grad_slot_name = slot_name + "@GRAD" - if grad_slot_name not in op.output_names: - continue + if grad_slot_name in op.output_names: + # some forward input maybe stop_gradient=True, e.g. input_mask + if len(op.output(grad_slot_name)) == 0: + continue + assert ( + len(op.output(grad_slot_name)) == 1 + ), "[{}], Current Op: {}".format(grad_slot_name, str(op)) + grad_name = op.output(grad_slot_name)[0] + grad = block.var(grad_name) + grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) + assert grad_dist_attr is not None, "{}".format(grad_name) + ref_mesh = grad_dist_attr.process_mesh + ref_mapping = grad_dist_attr.dims_mapping + + cast_grad = block.create_var( + name=unique_name.generate_with_ignorable_key( + "".join([cast_name, '@GRAD']) + ), + dtype=dst_dtype, + shape=grad.shape, + type=grad.type, + persistable=grad.persistable, + stop_gradient=grad.stop_gradient, + ) + dist_context.set_tensor_dist_attr_for_program( + cast_grad, grad_dist_attr + ) + op._rename_output(grad_name, cast_grad.name) + grad_op_attr.set_output_dist_attr( + cast_grad.name, grad_dist_attr + ) - # some forward input maybe stop_gradient=True, e.g. input_mask - if len(op.output(grad_slot_name)) == 0: - continue - assert ( - len(op.output(grad_slot_name)) == 1 - ), "[{}], Current Op: {}".format(grad_slot_name, str(op)) - grad_name = op.output(grad_slot_name)[0] - grad = block.var(grad_name) - grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name) - assert grad_dist_attr is not None, "{}".format(grad_name) - ref_mesh = grad_dist_attr.process_mesh - ref_mapping = grad_dist_attr.dims_mapping - - cast_grad = block.create_var( - name=unique_name.generate_with_ignorable_key( - "".join([cast_name, '@GRAD']) - ), - dtype=dst_dtype, - shape=grad.shape, - type=grad.type, - persistable=grad.persistable, - stop_gradient=grad.stop_gradient, - ) - dist_context.set_tensor_dist_attr_for_program( - cast_grad, grad_dist_attr - ) - op._rename_output(grad_name, cast_grad.name) - grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) - - # add cast - cast_op = block._insert_op_without_sync( - idx + 1, - type="cast", - inputs={"X": [cast_grad.name]}, - outputs={"Out": [grad.name]}, - attrs={ - "in_dtype": dst_dtype, - "out_dtype": src_dtype, - OP_ROLE_KEY: OpRole.Backward, - }, - ) - grad.desc.set_dtype(src_dtype) + # add cast + cast_op = block._insert_op_without_sync( + idx + 1, + type="cast", + inputs={"X": [cast_grad.name]}, + outputs={"Out": [grad.name]}, + attrs={ + "in_dtype": dst_dtype, + "out_dtype": src_dtype, + OP_ROLE_KEY: OpRole.Backward, + }, + ) + grad.desc.set_dtype(src_dtype) - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, ref_mapping, dist_context - ) - num_cast_ops += 1 + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_mesh, ref_mapping, dist_context + ) + num_cast_ops += 1 return num_cast_ops @@ -604,7 +623,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): def _split_grads(params_grads): grads = [g for _, g in params_grads] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] - fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] + fp16_grads = [g for g in grads if g.dtype == __target_dtype__] assert len(fp32_grads) + len(fp16_grads) == len( grads ), "Data types of all grads must be either fp16 or fp32." @@ -707,17 +726,17 @@ def cast_startup_program(): for op in startup_program.global_block().ops: if is_initialization_op(op): output_name = op.output_arg_names[0] - if ( - param_to_dtype.get(output_name, None) - == core.VarDesc.VarType.FP16 - ): + if param_to_dtype.get(output_name, None) == __target_dtype__: assert op.has_attr( 'dtype' ), "initialization op is supported to has dtype attribute but got {}.".format( str(op) ) + out_var = startup_program.global_block().var(output_name) + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(__target_dtype__) if op.attr('dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + op._set_attr('dtype', __target_dtype__) @register_pass("auto_parallel_fp16") @@ -730,9 +749,37 @@ class FP16Pass(AMPPass): # in distributed scenario, all ranks should have the same modification. def _apply_single_impl(self, main_program, startup_program, context): self.dist_context = self.get_attr("dist_context") + self.target_dtype = self.get_attr("dtype") params_grads = self.get_attr("params_grads") - amp_list = AutoMixedPrecisionLists( + self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None) + if self.use_optimizer_fp16 is None: + self.use_optimizer_fp16 = self.get_attr("level", None) == "o3" + + # swith enviroment for fp16 / bf16. + if self.target_dtype == "float16": + import paddle.static.amp.fp16_utils as amp_utils + + AMPList = amp_utils.AutoMixedPrecisionLists + __target_dtype = core.VarDesc.VarType.FP16 + + elif self.target_dtype == "bfloat16": + import paddle.static.amp.bf16.amp_utils as amp_utils + + AMPList = amp_utils.AutoMixedPrecisionListsBF16 + __target_dtype = core.VarDesc.VarType.BF16 + + else: + raise NotImplementedError( + "target dtype [{}] is for amp o2 not supported yet.".format( + self.target_dtype + ) + ) + global __target_dtype__ + __target_dtype__ = __target_dtype + global __amp_utils__ + __amp_utils__ = amp_utils + amp_list = AMPList( set(self.get_attr("custom_white_list")), set(self.get_attr("custom_black_list")), None, @@ -747,7 +794,9 @@ class FP16Pass(AMPPass): main_program, amp_list, self.dist_context, - self.get_attr("use_fp16_guard"), + self.get_attr( + "use_fp16_guard" + ), # TODO unify to use_amp_guard to be compatible with amp o1 input_data_var_names, ) is_train = fp16_state._build_state() @@ -755,128 +804,130 @@ class FP16Pass(AMPPass): cast_startup_program() if is_train: - with paddle.static.program_guard(main_program, startup_program): - # TODO (JZ-LIANG)support cast forward program only when inference - self._init_amp_var() - self._scale_loss() - - grads, fp32_grads, fp16_grads = _split_grads(params_grads) - - if ( - self.get_attr("use_dynamic_loss_scaling") - or self.get_attr("init_loss_scaling") != 1.0 - ): - found_infs = [] - if fp32_grads: + if self.target_dtype == "fp16": + with paddle.static.program_guard(main_program, startup_program): + # TODO (JZ-LIANG)support cast forward program only when inference + self._init_amp_var() + self._scale_loss() + + grads, fp32_grads, fp16_grads = _split_grads(params_grads) + + if ( + self.get_attr("use_dynamic_loss_scaling") + or self.get_attr("init_loss_scaling") != 1.0 + ): + found_infs = [] + if fp32_grads: + with main_program._optimized_guard([]): + _, found_inf_fp32 = _check_and_update_gradient( + fp32_grads, + self._loss_scaling, + "@fp32", + self.dist_context, + ) + found_infs.append(found_inf_fp32) + if fp16_grads: + with main_program._optimized_guard([]): + _, found_inf_fp16 = _check_and_update_gradient( + fp16_grads, + self._loss_scaling, + "@fp16", + self.dist_context, + ) + found_infs.append(found_inf_fp16) with main_program._optimized_guard([]): - _, found_inf_fp32 = _check_and_update_gradient( - fp32_grads, - self._loss_scaling, - "@fp32", + block = main_program.global_block() + + # all_infs = paddle.fluid.layers.concat(found_infs) + all_infs = block.create_var( + name=paddle.utils.unique_name.generate_with_ignorable_key( + ".".join(['concat', 'tmp']) + ), + dtype=found_infs[0].dtype, + shape=None, + lod_level=found_infs[0].lod_level, + type=found_infs[0].type, + persistable=False, + stop_gradient=False, + ) + concat_op = block.append_op( + type='concat', + inputs={'X': found_infs}, + outputs={'Out': [all_infs]}, + attrs={'axis': 0}, + ) + set_var_dist_attr( self.dist_context, + all_infs, + [-1], + world_process_group.ranks, ) - found_infs.append(found_inf_fp32) - if fp16_grads: - with main_program._optimized_guard([]): - _, found_inf_fp16 = _check_and_update_gradient( - fp16_grads, - self._loss_scaling, - "@fp16", + _set_op_dist_attr_with_ranks( + concat_op, + world_process_group.ranks, + block, self.dist_context, ) - found_infs.append(found_inf_fp16) - with main_program._optimized_guard([]): - block = main_program.global_block() - - # all_infs = paddle.fluid.layers.concat(found_infs) - all_infs = block.create_var( - name=paddle.utils.unique_name.generate_with_ignorable_key( - ".".join(['concat', 'tmp']) - ), - dtype=found_infs[0].dtype, - shape=None, - lod_level=found_infs[0].lod_level, - type=found_infs[0].type, - persistable=False, - stop_gradient=False, - ) - concat_op = block.append_op( - type='concat', - inputs={'X': found_infs}, - outputs={'Out': [all_infs]}, - attrs={'axis': 0}, - ) - set_var_dist_attr( - self.dist_context, - all_infs, - [-1], - world_process_group.ranks, - ) - _set_op_dist_attr_with_ranks( - concat_op, - world_process_group.ranks, - block, - self.dist_context, - ) - # found_inf = paddle.fluid.layers.reduce_any(all_infs) - found_inf = block.create_var( - name=paddle.utils.unique_name.generate_with_ignorable_key( - ".".join(['reduce_any', 'tmp']) - ), - dtype=all_infs.dtype, - shape=None, - lod_level=all_infs.lod_level, - type=all_infs.type, - persistable=False, - stop_gradient=False, - ) - reduce_any_op = block.append_op( - type='reduce_any', - inputs={'X': all_infs}, - outputs={'Out': found_inf}, - attrs={ - 'dim': [0], - 'keep_dim': False, - 'reduce_all': True, - }, - ) - set_var_dist_attr( - self.dist_context, - found_inf, - [-1], - world_process_group.ranks, - ) - _set_op_dist_attr_with_ranks( - reduce_any_op, - world_process_group.ranks, - block, - self.dist_context, - ) + # found_inf = paddle.fluid.layers.reduce_any(all_infs) + found_inf = block.create_var( + name=paddle.utils.unique_name.generate_with_ignorable_key( + ".".join(['reduce_any', 'tmp']) + ), + dtype=all_infs.dtype, + shape=None, + lod_level=all_infs.lod_level, + type=all_infs.type, + persistable=False, + stop_gradient=False, + ) + reduce_any_op = block.append_op( + type='reduce_any', + inputs={'X': all_infs}, + outputs={'Out': found_inf}, + attrs={ + 'dim': [0], + 'keep_dim': False, + 'reduce_all': True, + }, + ) + set_var_dist_attr( + self.dist_context, + found_inf, + [-1], + world_process_group.ranks, + ) + _set_op_dist_attr_with_ranks( + reduce_any_op, + world_process_group.ranks, + block, + self.dist_context, + ) - if self.get_attr("use_dynamic_loss_scaling"): - with main_program._optimized_guard([]): - if fp32_grads: - self._update_loss_scaling(fp32_grads, found_inf) - if fp16_grads: - self._update_loss_scaling(fp16_grads, found_inf) + if self.get_attr("use_dynamic_loss_scaling"): + with main_program._optimized_guard([]): + if fp32_grads: + self._update_loss_scaling(fp32_grads, found_inf) + if fp16_grads: + self._update_loss_scaling(fp16_grads, found_inf) # modify optimizer base_opt = self.get_attr("base_opt") base_opt._multi_precision = True - if self.get_attr("use_optimizer_fp16"): + if self.use_optimizer_fp16: base_opt._multi_precision = False - if isinstance( - base_opt, - (paddle.static.Adam, paddle.optimizer.AdamW), - ): - with main_program._optimized_guard([]): - # found_inf = paddle.tensor.creation._memcpy( - # found_inf, paddle.CPUPlace()) - insert_idx = _get_memcopy_idx(block, found_inf) - found_inf = _insert_memcopy( - block, insert_idx, found_inf, self.dist_context - ) - base_opt._set_auxiliary_var('found_inf', found_inf.name) - elif hasattr(base_opt, "_set_auxiliary_var"): - base_opt._set_auxiliary_var('found_inf', found_inf.name) + + if self.target_dtype == "fp16": + if isinstance( + base_opt, (paddle.static.Adam, paddle.optimizer.AdamW) + ): + with main_program._optimized_guard([]): + # found_inf = paddle.tensor.creation._memcpy( + # found_inf, paddle.CPUPlace()) + insert_idx = _get_memcopy_idx(block, found_inf) + found_inf = _insert_memcopy( + block, insert_idx, found_inf, self.dist_context + ) + base_opt._set_auxiliary_var('found_inf', found_inf.name) + elif hasattr(base_opt, "_set_auxiliary_var"): + base_opt._set_auxiliary_var('found_inf', found_inf.name) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index cd9817f62bdaf9edcad8a5b65930fb83edf3a7be..ac0d834dd239e7a05fe124750e35504c81938b78 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -49,6 +49,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_pass_amp MODULES test_pass_amp) set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass) + set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) set_tests_properties(test_engine_callbacks PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..767b95c808330625148433983d3d6c893f12edda --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import re +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed.fleet import auto +from paddle.framework import core + +paddle.enable_static() + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def apply_pass(use_amp=False, amp_dtype="bfloat16"): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + + if use_amp: + amp = strategy.amp + amp.enable = True + amp.dtype = amp_dtype + amp.level = "o2" + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', + ] + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestShardingStage2WithNewEXE(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_amp=False, amp_dtype="bfloat16"): + reset_prog() + + strategy = apply_pass(use_amp, amp_dtype) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_bf16(self, program): + num_bf16 = 0 + num_fp16 = 0 + num_fp32 = 0 + + for p in program.all_parameters(): + if p.dtype == core.VarDesc.VarType.FP32: + num_fp32 += 1 + if p.dtype == core.VarDesc.VarType.FP16: + num_fp16 += 1 + if p.dtype == core.VarDesc.VarType.BF16: + num_bf16 += 1 + + self.assertEqual(num_bf16, 26) + self.assertEqual(num_fp16, 0) + self.assertEqual(num_fp32, 10) + + def test_param_grad_fuse_overlap(self): + # std + mp_engine = self.get_engine(use_amp=False) + mp_history = mp_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss0 = mp_history.history['loss'][0] + + # bf16 + mp_bf16_engine = self.get_engine(use_amp=True) + if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000: + return + + mp_bf16_history = mp_bf16_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss1 = mp_bf16_history.history['loss'][0] + np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2) + + self.check_bf16(mp_bf16_engine.main_program) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py index 388ab592e99330a532f0d0504e9c35bab7729b93..861747120b6d63ca429443b808ac91c8c00687b0 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py @@ -37,7 +37,7 @@ def apply_pass(use_amp=False, level=None): ] amp.init_loss_scaling = 32768 amp.use_fp16_guard = False - amp.use_pure_fp16 = level in ["o2", "o3"] + amp.level = level amp.use_optimizer_fp16 = level == "o3" print("amp level: ", level) return strategy diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/quantization_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/quantization_pass_unittest.py index 8b941e7a7fc6cc68e48d24828b81a462737bea10..e6788365ac2793b22734eb6ae5c177f4960eed78 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/quantization_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/quantization_pass_unittest.py @@ -39,7 +39,7 @@ def apply_pass(): ] amp.init_loss_scaling = 32768 amp.use_fp16_guard = False - amp.use_pure_fp16 = True + amp.level = "o2" qat = dist_strategy.qat qat.enable = True diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3d9797bc6799b8912b9160d68f3ba001868539 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestAMPO2(unittest.TestCase): + def test_bf16(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "amp_o2_pass.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 0cc83160908e2898edfea829acd7d5142c2118c2..4d1990a25b1158dff3ca6f14af2c6bc427ace1c4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -28,6 +28,8 @@ class TestStrategy(unittest.TestCase): amp = strategy.amp self.assertEqual(amp.enable, False) + self.assertAlmostEqual(amp.dtype, "float16") + self.assertAlmostEqual(amp.level, "o1") self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.decr_every_n_nan_or_inf, 2) @@ -37,15 +39,11 @@ class TestStrategy(unittest.TestCase): self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_black_varnames, []) - self.assertEqual(amp.use_pure_fp16, False) - self.assertEqual(amp.use_fp16_guard, True) + self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_optimizer_fp16, False) - - self.assertEqual(amp.enable_bf16, False) self.assertEqual(amp.custom_bf16_list, []) self.assertEqual(amp.custom_fp32_list, []) self.assertEqual(amp.custom_fp32_varnames, []) - self.assertEqual(amp.use_pure_bf16, False) self.assertEqual(amp.use_bf16_guard, False) sharding = strategy.sharding @@ -102,7 +100,6 @@ class TestStrategy(unittest.TestCase): amp.custom_white_list = ["x"] amp.custom_black_list = ["y"] amp.custom_black_varnames = ["z"] - amp.use_pure_fp16 = True amp.use_fp16_guard = False amp.use_optimizer_fp16 = True self.assertEqual(amp.enable, True) @@ -115,7 +112,6 @@ class TestStrategy(unittest.TestCase): self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_varnames, ["z"]) - self.assertEqual(amp.use_pure_fp16, True) self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_optimizer_fp16, True) diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index c3e00894eb4a9f0080032235a4c56d95780d9706..4b7e6215b2d3bdc230e88211ce8c35cf1466749e 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -14,6 +14,7 @@ import copy import warnings +from sqlite3 import NotSupportedError import paddle import paddle.autograd as imperative_base @@ -217,7 +218,9 @@ def _squared_l2_norm(x): return _C_ops.squared_l2_norm(x) op_type = 'squared_l2_norm' - check_variable_and_dtype(x, 'x', ['float32', 'float64', 'float16'], op_type) + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type + ) helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -557,6 +560,20 @@ def _allow_pure_fp16_global_norm_clip(*args): return old_value +_allow_pure_bf16_global_norm_clip_flag = False + + +def _allow_pure_bf16_global_norm_clip(*args): + global _allow_pure_bf16_global_norm_clip_flag + if len(args) == 0: + return _allow_pure_bf16_global_norm_clip_flag + else: + assert len(args) == 1 and isinstance(args[0], bool) + old_value = _allow_pure_bf16_global_norm_clip_flag + _allow_pure_bf16_global_norm_clip_flag = args[0] + return old_value + + class ClipGradByGlobalNorm(ClipGradBase): r""" Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in @@ -720,6 +737,7 @@ class ClipGradByGlobalNorm(ClipGradBase): params_and_grads = [] sum_square_list = [] sum_square_list_fp16 = [] + sum_square_list_bf16 = [] sum_square_list_fp32 = [] with framework.name_scope('gradient_clip'): for p, g in params_grads: @@ -735,17 +753,29 @@ class ClipGradByGlobalNorm(ClipGradBase): sum_square = _squared_l2_norm(merge_grad) if sum_square.dtype == core.VarDesc.VarType.FP16: sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.BF16: + sum_square_list_bf16.append(sum_square) elif sum_square.dtype == core.VarDesc.VarType.FP32: sum_square_list_fp32.append(sum_square) else: sum_square_list.append(sum_square) + if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0: + raise NotSupportedError( + 'FP16 and BF16 are not supported at the same time.' + ) + # all parameters have been filterd out if ( len(sum_square_list) + len(sum_square_list_fp16) + len(sum_square_list_fp32) == 0 + ) and ( + len(sum_square_list) + + len(sum_square_list_bf16) + + len(sum_square_list_fp32) + == 0 ): return params_grads @@ -765,6 +795,18 @@ class ClipGradByGlobalNorm(ClipGradBase): ) else: global_norm_var.append(global_norm_var_fp16) + if len(sum_square_list_bf16) > 0: + global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16) + if ( + sum_square_list_fp32 + or sum_square_list + or not _allow_pure_bf16_global_norm_clip() + ): + global_norm_var.append( + global_norm_var_bf16.astype(sum_dtype) + ) + else: + global_norm_var.append(global_norm_var_bf16) if len(sum_square_list_fp32) > 0: global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) if sum_dtype == 'float32': @@ -804,12 +846,18 @@ class ClipGradByGlobalNorm(ClipGradBase): with p.block.program._optimized_guard([p, g]): new_g = _cast_to_mp_type_if_enabled(g) # inplace - scale_input = ( - scale_var.astype('float16') - if new_g.dtype == core.VarDesc.VarType.FP16 + if ( + new_g.dtype == core.VarDesc.VarType.FP16 and scale_var.dtype != core.VarDesc.VarType.FP16 - else scale_var - ) + ): + scale_input = scale_var.astype('float16') + elif ( + new_g.dtype == core.VarDesc.VarType.BF16 + and scale_var.dtype != core.VarDesc.VarType.BF16 + ): + scale_input = scale_var.astype('bfloat16') + else: + scale_input = scale_var # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # will be in different blocks with the gradient clip related ops. # We need to handle the correct block, otherwise will encounter diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e254884c3eacf27315c17750e733f343fb4ce9e0..3fbc72c01f6bdfa61ec69a889a07c05a438232f1 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1657,14 +1657,21 @@ def add_n(inputs, name=None): check_variable_and_dtype( input, "inputs", - ['float16', 'float32', 'float64', 'int32', 'int64'], + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ], 'add_n', ) else: check_variable_and_dtype( inputs, "inputs", - ['float16', 'float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n', )