Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9be2b721
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9be2b721
编写于
10月 21, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
10月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix virtualpp with mp/recompute bugs (#47242)
上级
a9ac608f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
16 deletion
+22
-16
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+2
-1
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+20
-15
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
9be2b721
...
...
@@ -598,11 +598,12 @@ class PipelineLayer(Layer):
return
run_function
def
forward_function
(
self
,
start
,
end
):
run_function
=
self
.
run_function
def
execute_func
(
*
x
):
if
len
(
x
)
==
1
:
x
=
x
[
0
]
for
idx
,
layer
in
enumerate
(
self
.
run_function
[
start
:
end
]):
for
idx
,
layer
in
enumerate
(
run_function
[
start
:
end
]):
x
=
layer
(
x
)
return
x
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
9be2b721
...
...
@@ -168,17 +168,18 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
):
dst_rank_in_group
=
dst
if
group
is
None
else
group
.
get_group_rank
(
dst
)
if
_in_legacy_dygraph
():
return
_legacy_C_ops
.
partial_send
(
tensor
.
detach
(),
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
dst
,
'num'
,
nranks
,
'id
'
,
rank_id
)
'peer'
,
dst
_rank_in_group
,
'num
'
,
nranks
,
'id'
,
rank_id
)
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
comm_op
=
group
.
process_group
.
send_partial_on_calc_stream
\
if
use_calc_stream
else
group
.
process_group
.
send_partial
return
comm_op
(
tensor
,
dst
,
nranks
,
rank_id
)
return
comm_op
(
tensor
,
dst
_rank_in_group
,
nranks
,
rank_id
)
def
send_partial
(
tensor
,
...
...
@@ -192,12 +193,13 @@ def send_partial(tensor,
return
ring_id
=
0
if
group
is
None
else
group
.
id
dst_rank
=
_hcg
.
_get_p2p_next_rank
(
)
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
)
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst_rank
,
nranks
,
rank_id
)
else
:
dst_rank
=
_hcg
.
_get_p2p_next_rank
(
)
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
if
_in_legacy_dygraph
():
send_op
=
lambda
x
,
dst
,
group
:
\
paddle
.
distributed
.
send
(
x
,
dst
,
group
,
use_calc_stream
)
...
...
@@ -208,19 +210,21 @@ def send_partial(tensor,
def
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
):
src_rank_in_group
=
src
if
group
is
None
else
group
.
get_group_rank
(
src
)
if
_in_legacy_dygraph
():
assert
use_calc_stream
return
_legacy_C_ops
.
partial_recv
(
tensor
.
detach
(),
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
src
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
tensor
.
shape
)
'peer'
,
src_rank_in_group
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
tensor
.
shape
)
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
comm_op
=
group
.
process_group
.
recv_partial_on_calc_stream
\
if
use_calc_stream
else
group
.
process_group
.
recv_partial
return
comm_op
(
tensor
,
src
,
nranks
,
rank_id
)
return
comm_op
(
tensor
,
src
_rank_in_group
,
nranks
,
rank_id
)
def
recv_partial
(
tensor
,
...
...
@@ -234,12 +238,13 @@ def recv_partial(tensor,
return
ring_id
=
0
if
group
is
None
else
group
.
id
src_rank
=
_hcg
.
_get_p2p_prev_rank
(
)
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
)
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src_rank
,
nranks
,
rank_id
)
else
:
src_rank
=
_hcg
.
_get_p2p_prev_rank
(
)
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
if
_in_legacy_dygraph
()
or
use_calc_stream
:
recv_op
=
paddle
.
distributed
.
recv
elif
in_dygraph_mode
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录