未验证 提交 6fee5a3e 编写于 作者: J JZ-LIANG 提交者: GitHub

auto tp sync (#53234)

* support tp sync for auto parallel

* support tp sync for auto parallel1

* support tp sync for auto parallel1

* support tp sync for auto parallel1
上级 efeeb6fb
......@@ -168,15 +168,20 @@ def insert_synchronization(
sync_param,
sync_grad,
sync_moment,
sync_master_param,
sync_mode,
src_rank,
):
unsync_param_names = [p.name for p in params_to_sync]
is_opt_block = False
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in _supported_optimizer_type:
is_opt_block = True
assert "Param" in op.input_names
assert len(op.input("Param")) == 1
param_name = op.input("Param")[0]
......@@ -203,6 +208,7 @@ def insert_synchronization(
op_role,
)
if sync_master_param:
if (
"MasterParamOut" in op.output_names
and len(op.output("MasterParamOut")) == 1
......@@ -270,6 +276,7 @@ def insert_synchronization(
op_role,
)
if is_opt_block:
assert (
len(unsync_param_names) == 0
), "The following param is unsync by some error: {}".format(
......@@ -285,6 +292,7 @@ def add_extra_synchronization(
sync_param=True,
sync_grad=False,
sync_moment=False,
sync_master_param=False,
src_rank=0,
sync_ring_id=None,
):
......@@ -317,11 +325,8 @@ def add_extra_synchronization(
)
)
# adopt for pipeline opt
if program._pipeline_opt is not None:
assert (
program._pipeline_opt['section_program'] is not None
), "Pipeline is enable but section_program is None"
# adopt for static pipeline opt
if 'section_program' in program._pipeline_opt:
program = program._pipeline_opt['section_program']
# step1: collect the param that need to be sync
......@@ -341,8 +346,10 @@ def add_extra_synchronization(
sync_ring_id = resolute_tensor_parallel_ring_id(program)
# step3: insert synchronization
# TODO support gradient merge with different update block
block = program.global_block()
# NOTE AutoParallel pass like gradient merge maywould move optimization ops to another block and add useless blocks into program.
# But those program would not be executed, therefore we add extra synchronization to all blocks that has optimizer operator.
# TODO support autoparallel resolute tp degree.
for block in program.blocks:
insert_synchronization(
block,
params_to_sync,
......@@ -351,6 +358,7 @@ def add_extra_synchronization(
sync_param,
sync_grad,
sync_moment,
sync_master_param,
sync_mode,
src_rank,
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册