Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ff806111
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff806111
编写于
6月 16, 2023
作者:
L
LiYuRio
提交者:
GitHub
6月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
separate four directions p2p communication to a new file (#54664)
上级
aac91e82
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
931 addition
and
4 deletion
+931
-4
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+54
-3
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+11
-1
python/paddle/distributed/fleet/meta_parallel/pp_utils/four_directions_p2p_communication.py
...ta_parallel/pp_utils/four_directions_p2p_communication.py
+866
-0
未找到文件。
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
ff806111
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
collections
import
os
from
functools
import
reduce
from
itertools
import
product
...
...
@@ -24,6 +25,9 @@ from ..utils.log_util import logger
__all__
=
[
'CommunicateTopology'
,
'HybridCommunicateGroup'
]
_HYBRID_PARALLEL_GROUP
=
None
_use_four_directions
=
os
.
environ
.
get
(
'PADDLE_USE_FOUR_DIRECTIONS_P2P'
,
paddle
.
fluid
.
core
.
is_compiled_with_xpu
()
)
class
ParallelMode
:
...
...
@@ -191,7 +195,9 @@ class HybridCommunicateGroup:
if
self
.
_pp_degree
>
1
:
if
paddle
.
framework
.
core
.
is_compiled_with_nccl
():
check_nccl_version_for_p2p
()
self
.
_set_p2p_group
()
self
.
_set_p2p_prev_next
()
if
_use_four_directions
:
self
.
_set_four_directions_p2p_group
()
debug_str
=
(
"HybridParallelInfo: rank_id: %d, mp_degree: %d, "
...
...
@@ -291,7 +297,7 @@ class HybridCommunicateGroup:
assert
hasattr
(
self
,
'prev_rank'
),
"prev_rank has not been inited"
return
self
.
prev_rank
def
_set_p2p_
group
(
self
):
def
_set_p2p_
prev_next
(
self
):
comm_lists
=
self
.
_topo
.
get_comm_list
(
'pipe'
)
for
comm_ranks
in
comm_lists
:
...
...
@@ -305,6 +311,43 @@ class HybridCommunicateGroup:
self
.
next_rank
=
next_rank
self
.
prev_rank
=
prev_rank
def
_set_four_directions_p2p_group
(
self
):
comm_lists
=
self
.
_topo
.
get_comm_list
(
'pipe'
)
self
.
send_next_group
=
None
self
.
send_prev_group
=
None
self
.
recv_next_group
=
None
self
.
recv_prev_group
=
None
for
comm_ranks
in
comm_lists
:
assert
len
(
comm_ranks
)
==
self
.
_pp_degree
for
idx
,
rank
in
enumerate
(
comm_ranks
):
curr_rank
=
rank
next_rank
=
comm_ranks
[(
idx
+
1
)
%
self
.
_pp_degree
]
prev_rank
=
comm_ranks
[(
idx
-
1
)
%
self
.
_pp_degree
]
next_group
=
paddle
.
distributed
.
new_group
(
ranks
=
[
curr_rank
,
next_rank
]
)
if
self
.
global_rank
==
curr_rank
:
self
.
send_next_group
=
next_group
elif
self
.
global_rank
==
next_rank
:
self
.
recv_prev_group
=
next_group
prev_group
=
paddle
.
distributed
.
new_group
(
ranks
=
[
prev_rank
,
curr_rank
]
)
if
self
.
global_rank
==
curr_rank
:
self
.
send_prev_group
=
prev_group
elif
self
.
global_rank
==
prev_rank
:
self
.
recv_next_group
=
prev_group
assert
self
.
send_next_group
is
not
None
assert
self
.
send_prev_group
is
not
None
assert
self
.
recv_next_group
is
not
None
assert
self
.
recv_prev_group
is
not
None
def
topology
(
self
):
return
self
.
_topo
...
...
@@ -357,7 +400,15 @@ class HybridCommunicateGroup:
return
self
.
_pp_comm_group
def
get_p2p_groups
(
self
):
return
None
assert
(
_use_four_directions
),
"If you want to use four directions p2p group, set the environment variable PADDLE_USE_FOUR_DIRECTIONS_P2P to True."
return
(
self
.
send_next_group
,
self
.
send_prev_group
,
self
.
recv_next_group
,
self
.
recv_prev_group
,
)
# sharding parallel message:
def
_get_sharding_parallel_id
(
self
):
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
ff806111
...
...
@@ -13,6 +13,8 @@
import
time
import
warnings
import
os
import
paddle
from
paddle
import
framework
...
...
@@ -26,7 +28,15 @@ from ..utils.hybrid_parallel_util import (
from
..utils.log_util
import
logger
from
.meta_parallel_base
import
MetaParallelBase
from
.parallel_layers.pp_layers
import
PipelineLayer
from
.pp_utils
import
p2p_communication
as
p2p
_use_four_directions
=
os
.
environ
.
get
(
'PADDLE_USE_FOUR_DIRECTIONS_P2P'
,
paddle
.
fluid
.
core
.
is_compiled_with_xpu
()
)
if
_use_four_directions
:
from
.pp_utils
import
four_directions_p2p_communication
as
p2p
else
:
from
.pp_utils
import
p2p_communication
as
p2p
from
.pp_utils.utils
import
HOOK_ACTION
,
FusedCommBuffer
,
assign_group_by_size
__all__
=
[]
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/four_directions_p2p_communication.py
0 → 100644
浏览文件 @
ff806111
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
from
paddle
import
framework
from
...utils
import
timer_helper
as
timer
from
...utils.log_util
import
logger
from
.utils
import
number_2_dtype
,
paddle_2_number
_hcg
=
None
_use_cache
=
False
_enable_partial_send_recv
=
True
_timers
=
None
_xpu_comm_group_started
=
False
_sync_send
=
os
.
environ
.
get
(
"PADDLE_P2P_SYNC_SEND"
,
"0"
)
_sync_send
=
_sync_send
.
lower
()
in
[
'1'
,
'true'
]
def
_xpu_comm_group_start
():
if
not
paddle
.
is_compiled_with_xpu
():
return
global
_xpu_comm_group_started
assert
not
_xpu_comm_group_started
framework
.
core
.
ProcessGroupBKCL
.
group_start
()
_xpu_comm_group_started
=
True
def
_xpu_comm_group_end
():
if
not
paddle
.
is_compiled_with_xpu
():
return
global
_xpu_comm_group_started
if
_xpu_comm_group_started
:
framework
.
core
.
ProcessGroupBKCL
.
group_end
()
_xpu_comm_group_started
=
False
def
initialize_p2p_groups
(
hcg
,
use_cache
=
True
,
enable_partial_send_recv
=
True
,
enable_timer
=
False
):
global
_hcg
,
_use_cache
,
_enable_partial_send_recv
,
_timers
_hcg
=
hcg
_use_cache
=
use_cache
_enable_partial_send_recv
=
enable_partial_send_recv
if
enable_timer
:
_timers
=
timer
.
get_timers
()
(
send_next_group
,
send_prev_group
,
recv_next_group
,
recv_prev_group
,
)
=
_hcg
.
get_p2p_groups
()
debug_str
=
(
"P2pInfo: send_next_group: {}, send_prev_group: {}, "
"recv_next_group: {}, recv_prev_group: {}"
.
format
(
repr
(
send_next_group
),
repr
(
send_prev_group
),
repr
(
recv_next_group
),
repr
(
recv_prev_group
),
)
)
logger
.
info
(
debug_str
)
class
SendRecvMeta
:
"""Mainly used to help p2p communication context information"""
def
__init__
(
self
):
self
.
send_shape_message
=
None
self
.
send_dtype_message
=
None
self
.
recv_shape_message
=
None
self
.
recv_dtype_message
=
None
self
.
recv_stop_gradient
=
None
self
.
has_send_meta
=
False
self
.
has_recv_meta
=
False
def
_recv_shape_dtype
(
self
,
group
):
# recv len(shape)
dims
=
paddle
.
to_tensor
([
0
])
src_rank
=
_hcg
.
_get_p2p_prev_rank
()
paddle
.
distributed
.
recv
(
dims
,
src
=
src_rank
,
group
=
group
)
dims
=
dims
.
item
()
# recv shape
shape
=
paddle
.
to_tensor
([
0
]
*
dims
)
paddle
.
distributed
.
recv
(
shape
,
src
=
src_rank
,
group
=
group
)
# recv dtype
dtype
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
dtype
,
src
=
src_rank
,
group
=
group
)
# recv stop_gradient
stop_grad
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
stop_grad
,
src
=
src_rank
,
group
=
group
)
return
shape
.
tolist
(),
dtype
.
item
(),
stop_grad
.
item
()
def
recv_meta
(
self
,
group
):
tensor_type
=
paddle
.
to_tensor
([
0
])
src_rank
=
_hcg
.
_get_p2p_prev_rank
()
paddle
.
distributed
.
recv
(
tensor_type
,
src
=
src_rank
,
group
=
group
)
tensor_type
=
tensor_type
.
item
()
if
tensor_type
==
0
:
shape
,
dtype
,
stop_grad
=
self
.
_recv_shape_dtype
(
group
)
self
.
recv_shape_message
=
shape
self
.
recv_dtype_message
=
dtype
self
.
recv_stop_gradient
=
bool
(
stop_grad
)
elif
tensor_type
==
1
:
num
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
num
,
src
=
src_rank
,
group
=
group
)
num
=
num
.
item
()
shapes
=
[]
dtypes
=
[]
stop_grads
=
[]
for
i
in
range
(
num
):
shape
,
dtype
,
stop_grad
=
self
.
_recv_shape_dtype
(
group
)
shapes
.
append
(
shape
)
dtypes
.
append
(
dtype
)
stop_grads
.
append
(
bool
(
stop_grad
))
self
.
recv_shape_message
=
tuple
(
shapes
)
self
.
recv_dtype_message
=
tuple
(
dtypes
)
self
.
recv_stop_gradient
=
tuple
(
stop_grads
)
def
_send_dims_shape_dtype
(
self
,
tensor
,
group
):
# send len(shape)
dims
=
paddle
.
to_tensor
([
len
(
tensor
.
shape
)])
dst_rank
=
_hcg
.
_get_p2p_next_rank
()
paddle
.
distributed
.
send
(
dims
,
dst
=
dst_rank
,
group
=
group
)
# send shape
shape
=
paddle
.
to_tensor
(
tensor
.
shape
)
paddle
.
distributed
.
send
(
shape
,
dst
=
dst_rank
,
group
=
group
)
# send dtype
dtype
=
paddle
.
to_tensor
([
paddle_2_number
(
tensor
.
dtype
)])
paddle
.
distributed
.
send
(
dtype
,
dst
=
dst_rank
,
group
=
group
)
# send trainable
stop_grad
=
paddle
.
to_tensor
([
int
(
tensor
.
stop_gradient
)])
paddle
.
distributed
.
send
(
stop_grad
,
dst
=
dst_rank
,
group
=
group
)
def
send_meta
(
self
,
tensor
,
group
):
dst_rank
=
_hcg
.
_get_p2p_next_rank
()
if
isinstance
(
tensor
,
(
paddle
.
Tensor
,
framework
.
core
.
eager
.
Tensor
)):
tensor_type
=
paddle
.
to_tensor
([
0
])
# send tensor type
paddle
.
distributed
.
send
(
tensor_type
,
dst
=
dst_rank
,
group
=
group
)
self
.
_send_dims_shape_dtype
(
tensor
,
group
)
elif
isinstance
(
tensor
,
tuple
):
tensor_type
=
paddle
.
to_tensor
([
1
])
# send tensor type
paddle
.
distributed
.
send
(
tensor_type
,
dst
=
dst_rank
,
group
=
group
)
nums
=
paddle
.
to_tensor
([
len
(
tensor
)])
paddle
.
distributed
.
send
(
nums
,
dst
=
dst_rank
,
group
=
group
)
for
d
in
tensor
:
assert
isinstance
(
d
,
(
paddle
.
Tensor
,
framework
.
core
.
eager
.
Tensor
)
)
self
.
_send_dims_shape_dtype
(
d
,
group
=
group
)
def
set_send_message
(
self
,
tensor
):
if
isinstance
(
tensor
,
(
paddle
.
Tensor
,
framework
.
core
.
eager
.
Tensor
)):
self
.
send_shape_message
=
tensor
.
shape
self
.
send_dtype_message
=
paddle_2_number
(
tensor
.
dtype
)
elif
isinstance
(
tensor
,
tuple
):
self
.
send_shape_message
=
tuple
(
[
d
.
shape
for
d
in
tensor
if
not
d
.
stop_gradient
]
)
self
.
send_dtype_message
=
tuple
(
[
paddle_2_number
(
d
.
dtype
)
for
d
in
tensor
if
not
d
.
stop_gradient
]
)
_send_recv_meta
=
SendRecvMeta
()
def
_is_valid_send_recv_partial
(
tensor
,
mp_degree
):
if
not
_enable_partial_send_recv
:
return
False
tensor_numel
=
np
.
prod
(
tensor
.
shape
)
assert
tensor_numel
!=
0
,
"can't send/recv zero element"
return
mp_degree
>
1
and
tensor_numel
%
mp_degree
==
0
def
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
):
dst_rank_in_group
=
dst
if
group
is
None
else
group
.
get_group_rank
(
dst
)
if
framework
.
in_dynamic_mode
():
group
=
(
paddle
.
distributed
.
collective
.
_get_default_group
()
if
group
is
None
else
group
)
comm_op
=
(
group
.
process_group
.
send_partial_on_calc_stream
if
use_calc_stream
else
group
.
process_group
.
send_partial
)
return
comm_op
(
tensor
,
dst_rank_in_group
,
nranks
,
rank_id
)
def
send_partial
(
tensor
,
dst
=
0
,
nranks
=
1
,
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
):
# dst: local rank in group
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
dst_rank
=
(
_hcg
.
_get_p2p_next_rank
()
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
)
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst_rank
,
nranks
,
rank_id
)
else
:
send_op
=
paddle
.
distributed
.
isend
return
send_op
(
tensor
.
detach
(),
dst
=
dst_rank
,
group
=
group
)
def
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
):
src_rank_in_group
=
src
if
group
is
None
else
group
.
get_group_rank
(
src
)
group
=
(
paddle
.
distributed
.
collective
.
_get_default_group
()
if
group
is
None
else
group
)
comm_op
=
(
group
.
process_group
.
recv_partial_on_calc_stream
if
use_calc_stream
else
group
.
process_group
.
recv_partial
)
return
comm_op
(
tensor
,
src_rank_in_group
,
nranks
,
rank_id
)
def
recv_partial
(
tensor
,
src
=
0
,
nranks
=
1
,
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
):
# src: local rank in group
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
src_rank
=
(
_hcg
.
_get_p2p_prev_rank
()
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
)
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src_rank
,
nranks
,
rank_id
)
else
:
if
use_calc_stream
:
recv_op
=
paddle
.
distributed
.
recv
elif
framework
.
in_dynamic_mode
():
recv_op
=
paddle
.
distributed
.
irecv
return
recv_op
(
tensor
.
detach
(),
src
=
src_rank
,
group
=
group
)
def
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
nranks
,
rank_id
):
group
=
(
paddle
.
distributed
.
collective
.
_get_default_group
()
if
group
is
None
else
group
)
comm_op
=
(
group
.
process_group
.
all_gather_partial_on_calc_stream
if
use_calc_stream
else
group
.
process_group
.
all_gather_partial
)
return
comm_op
(
tensor
,
tensor
,
nranks
,
rank_id
)
def
allgather_partial
(
tensor
,
nranks
=
1
,
rank_id
=
0
,
group
=
None
,
use_calc_stream
=
True
):
if
not
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
tensor
if
group
is
not
None
and
not
group
.
is_member
():
return
ring_id
=
0
if
group
is
None
else
group
.
id
return
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
nranks
,
rank_id
)
def
_p2p_helper
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
sync_recv
=
True
):
global
_hcg
tensor_recv_prev
=
None
tensor_recv_next
=
None
# send / recv message
recv_shape_msg
=
_send_recv_meta
.
recv_shape_message
recv_dtype_msg
=
_send_recv_meta
.
recv_dtype_message
recv_stop_gradient
=
_send_recv_meta
.
recv_stop_gradient
send_shape_msg
=
_send_recv_meta
.
send_shape_message
send_dtype_msg
=
_send_recv_meta
.
send_dtype_message
# model parallel message
mp_group
=
_hcg
.
get_model_parallel_group
()
mp_degree
=
_hcg
.
get_model_parallel_world_size
()
mp_rank
=
_hcg
.
get_model_parallel_rank
()
if
recv_prev
:
if
isinstance
(
recv_shape_msg
,
tuple
):
tensor_recv_prev
=
[]
for
idx
,
shape
in
enumerate
(
recv_shape_msg
):
tmp
=
paddle
.
empty
(
shape
=
shape
,
dtype
=
number_2_dtype
(
recv_dtype_msg
[
idx
])
)
tmp
.
stop_gradient
=
recv_stop_gradient
[
idx
]
tensor_recv_prev
.
append
(
tmp
)
tensor_recv_prev
=
tuple
(
tensor_recv_prev
)
else
:
tensor_recv_prev
=
paddle
.
empty
(
shape
=
recv_shape_msg
,
dtype
=
number_2_dtype
(
recv_dtype_msg
)
)
tensor_recv_prev
.
stop_gradient
=
recv_stop_gradient
if
recv_next
:
if
isinstance
(
send_shape_msg
,
tuple
):
tensor_recv_next
=
[]
for
idx
,
shape
in
enumerate
(
send_shape_msg
):
tensor_recv_next
.
append
(
paddle
.
empty
(
shape
=
shape
,
dtype
=
number_2_dtype
(
send_dtype_msg
[
idx
])
)
)
tensor_recv_next
=
tuple
(
tensor_recv_next
)
else
:
tensor_recv_next
=
paddle
.
empty
(
shape
=
send_shape_msg
,
dtype
=
number_2_dtype
(
send_dtype_msg
)
)
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks
=
[]
# start to p2p communicate
if
_sync_send
:
# Some devices(NPU for example) do not support asynchronized send op, So the order is
# recv_prev -> send_next -> recv_next -> send_prev
# When using this order, the environment variable
# 'PADDLE_P2P_SYNC_SEND' should be set True
if
tensor_recv_prev
is
not
None
:
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
task
=
recv_partial
(
d
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
else
:
task
=
recv_partial
(
tensor_recv_prev
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
allgather_partial
(
tensor_recv_prev
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
if
tensor_send_next
is
not
None
:
if
isinstance
(
tensor_send_next
,
tuple
):
for
d
in
tensor_send_next
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
,
)
else
:
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_next
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
,
)
if
tensor_recv_next
is
not
None
:
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
task
=
recv_partial
(
d
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
else
:
task
=
recv_partial
(
tensor_recv_next
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
allgather_partial
(
tensor_recv_next
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
if
tensor_send_prev
is
not
None
:
if
isinstance
(
tensor_send_prev
,
tuple
):
for
d
in
tensor_send_prev
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
,
)
else
:
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_prev
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
,
)
else
:
_xpu_comm_group_start
()
if
tensor_send_prev
is
not
None
:
if
isinstance
(
tensor_send_prev
,
tuple
):
for
d
in
tensor_send_prev
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
,
)
else
:
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_prev
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
,
)
if
tensor_recv_prev
is
not
None
:
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
task
=
recv_partial
(
d
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
_xpu_comm_group_end
()
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
else
:
task
=
recv_partial
(
tensor_recv_prev
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
_xpu_comm_group_end
()
allgather_partial
(
tensor_recv_prev
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
if
tensor_send_next
is
not
None
:
if
isinstance
(
tensor_send_next
,
tuple
):
for
d
in
tensor_send_next
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
,
)
else
:
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_next
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
,
)
if
tensor_recv_next
is
not
None
:
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
task
=
recv_partial
(
d
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
_xpu_comm_group_end
()
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
else
:
task
=
recv_partial
(
tensor_recv_next
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
sync_recv
,
)
if
sync_recv
:
_xpu_comm_group_end
()
allgather_partial
(
tensor_recv_next
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
else
:
tasks
.
append
(
task
)
_xpu_comm_group_end
()
if
not
sync_recv
:
if
framework
.
in_dynamic_mode
():
# wait irecv tasks in eager dygraph mode with new comm library
for
task
in
tasks
:
assert
task
is
not
None
task
.
wait
()
tensors_for_all_gather
=
[]
if
tensor_recv_prev
is
not
None
:
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
tensors_for_all_gather
.
append
(
d
)
else
:
tensors_for_all_gather
.
append
(
tensor_recv_prev
)
if
tensor_recv_next
is
not
None
:
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
tensors_for_all_gather
.
append
(
d
)
else
:
tensors_for_all_gather
.
append
(
tensor_recv_next
)
for
tensor
in
tensors_for_all_gather
:
allgather_partial
(
tensor
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
,
)
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
pp_first_stage
,
sync_recv
=
True
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
start
()
if
pp_first_stage
:
input_tensor
=
None
else
:
if
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
sync_recv
=
sync_recv
,
)
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
stop
()
return
input_tensor
def
recv_backward
(
pp_last_stage
,
sync_recv
=
True
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
start
()
if
pp_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
sync_recv
=
sync_recv
,
)
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
pp_last_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
start
()
if
not
pp_last_stage
:
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
has_send_meta
=
_use_cache
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
stop
()
def
send_backward
(
input_tensor_grad
,
pp_first_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
start
()
if
not
pp_first_stage
:
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
pp_last_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
start
()
if
pp_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
pp_first_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
start
()
if
pp_first_stage
:
input_tensor
=
None
else
:
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
stop
()
return
input_tensor
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
):
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
start
()
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
has_send_meta
=
_use_cache
if
recv_prev
and
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
sync_recv
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
stop
()
return
input_tensor
,
output_tensor_grad
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
):
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
start
()
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
has_send_meta
=
_use_cache
if
recv_prev
and
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
sync_recv
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
start
()
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
sync_recv
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
stop
()
return
output_tensor_grad
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录