From 658387b0f22b16dbe0e87ba7b86ec4315cc9c4e4 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 10 Nov 2022 10:38:38 +0800 Subject: [PATCH] [AutoParallel] fix insert concat op (#47710) * fix insert concat op * fix fp16 assert --- .../distributed/auto_parallel/dist_loader.py | 2 +- .../distributed/passes/auto_parallel_fp16.py | 55 +++++++++++++++---- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index f16c55a011..f982f74589 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): raise StopIteration def _infer_steps(self): - if self.steps_per_epoch is not None: + if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 1: return self.steps_per_epoch try: if isinstance(self.dataset, IterableDataset): diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index a952986c21..4cb0c361fe 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -487,9 +487,8 @@ class FP16State: # create cast grad grad_slot_name = slot_name + "@GRAD" - assert ( - grad_slot_name in op.output_names - ), "[{}], Current Op: {}".format(grad_slot_name, str(op)) + if grad_slot_name not in op.output_names: + continue # some forward input maybe stop_gradient=True, e.g. input_mask if len(op.output(grad_slot_name)) == 0: @@ -785,33 +784,67 @@ class FP16Pass(AMPPass): with main_program._optimized_guard([]): block = main_program.global_block() - all_infs = paddle.fluid.layers.concat(found_infs) + # all_infs = paddle.fluid.layers.concat(found_infs) + all_infs = block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join(['concat', 'tmp']) + ), + dtype=found_infs[0].dtype, + shape=None, + lod_level=found_infs[0].lod_level, + type=found_infs[0].type, + persistable=False, + stop_gradient=False, + ) + concat_op = block.append_op( + type='concat', + inputs={'X': found_infs}, + outputs={'Out': [all_infs]}, + attrs={'axis': 0}, + ) set_var_dist_attr( self.dist_context, all_infs, [-1], world_process_group.ranks, ) - new_op = block.ops[-1] - assert new_op.type == "concat" _set_op_dist_attr_with_ranks( - new_op, + concat_op, world_process_group.ranks, block, self.dist_context, ) - found_inf = paddle.fluid.layers.reduce_any(all_infs) + # found_inf = paddle.fluid.layers.reduce_any(all_infs) + found_inf = block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join(['reduce_any', 'tmp']) + ), + dtype=all_infs.dtype, + shape=None, + lod_level=all_infs.lod_level, + type=all_infs.type, + persistable=False, + stop_gradient=False, + ) + reduce_any_op = block.append_op( + type='reduce_any', + inputs={'X': all_infs}, + outputs={'Out': found_inf}, + attrs={ + 'dim': [0], + 'keep_dim': False, + 'reduce_all': True, + }, + ) set_var_dist_attr( self.dist_context, found_inf, [-1], world_process_group.ranks, ) - new_op = block.ops[-1] - assert new_op.type == "reduce_any" _set_op_dist_attr_with_ranks( - new_op, + reduce_any_op, world_process_group.ranks, block, self.dist_context, -- GitLab