Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9b5e0154
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9b5e0154
编写于
9月 02, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
9月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] DP Calc-Comm Overlapping Support Weight Sharing (#45443)
* bugfix (#45332) * customize wait_comm
上级
a4d2878a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
57 addition
and
22 deletion
+57
-22
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
...ibuted/passes/auto_parallel_data_parallel_optimization.py
+57
-22
未找到文件。
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
浏览文件 @
9b5e0154
...
...
@@ -93,8 +93,8 @@ class DataParallelOptimizationPass(PassBase):
def
_calc_comm_overlap
(
self
):
if
not
self
.
_could_be_overlap
():
return
self
.
_c
alc_overlap_comms
()
self
.
_
update
_wait_comms
()
self
.
_c
omms_overlap_calc
()
self
.
_
calc
_wait_comms
()
def
_fuse_allreduce
(
self
):
pass
...
...
@@ -227,7 +227,7 @@ class DataParallelOptimizationPass(PassBase):
return
True
def
_c
alc_overlap_comms
(
self
):
def
_c
omms_overlap_calc
(
self
):
# TODO support InterpreterCore executor for overlap.
# InterpreterCore has a different logic for overlapping
# which is different from use_calc_stream
...
...
@@ -254,27 +254,62 @@ class DataParallelOptimizationPass(PassBase):
block
.
_sync_with_cpp
()
def
_
update
_wait_comms
(
self
):
def
_
calc
_wait_comms
(
self
):
block
=
default_main_program
().
global_block
()
ops
=
block
.
ops
# update wait comm to finish
first_optimize_op_idx
=
-
1
for
idx
,
op
in
enumerate
(
ops
):
if
is_optimize_op
(
op
):
first_optimize_op_idx
=
idx
break
assert
first_optimize_op_idx
>
-
1
,
"Unexception: not found optimizer op in program"
# NOTE the naive overlap implement in static hybird parallel only sync comm stream
# at the end of Backward phase, based on a strong constraint that
# all communicating gradient would NOT be used after communication in Backward phase.
# BUT this constraint will fail for scenario like Weight-Sharing and Higher-Order Differentiation,
# where gradient will be involved in other calculation between data-parallel allreduce kernel submmited
# into comm streams and the synchronization of comm stream at the end of Backward phase.
# synchronization of comm stream should add according to the usage of communicating gradients
# to support Overlapping for Weight-Sharing and Higher-Order Differentiation.
ring_id_to_un_sync_grad_map
=
{}
op_idx_to_sync_ring_id_map
=
{}
for
group
in
self
.
_group_to_grad_name_map
.
keys
():
ring_id
=
group
.
id
block
.
_insert_op_without_sync
(
first_optimize_op_idx
,
type
=
'c_wait_comm'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
})
ring_id_to_un_sync_grad_map
[
group
.
id
]
=
[]
# analyze the where need to sync
for
i
,
op
in
enumerate
(
ops
):
if
is_data_parallel_reduce_op
(
op
):
ring_id
=
op
.
attr
(
"ring_id"
)
grad_name
=
op
.
output_arg_names
[
0
]
ring_id_to_un_sync_grad_map
[
ring_id
].
append
(
grad_name
)
elif
is_data_parallel_scale_op
(
op
):
continue
# other ops that might use communicating grad
else
:
for
input_var_name
in
op
.
input_arg_names
:
for
ring_id
,
unsync_grad_names
in
ring_id_to_un_sync_grad_map
.
items
(
):
if
input_var_name
in
unsync_grad_names
:
# need to sync before op_i
if
i
in
op_idx_to_sync_ring_id_map
:
op_idx_to_sync_ring_id_map
[
i
].
append
(
ring_id
)
else
:
op_idx_to_sync_ring_id_map
[
i
]
=
[
ring_id
]
# all grads in this comm stream are synced
ring_id_to_un_sync_grad_map
[
ring_id
]
=
[]
# insert synchronization
indices
=
list
(
op_idx_to_sync_ring_id_map
.
keys
())
# TODO the synchronization could be optimized
# we should record the event of a gradient is communicating and
# only wait for that event to be completed.
# BUT paddle static currently not support op api for event record only, so
# here we try to wait for all kernel in that comm stream to be finish which is not that optimized.
for
i
in
sorted
(
indices
,
reverse
=
True
):
for
ring_id
in
op_idx_to_sync_ring_id_map
[
i
]:
block
.
_insert_op_without_sync
(
i
,
type
=
'c_wait_comm'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
})
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录