From 6fee5a3e7af253b43faae267f3ec59ee7adb254b Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 25 Apr 2023 21:02:16 +0800 Subject: [PATCH] auto tp sync (#53234) * support tp sync for auto parallel * support tp sync for auto parallel1 * support tp sync for auto parallel1 * support tp sync for auto parallel1 --- .../fleet/utils/tensor_parallel_utils.py | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py b/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py index 85eab5e26fb..aeb6fa69df3 100644 --- a/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py +++ b/python/paddle/distributed/fleet/utils/tensor_parallel_utils.py @@ -168,15 +168,20 @@ def insert_synchronization( sync_param, sync_grad, sync_moment, + sync_master_param, sync_mode, src_rank, ): unsync_param_names = [p.name for p in params_to_sync] + is_opt_block = False for idx, op in reversed(list(enumerate(block.ops))): if op.type in _supported_optimizer_type: + + is_opt_block = True + assert "Param" in op.input_names assert len(op.input("Param")) == 1 param_name = op.input("Param")[0] @@ -203,6 +208,7 @@ def insert_synchronization( op_role, ) + if sync_master_param: if ( "MasterParamOut" in op.output_names and len(op.output("MasterParamOut")) == 1 @@ -270,11 +276,12 @@ def insert_synchronization( op_role, ) - assert ( - len(unsync_param_names) == 0 - ), "The following param is unsync by some error: {}".format( - unsync_param_names - ) + if is_opt_block: + assert ( + len(unsync_param_names) == 0 + ), "The following param is unsync by some error: {}".format( + unsync_param_names + ) def add_extra_synchronization( @@ -285,6 +292,7 @@ def add_extra_synchronization( sync_param=True, sync_grad=False, sync_moment=False, + sync_master_param=False, src_rank=0, sync_ring_id=None, ): @@ -317,11 +325,8 @@ def add_extra_synchronization( ) ) - # adopt for pipeline opt - if program._pipeline_opt is not None: - assert ( - program._pipeline_opt['section_program'] is not None - ), "Pipeline is enable but section_program is None" + # adopt for static pipeline opt + if 'section_program' in program._pipeline_opt: program = program._pipeline_opt['section_program'] # step1: collect the param that need to be sync @@ -341,16 +346,19 @@ def add_extra_synchronization( sync_ring_id = resolute_tensor_parallel_ring_id(program) # step3: insert synchronization - # TODO support gradient merge with different update block - block = program.global_block() - insert_synchronization( - block, - params_to_sync, - tp_degree, - sync_ring_id, - sync_param, - sync_grad, - sync_moment, - sync_mode, - src_rank, - ) + # NOTE AutoParallel pass like gradient merge maywould move optimization ops to another block and add useless blocks into program. + # But those program would not be executed, therefore we add extra synchronization to all blocks that has optimizer operator. + # TODO support autoparallel resolute tp degree. + for block in program.blocks: + insert_synchronization( + block, + params_to_sync, + tp_degree, + sync_ring_id, + sync_param, + sync_grad, + sync_moment, + sync_master_param, + sync_mode, + src_rank, + ) -- GitLab