Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5a1b6f5d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5a1b6f5d
编写于
2月 12, 2023
作者:
S
ShenLiang
提交者:
GitHub
2月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add p2p (#50337)
上级
913f40ee
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
306 addition
and
194 deletion
+306
-194
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+306
-194
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
5a1b6f5d
...
@@ -17,7 +17,11 @@ from ...utils.log_util import logger
...
@@ -17,7 +17,11 @@ from ...utils.log_util import logger
import
numpy
as
np
import
numpy
as
np
from
paddle
import
_C_ops
,
_legacy_C_ops
from
paddle
import
_C_ops
,
_legacy_C_ops
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
paddle.fluid.framework
import
_in_legacy_dygraph
,
_non_static_mode
,
in_dygraph_mode
from
paddle.fluid.framework
import
(
_in_legacy_dygraph
,
_non_static_mode
,
in_dygraph_mode
,
)
from
.utils
import
paddle_2_number
,
paddle_2_number
,
number_2_dtype
from
.utils
import
paddle_2_number
,
paddle_2_number
,
number_2_dtype
_hcg
=
None
_hcg
=
None
...
@@ -30,12 +34,23 @@ def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
...
@@ -30,12 +34,23 @@ def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
_hcg
=
hcg
_hcg
=
hcg
_use_cache
=
use_cache
_use_cache
=
use_cache
_enable_partial_send_recv
=
enable_partial_send_recv
_enable_partial_send_recv
=
enable_partial_send_recv
send_next_group
,
send_prev_group
,
recv_next_group
,
recv_prev_group
=
_hcg
.
get_p2p_groups
(
(
send_next_group
,
send_prev_group
,
recv_next_group
,
recv_prev_group
,
)
=
_hcg
.
get_p2p_groups
()
debug_str
=
(
"P2pInfo: send_next_group: %s, send_prev_group: %s, "
"recv_next_group: %s, recv_prev_group: %s"
%
(
repr
(
send_next_group
),
repr
(
send_prev_group
),
repr
(
recv_next_group
),
repr
(
recv_prev_group
),
)
)
)
debug_str
=
"P2pInfo: send_next_group: %s, send_prev_group: %s, "
\
"recv_next_group: %s, recv_prev_group: %s"
%
(
repr
(
send_next_group
),
repr
(
send_prev_group
),
repr
(
recv_next_group
),
repr
(
recv_prev_group
))
logger
.
info
(
debug_str
)
logger
.
info
(
debug_str
)
...
@@ -150,9 +165,15 @@ class SendRecvMeta:
...
@@ -150,9 +165,15 @@ class SendRecvMeta:
self
.
send_dtype_message
=
paddle_2_number
(
tensor
.
dtype
)
self
.
send_dtype_message
=
paddle_2_number
(
tensor
.
dtype
)
elif
isinstance
(
tensor
,
tuple
):
elif
isinstance
(
tensor
,
tuple
):
self
.
send_shape_message
=
tuple
(
self
.
send_shape_message
=
tuple
(
[
d
.
shape
for
d
in
tensor
if
not
d
.
stop_gradient
])
[
d
.
shape
for
d
in
tensor
if
not
d
.
stop_gradient
]
)
self
.
send_dtype_message
=
tuple
(
self
.
send_dtype_message
=
tuple
(
[
paddle_2_number
(
d
.
dtype
)
for
d
in
tensor
])
[
paddle_2_number
(
d
.
dtype
)
for
d
in
tensor
if
not
d
.
stop_gradient
]
)
_send_recv_meta
=
SendRecvMeta
()
_send_recv_meta
=
SendRecvMeta
()
...
@@ -166,84 +187,117 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
...
@@ -166,84 +187,117 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
return
mp_degree
>
1
and
tensor_numel
%
mp_degree
==
0
return
mp_degree
>
1
and
tensor_numel
%
mp_degree
==
0
def
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
def
_partial_send_op
(
rank_id
):
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
):
dst_rank_in_group
=
dst
if
group
is
None
else
group
.
get_group_rank
(
dst
)
dst_rank_in_group
=
dst
if
group
is
None
else
group
.
get_group_rank
(
dst
)
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
return
_legacy_C_ops
.
partial_send
(
tensor
.
detach
(),
'use_calc_stream'
,
return
_legacy_C_ops
.
partial_send
(
use_calc_stream
,
'ring_id'
,
ring_id
,
tensor
.
detach
(),
'peer'
,
dst_rank_in_group
,
'num'
,
'use_calc_stream'
,
nranks
,
'id'
,
rank_id
)
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
dst_rank_in_group
,
'num'
,
nranks
,
'id'
,
rank_id
,
)
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
group
=
(
)
if
group
is
None
else
group
paddle
.
distributed
.
collective
.
_get_default_group
()
comm_op
=
group
.
process_group
.
send_partial_on_calc_stream
\
if
group
is
None
if
use_calc_stream
else
group
.
process_group
.
send_partial
else
group
)
comm_op
=
(
group
.
process_group
.
send_partial_on_calc_stream
if
use_calc_stream
else
group
.
process_group
.
send_partial
)
return
comm_op
(
tensor
,
dst_rank_in_group
,
nranks
,
rank_id
)
return
comm_op
(
tensor
,
dst_rank_in_group
,
nranks
,
rank_id
)
def
send_partial
(
tensor
,
def
send_partial
(
dst
=
0
,
tensor
,
dst
=
0
,
nranks
=
1
,
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
nranks
=
1
,
):
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
):
# dst: local rank in group
# dst: local rank in group
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
dst_rank
=
_hcg
.
_get_p2p_next_rank
(
dst_rank
=
(
)
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
_hcg
.
_get_p2p_next_rank
()
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
)
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
return
_partial_send_op
(
dst_rank
,
nranks
,
rank_id
)
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst_rank
,
nranks
,
rank_id
)
else
:
else
:
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
send_op
=
lambda
x
,
dst
,
group
:
\
send_op
=
lambda
x
,
dst
,
group
:
paddle
.
distributed
.
send
(
paddle
.
distributed
.
send
(
x
,
dst
,
group
,
use_calc_stream
)
x
,
dst
,
group
,
use_calc_stream
)
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
send_op
=
paddle
.
distributed
.
isend
send_op
=
paddle
.
distributed
.
isend
return
send_op
(
tensor
.
detach
(),
dst
=
dst_rank
,
group
=
group
)
return
send_op
(
tensor
.
detach
(),
dst
=
dst_rank
,
group
=
group
)
def
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
def
_partial_recv_op
(
rank_id
):
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
):
src_rank_in_group
=
src
if
group
is
None
else
group
.
get_group_rank
(
src
)
src_rank_in_group
=
src
if
group
is
None
else
group
.
get_group_rank
(
src
)
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
assert
use_calc_stream
assert
use_calc_stream
return
_legacy_C_ops
.
partial_recv
(
tensor
.
detach
(),
'use_calc_stream'
,
return
_legacy_C_ops
.
partial_recv
(
use_calc_stream
,
'ring_id'
,
ring_id
,
tensor
.
detach
(),
'peer'
,
src_rank_in_group
,
'num'
,
'use_calc_stream'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
use_calc_stream
,
tensor
.
dtype
,
'out_shape'
,
'ring_id'
,
tensor
.
shape
)
ring_id
,
'peer'
,
src_rank_in_group
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
tensor
.
shape
,
)
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
group
=
(
)
if
group
is
None
else
group
paddle
.
distributed
.
collective
.
_get_default_group
()
comm_op
=
group
.
process_group
.
recv_partial_on_calc_stream
\
if
group
is
None
if
use_calc_stream
else
group
.
process_group
.
recv_partial
else
group
)
comm_op
=
(
group
.
process_group
.
recv_partial_on_calc_stream
if
use_calc_stream
else
group
.
process_group
.
recv_partial
)
return
comm_op
(
tensor
,
src_rank_in_group
,
nranks
,
rank_id
)
return
comm_op
(
tensor
,
src_rank_in_group
,
nranks
,
rank_id
)
def
recv_partial
(
tensor
,
def
recv_partial
(
src
=
0
,
tensor
,
src
=
0
,
nranks
=
1
,
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
nranks
=
1
,
):
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
):
# src: local rank in group
# src: local rank in group
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
src_rank
=
_hcg
.
_get_p2p_prev_rank
(
src_rank
=
(
)
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
_hcg
.
_get_p2p_prev_rank
()
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
)
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
return
_partial_recv_op
(
src_rank
,
nranks
,
rank_id
)
tensor
,
group
,
use_calc_stream
,
ring_id
,
src_rank
,
nranks
,
rank_id
)
else
:
else
:
if
_in_legacy_dygraph
()
or
use_calc_stream
:
if
_in_legacy_dygraph
()
or
use_calc_stream
:
recv_op
=
paddle
.
distributed
.
recv
recv_op
=
paddle
.
distributed
.
recv
...
@@ -252,42 +306,52 @@ def recv_partial(tensor,
...
@@ -252,42 +306,52 @@ def recv_partial(tensor,
return
recv_op
(
tensor
.
detach
(),
src
=
src_rank
,
group
=
group
)
return
recv_op
(
tensor
.
detach
(),
src
=
src_rank
,
group
=
group
)
def
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
nranks
,
def
_partial_allgather_op
(
rank_id
):
tensor
,
group
,
use_calc_stream
,
ring_id
,
nranks
,
rank_id
):
if
_in_legacy_dygraph
():
if
_in_legacy_dygraph
():
return
_legacy_C_ops
.
partial_allgather_
(
tensor
.
detach
(),
return
_legacy_C_ops
.
partial_allgather_
(
'use_calc_stream'
,
tensor
.
detach
(),
use_calc_stream
,
'ring_id'
,
'use_calc_stream'
,
ring_id
,
'nranks'
,
nranks
,
use_calc_stream
,
'rank'
,
rank_id
)
'ring_id'
,
ring_id
,
'nranks'
,
nranks
,
'rank'
,
rank_id
,
)
elif
in_dygraph_mode
():
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
group
=
(
)
if
group
is
None
else
group
paddle
.
distributed
.
collective
.
_get_default_group
()
comm_op
=
group
.
process_group
.
all_gather_partial_on_calc_stream
\
if
group
is
None
if
use_calc_stream
else
group
.
process_group
.
all_gather_partial
else
group
)
comm_op
=
(
group
.
process_group
.
all_gather_partial_on_calc_stream
if
use_calc_stream
else
group
.
process_group
.
all_gather_partial
)
return
comm_op
(
tensor
,
tensor
,
nranks
,
rank_id
)
return
comm_op
(
tensor
,
tensor
,
nranks
,
rank_id
)
def
allgather_partial
(
tensor
,
def
allgather_partial
(
nranks
=
1
,
tensor
,
nranks
=
1
,
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
rank_id
=
0
,
):
group
=
None
,
use_calc_stream
=
True
):
if
not
_is_valid_send_recv_partial
(
tensor
,
nranks
):
if
not
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
tensor
return
tensor
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
return
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
return
_partial_allgather_op
(
nranks
,
rank_id
)
tensor
,
group
,
use_calc_stream
,
ring_id
,
nranks
,
rank_id
)
def
_p2p_helper
(
tensor_send_next
,
def
_p2p_helper
(
tensor_send_prev
,
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
sync_recv
=
True
recv_prev
,
):
recv_next
,
sync_recv
=
True
):
global
_hcg
global
_hcg
tensor_recv_prev
=
None
tensor_recv_prev
=
None
...
@@ -310,15 +374,17 @@ def _p2p_helper(tensor_send_next,
...
@@ -310,15 +374,17 @@ def _p2p_helper(tensor_send_next,
if
isinstance
(
recv_shape_msg
,
tuple
):
if
isinstance
(
recv_shape_msg
,
tuple
):
tensor_recv_prev
=
[]
tensor_recv_prev
=
[]
for
idx
,
shape
in
enumerate
(
recv_shape_msg
):
for
idx
,
shape
in
enumerate
(
recv_shape_msg
):
tmp
=
paddle
.
empty
(
shape
=
shape
,
tmp
=
paddle
.
empty
(
dtype
=
number_2_dtype
(
recv_dtype_msg
[
idx
]))
shape
=
shape
,
dtype
=
number_2_dtype
(
recv_dtype_msg
[
idx
])
)
tmp
.
stop_gradient
=
recv_stop_gradient
[
idx
]
tmp
.
stop_gradient
=
recv_stop_gradient
[
idx
]
tensor_recv_prev
.
append
(
tmp
)
tensor_recv_prev
.
append
(
tmp
)
tensor_recv_prev
=
tuple
(
tensor_recv_prev
)
tensor_recv_prev
=
tuple
(
tensor_recv_prev
)
else
:
else
:
tensor_recv_prev
=
paddle
.
empty
(
tensor_recv_prev
=
paddle
.
empty
(
shape
=
recv_shape_msg
,
dtype
=
number_2_dtype
(
recv_dtype_msg
))
shape
=
recv_shape_msg
,
dtype
=
number_2_dtype
(
recv_dtype_msg
)
)
tensor_recv_prev
.
stop_gradient
=
recv_stop_gradient
tensor_recv_prev
.
stop_gradient
=
recv_stop_gradient
if
recv_next
:
if
recv_next
:
...
@@ -326,12 +392,15 @@ def _p2p_helper(tensor_send_next,
...
@@ -326,12 +392,15 @@ def _p2p_helper(tensor_send_next,
tensor_recv_next
=
[]
tensor_recv_next
=
[]
for
idx
,
shape
in
enumerate
(
send_shape_msg
):
for
idx
,
shape
in
enumerate
(
send_shape_msg
):
tensor_recv_next
.
append
(
tensor_recv_next
.
append
(
paddle
.
empty
(
shape
=
shape
,
paddle
.
empty
(
dtype
=
number_2_dtype
(
send_dtype_msg
[
idx
])))
shape
=
shape
,
dtype
=
number_2_dtype
(
send_dtype_msg
[
idx
])
)
)
tensor_recv_next
=
tuple
(
tensor_recv_next
)
tensor_recv_next
=
tuple
(
tensor_recv_next
)
else
:
else
:
tensor_recv_next
=
paddle
.
empty
(
tensor_recv_next
=
paddle
.
empty
(
shape
=
send_shape_msg
,
dtype
=
number_2_dtype
(
send_dtype_msg
))
shape
=
send_shape_msg
,
dtype
=
number_2_dtype
(
send_dtype_msg
)
)
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks
=
[]
tasks
=
[]
...
@@ -340,51 +409,63 @@ def _p2p_helper(tensor_send_next,
...
@@ -340,51 +409,63 @@ def _p2p_helper(tensor_send_next,
if
isinstance
(
tensor_send_prev
,
tuple
):
if
isinstance
(
tensor_send_prev
,
tuple
):
for
d
in
tensor_send_prev
:
for
d
in
tensor_send_prev
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
send_partial
(
dst
=
0
,
d
,
nranks
=
mp_degree
,
dst
=
0
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
send_prev_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
False
)
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
,
)
else
:
else
:
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_prev
,
send_partial
(
dst
=
0
,
tensor_send_prev
,
nranks
=
mp_degree
,
dst
=
0
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
send_prev_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
False
)
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
,
)
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
if
isinstance
(
tensor_recv_prev
,
tuple
):
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
for
d
in
tensor_recv_prev
:
task
=
recv_partial
(
d
,
task
=
recv_partial
(
src
=
0
,
d
,
nranks
=
mp_degree
,
src
=
0
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
recv_prev_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
sync_recv
)
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
if
sync_recv
:
allgather_partial
(
d
,
allgather_partial
(
nranks
=
mp_degree
,
d
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
mp_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
True
)
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
else
:
tasks
.
append
(
task
)
tasks
.
append
(
task
)
else
:
else
:
task
=
recv_partial
(
tensor_recv_prev
,
task
=
recv_partial
(
src
=
0
,
tensor_recv_prev
,
nranks
=
mp_degree
,
src
=
0
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
recv_prev_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
sync_recv
)
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
if
sync_recv
:
allgather_partial
(
tensor_recv_prev
,
allgather_partial
(
nranks
=
mp_degree
,
tensor_recv_prev
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
mp_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
True
)
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
else
:
tasks
.
append
(
task
)
tasks
.
append
(
task
)
...
@@ -392,52 +473,64 @@ def _p2p_helper(tensor_send_next,
...
@@ -392,52 +473,64 @@ def _p2p_helper(tensor_send_next,
if
isinstance
(
tensor_send_next
,
tuple
):
if
isinstance
(
tensor_send_next
,
tuple
):
for
d
in
tensor_send_next
:
for
d
in
tensor_send_next
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
send_partial
(
dst
=
1
,
d
,
nranks
=
mp_degree
,
dst
=
1
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
send_next_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
False
)
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
,
)
else
:
else
:
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_next
,
send_partial
(
dst
=
1
,
tensor_send_next
,
nranks
=
mp_degree
,
dst
=
1
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
send_next_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
False
)
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
,
)
if
tensor_recv_next
is
not
None
:
if
tensor_recv_next
is
not
None
:
if
isinstance
(
tensor_recv_next
,
tuple
):
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
for
d
in
tensor_recv_next
:
task
=
recv_partial
(
d
,
task
=
recv_partial
(
src
=
1
,
d
,
nranks
=
mp_degree
,
src
=
1
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
recv_next_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
sync_recv
)
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
if
sync_recv
:
allgather_partial
(
d
,
allgather_partial
(
nranks
=
mp_degree
,
d
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
mp_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
True
)
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
else
:
tasks
.
append
(
task
)
tasks
.
append
(
task
)
else
:
else
:
task
=
recv_partial
(
tensor_recv_next
,
task
=
recv_partial
(
src
=
1
,
tensor_recv_next
,
nranks
=
mp_degree
,
src
=
1
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
_hcg
.
recv_next_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
sync_recv
)
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
if
sync_recv
:
allgather_partial
(
tensor_recv_next
,
allgather_partial
(
nranks
=
mp_degree
,
tensor_recv_next
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
mp_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
True
)
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
else
:
tasks
.
append
(
task
)
tasks
.
append
(
task
)
...
@@ -463,11 +556,13 @@ def _p2p_helper(tensor_send_next,
...
@@ -463,11 +556,13 @@ def _p2p_helper(tensor_send_next,
tensors_for_all_gather
.
append
(
tensor_recv_next
)
tensors_for_all_gather
.
append
(
tensor_recv_next
)
for
tensor
in
tensors_for_all_gather
:
for
tensor
in
tensors_for_all_gather
:
allgather_partial
(
tensor
,
allgather_partial
(
nranks
=
mp_degree
,
tensor
,
rank_id
=
mp_rank
,
nranks
=
mp_degree
,
group
=
mp_group
,
rank_id
=
mp_rank
,
use_calc_stream
=
True
)
group
=
mp_group
,
use_calc_stream
=
True
,
)
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
...
@@ -480,11 +575,13 @@ def recv_forward(pp_first_stage, sync_recv=True):
...
@@ -480,11 +575,13 @@ def recv_forward(pp_first_stage, sync_recv=True):
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
input_tensor
,
_
=
_p2p_helper
(
tensor_send_prev
=
None
,
tensor_send_next
=
None
,
recv_prev
=
True
,
tensor_send_prev
=
None
,
recv_next
=
False
,
recv_prev
=
True
,
sync_recv
=
sync_recv
)
recv_next
=
False
,
sync_recv
=
sync_recv
,
)
return
input_tensor
return
input_tensor
...
@@ -492,11 +589,13 @@ def recv_backward(pp_last_stage, sync_recv=True):
...
@@ -492,11 +589,13 @@ def recv_backward(pp_last_stage, sync_recv=True):
if
pp_last_stage
:
if
pp_last_stage
:
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_prev
=
None
,
tensor_send_next
=
None
,
recv_prev
=
False
,
tensor_send_prev
=
None
,
recv_next
=
True
,
recv_prev
=
False
,
sync_recv
=
sync_recv
)
recv_next
=
True
,
sync_recv
=
sync_recv
,
)
return
output_tensor_grad
return
output_tensor_grad
...
@@ -507,28 +606,34 @@ def send_forward(output_tensor, pp_last_stage):
...
@@ -507,28 +606,34 @@ def send_forward(output_tensor, pp_last_stage):
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
has_send_meta
=
_use_cache
_send_recv_meta
.
has_send_meta
=
_use_cache
_p2p_helper
(
tensor_send_next
=
output_tensor
,
_p2p_helper
(
tensor_send_prev
=
None
,
tensor_send_next
=
output_tensor
,
recv_prev
=
False
,
tensor_send_prev
=
None
,
recv_next
=
False
)
recv_prev
=
False
,
recv_next
=
False
,
)
def
send_backward
(
input_tensor_grad
,
pp_first_stage
):
def
send_backward
(
input_tensor_grad
,
pp_first_stage
):
if
not
pp_first_stage
:
if
not
pp_first_stage
:
_p2p_helper
(
tensor_send_next
=
None
,
_p2p_helper
(
tensor_send_prev
=
input_tensor_grad
,
tensor_send_next
=
None
,
recv_prev
=
False
,
tensor_send_prev
=
input_tensor_grad
,
recv_next
=
False
)
recv_prev
=
False
,
recv_next
=
False
,
)
def
send_forward_recv_backward
(
output_tensor
,
pp_last_stage
):
def
send_forward_recv_backward
(
output_tensor
,
pp_last_stage
):
if
pp_last_stage
:
if
pp_last_stage
:
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_prev
=
None
,
tensor_send_next
=
output_tensor
,
recv_prev
=
False
,
tensor_send_prev
=
None
,
recv_next
=
True
)
recv_prev
=
False
,
recv_next
=
True
,
)
return
output_tensor_grad
return
output_tensor_grad
...
@@ -536,16 +641,18 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
...
@@ -536,16 +641,18 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
if
pp_first_stage
:
if
pp_first_stage
:
input_tensor
=
None
input_tensor
=
None
else
:
else
:
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
input_tensor
,
_
=
_p2p_helper
(
tensor_send_prev
=
input_tensor_grad
,
tensor_send_next
=
None
,
recv_prev
=
True
,
tensor_send_prev
=
input_tensor_grad
,
recv_next
=
False
)
recv_prev
=
True
,
recv_next
=
False
,
)
return
input_tensor
return
input_tensor
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
def
send_forward_backward_recv_forward_backward
(
input_tensor_grad
,
recv_prev
,
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
recv_next
):
):
# always have to send dytpe info to downstream
# always have to send dytpe info to downstream
if
not
_send_recv_meta
.
has_send_meta
:
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
set_send_message
(
output_tensor
)
...
@@ -559,7 +666,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
...
@@ -559,7 +666,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
sync_recv
=
False
)
sync_recv
=
False
,
)
return
input_tensor
,
output_tensor_grad
return
input_tensor
,
output_tensor_grad
...
@@ -573,19 +681,23 @@ def send_forward_recv_forward(output_tensor, recv_prev):
...
@@ -573,19 +681,23 @@ def send_forward_recv_forward(output_tensor, recv_prev):
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
input_tensor
,
_
=
_p2p_helper
(
tensor_send_prev
=
None
,
tensor_send_next
=
output_tensor
,
recv_prev
=
recv_prev
,
tensor_send_prev
=
None
,
recv_next
=
False
,
recv_prev
=
recv_prev
,
sync_recv
=
False
)
recv_next
=
False
,
sync_recv
=
False
,
)
return
input_tensor
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
):
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_prev
=
input_tensor_grad
,
tensor_send_next
=
None
,
recv_prev
=
False
,
tensor_send_prev
=
input_tensor_grad
,
recv_next
=
recv_next
,
recv_prev
=
False
,
sync_recv
=
False
)
recv_next
=
recv_next
,
sync_recv
=
False
,
)
return
output_tensor_grad
return
output_tensor_grad
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录