未验证 提交 6fee5a3e 编写于 作者: J JZ-LIANG 提交者: GitHub

auto tp sync (#53234)

* support tp sync for auto parallel

* support tp sync for auto parallel1

* support tp sync for auto parallel1

* support tp sync for auto parallel1
上级 efeeb6fb
...@@ -168,15 +168,20 @@ def insert_synchronization( ...@@ -168,15 +168,20 @@ def insert_synchronization(
sync_param, sync_param,
sync_grad, sync_grad,
sync_moment, sync_moment,
sync_master_param,
sync_mode, sync_mode,
src_rank, src_rank,
): ):
unsync_param_names = [p.name for p in params_to_sync] unsync_param_names = [p.name for p in params_to_sync]
is_opt_block = False
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in _supported_optimizer_type: if op.type in _supported_optimizer_type:
is_opt_block = True
assert "Param" in op.input_names assert "Param" in op.input_names
assert len(op.input("Param")) == 1 assert len(op.input("Param")) == 1
param_name = op.input("Param")[0] param_name = op.input("Param")[0]
...@@ -203,6 +208,7 @@ def insert_synchronization( ...@@ -203,6 +208,7 @@ def insert_synchronization(
op_role, op_role,
) )
if sync_master_param:
if ( if (
"MasterParamOut" in op.output_names "MasterParamOut" in op.output_names
and len(op.output("MasterParamOut")) == 1 and len(op.output("MasterParamOut")) == 1
...@@ -270,6 +276,7 @@ def insert_synchronization( ...@@ -270,6 +276,7 @@ def insert_synchronization(
op_role, op_role,
) )
if is_opt_block:
assert ( assert (
len(unsync_param_names) == 0 len(unsync_param_names) == 0
), "The following param is unsync by some error: {}".format( ), "The following param is unsync by some error: {}".format(
...@@ -285,6 +292,7 @@ def add_extra_synchronization( ...@@ -285,6 +292,7 @@ def add_extra_synchronization(
sync_param=True, sync_param=True,
sync_grad=False, sync_grad=False,
sync_moment=False, sync_moment=False,
sync_master_param=False,
src_rank=0, src_rank=0,
sync_ring_id=None, sync_ring_id=None,
): ):
...@@ -317,11 +325,8 @@ def add_extra_synchronization( ...@@ -317,11 +325,8 @@ def add_extra_synchronization(
) )
) )
# adopt for pipeline opt # adopt for static pipeline opt
if program._pipeline_opt is not None: if 'section_program' in program._pipeline_opt:
assert (
program._pipeline_opt['section_program'] is not None
), "Pipeline is enable but section_program is None"
program = program._pipeline_opt['section_program'] program = program._pipeline_opt['section_program']
# step1: collect the param that need to be sync # step1: collect the param that need to be sync
...@@ -341,8 +346,10 @@ def add_extra_synchronization( ...@@ -341,8 +346,10 @@ def add_extra_synchronization(
sync_ring_id = resolute_tensor_parallel_ring_id(program) sync_ring_id = resolute_tensor_parallel_ring_id(program)
# step3: insert synchronization # step3: insert synchronization
# TODO support gradient merge with different update block # NOTE AutoParallel pass like gradient merge maywould move optimization ops to another block and add useless blocks into program.
block = program.global_block() # But those program would not be executed, therefore we add extra synchronization to all blocks that has optimizer operator.
# TODO support autoparallel resolute tp degree.
for block in program.blocks:
insert_synchronization( insert_synchronization(
block, block,
params_to_sync, params_to_sync,
...@@ -351,6 +358,7 @@ def add_extra_synchronization( ...@@ -351,6 +358,7 @@ def add_extra_synchronization(
sync_param, sync_param,
sync_grad, sync_grad,
sync_moment, sync_moment,
sync_master_param,
sync_mode, sync_mode,
src_rank, src_rank,
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册