Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3cbf0e93
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3cbf0e93
编写于
9月 28, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
9月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dygraph pp] all sync for allgather partial (#46483)
上级
cee2b12d
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
109 addition
and
92 deletion
+109
-92
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+109
-92
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
3cbf0e93
...
@@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
...
@@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
def
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
):
rank_id
):
dst_rank_in_group
=
dst
if
group
is
None
else
group
.
get_group_rank
(
dst
)
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
return
_legacy_C_ops
.
partial_send
(
tensor
.
detach
(),
'use_calc_stream'
,
return
_legacy_C_ops
.
partial_send
(
tensor
.
detach
(),
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
dst
_rank_in_group
,
'num
'
,
'peer'
,
dst
,
'num'
,
nranks
,
'id
'
,
nranks
,
'id'
,
rank_id
)
rank_id
)
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
)
if
group
is
None
else
group
return
group
.
process_group
.
send_partial
(
tensor
,
dst_rank_in_group
,
return
group
.
process_group
.
send_partial
(
tensor
,
dst
,
nranks
,
rank_id
)
nranks
,
rank_id
)
def
send_partial
(
tensor
,
def
send_partial
(
tensor
,
...
@@ -189,13 +187,12 @@ def send_partial(tensor,
...
@@ -189,13 +187,12 @@ def send_partial(tensor,
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
dst_rank
=
_hcg
.
_get_p2p_next_rank
(
)
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
dst_rank
,
nranks
,
rank_id
)
nranks
,
rank_id
)
else
:
else
:
dst_rank
=
_hcg
.
_get_p2p_next_rank
(
)
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
send_op
=
paddle
.
distributed
.
send
send_op
=
paddle
.
distributed
.
send
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
...
@@ -205,22 +202,21 @@ def send_partial(tensor,
...
@@ -205,22 +202,21 @@ def send_partial(tensor,
def
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
def
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
):
rank_id
):
src_rank_in_group
=
src
if
group
is
None
else
group
.
get_group_rank
(
src
)
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
assert
use_calc_stream
assert
use_calc_stream
return
_legacy_C_ops
.
partial_recv
(
tensor
.
detach
(),
'use_calc_stream'
,
return
_legacy_C_ops
.
partial_recv
(
tensor
.
detach
(),
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
src_rank_in_group
,
'num'
,
'peer'
,
src
,
'num'
,
nranks
,
'id'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
tensor
.
dtype
,
'out_shape'
,
'out_shape'
,
tensor
.
shape
)
tensor
.
shape
)
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
)
if
group
is
None
else
group
task
=
group
.
process_group
.
recv_partial
(
tensor
,
src_rank_in_group
,
task
=
group
.
process_group
.
recv_partial
(
tensor
,
src
,
nranks
,
rank_id
)
nranks
,
rank_id
)
if
use_calc_stream
:
if
use_calc_stream
:
task
.
wait
()
task
.
wait
()
return
None
else
:
return
task
return
task
...
@@ -235,13 +231,12 @@ def recv_partial(tensor,
...
@@ -235,13 +231,12 @@ def recv_partial(tensor,
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
src_rank
=
_hcg
.
_get_p2p_prev_rank
(
)
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
src_rank
,
nranks
,
rank_id
)
nranks
,
rank_id
)
else
:
else
:
src_rank
=
_hcg
.
_get_p2p_prev_rank
(
)
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
if
_in_legacy_dygraph
()
or
use_calc_stream
:
if
_in_legacy_dygraph
()
or
use_calc_stream
:
recv_op
=
paddle
.
distributed
.
recv
recv_op
=
paddle
.
distributed
.
recv
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
...
@@ -260,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
...
@@ -260,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
)
if
group
is
None
else
group
return
group
.
process_group
.
all_gather_partial
(
tensor
,
tensor
,
nranks
,
task
=
group
.
process_group
.
all_gather_partial
(
tensor
,
tensor
,
nranks
,
rank_id
)
rank_id
)
if
use_calc_stream
:
task
.
wait
()
return
None
else
:
return
task
def
allgather_partial
(
tensor
,
def
allgather_partial
(
tensor
,
...
@@ -270,9 +270,9 @@ def allgather_partial(tensor,
...
@@ -270,9 +270,9 @@ def allgather_partial(tensor,
group
=
None
,
group
=
None
,
use_calc_stream
=
True
):
use_calc_stream
=
True
):
if
not
_is_valid_send_recv_partial
(
tensor
,
nranks
):
if
not
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
None
return
tensor
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
None
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
return
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
return
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
...
@@ -335,7 +335,6 @@ def _p2p_helper(tensor_send_next,
...
@@ -335,7 +335,6 @@ def _p2p_helper(tensor_send_next,
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
if
isinstance
(
tensor_send_prev
,
tuple
):
if
isinstance
(
tensor_send_prev
,
tuple
):
for
d
in
tensor_send_prev
:
for
d
in
tensor_send_prev
:
if
_in_legacy_dygraph
():
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
send_partial
(
d
,
dst
=
0
,
dst
=
0
,
...
@@ -344,7 +343,6 @@ def _p2p_helper(tensor_send_next,
...
@@ -344,7 +343,6 @@ def _p2p_helper(tensor_send_next,
group
=
_hcg
.
send_prev_group
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
)
use_calc_stream
=
False
)
else
:
else
:
if
_in_legacy_dygraph
():
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_prev
,
send_partial
(
tensor_send_prev
,
dst
=
0
,
dst
=
0
,
...
@@ -356,26 +354,39 @@ def _p2p_helper(tensor_send_next,
...
@@ -356,26 +354,39 @@ def _p2p_helper(tensor_send_next,
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
if
isinstance
(
tensor_recv_prev
,
tuple
):
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
for
d
in
tensor_recv_prev
:
tasks
.
append
(
task
=
recv_partial
(
d
,
recv_partial
(
d
,
src
=
0
,
src
=
0
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
))
use_calc_stream
=
sync_recv
)
if
sync_recv
:
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
else
:
else
:
tasks
.
append
(
tasks
.
append
(
task
)
recv_partial
(
tensor_recv_prev
,
else
:
task
=
recv_partial
(
tensor_recv_prev
,
src
=
0
,
src
=
0
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
))
use_calc_stream
=
sync_recv
)
if
sync_recv
:
allgather_partial
(
tensor_recv_prev
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
else
:
tasks
.
append
(
task
)
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
if
isinstance
(
tensor_send_next
,
tuple
):
if
isinstance
(
tensor_send_next
,
tuple
):
for
d
in
tensor_send_next
:
for
d
in
tensor_send_next
:
if
_in_legacy_dygraph
():
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
send_partial
(
d
,
dst
=
1
,
dst
=
1
,
...
@@ -384,7 +395,6 @@ def _p2p_helper(tensor_send_next,
...
@@ -384,7 +395,6 @@ def _p2p_helper(tensor_send_next,
group
=
_hcg
.
send_next_group
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
)
use_calc_stream
=
False
)
else
:
else
:
if
_in_legacy_dygraph
():
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_next
,
send_partial
(
tensor_send_next
,
dst
=
1
,
dst
=
1
,
...
@@ -396,24 +406,39 @@ def _p2p_helper(tensor_send_next,
...
@@ -396,24 +406,39 @@ def _p2p_helper(tensor_send_next,
if
tensor_recv_next
is
not
None
:
if
tensor_recv_next
is
not
None
:
if
isinstance
(
tensor_recv_next
,
tuple
):
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
for
d
in
tensor_recv_next
:
tasks
.
append
(
task
=
recv_partial
(
d
,
recv_partial
(
d
,
src
=
1
,
src
=
1
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
))
use_calc_stream
=
sync_recv
)
if
sync_recv
:
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
else
:
tasks
.
append
(
task
)
else
:
else
:
tasks
.
append
(
task
=
recv_partial
(
tensor_recv_next
,
recv_partial
(
tensor_recv_next
,
src
=
1
,
src
=
1
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
))
use_calc_stream
=
sync_recv
)
if
sync_recv
:
allgather_partial
(
tensor_recv_next
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
else
:
tasks
.
append
(
task
)
if
not
sync_recv
and
in_dygraph_mode
():
if
not
sync_recv
:
if
in_dygraph_mode
():
# wait irecv tasks in eager dygraph mode with new comm library
# wait irecv tasks in eager dygraph mode with new comm library
for
task
in
tasks
:
for
task
in
tasks
:
assert
task
is
not
None
assert
task
is
not
None
...
@@ -433,20 +458,12 @@ def _p2p_helper(tensor_send_next,
...
@@ -433,20 +458,12 @@ def _p2p_helper(tensor_send_next,
else
:
else
:
tensors_for_all_gather
.
append
(
tensor_recv_next
)
tensors_for_all_gather
.
append
(
tensor_recv_next
)
tasks
=
[]
for
tensor
in
tensors_for_all_gather
:
for
tensor
in
tensors_for_all_gather
:
tasks
.
append
(
allgather_partial
(
tensor
,
allgather_partial
(
tensor
,
nranks
=
mp_degree
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
rank_id
=
mp_rank
,
group
=
mp_group
,
group
=
mp_group
,
use_calc_stream
=
True
))
use_calc_stream
=
True
)
if
in_dygraph_mode
():
for
task
in
tasks
:
# wait partial all gather tasks
if
task
is
not
None
:
task
.
wait
()
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录