Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
71cdf009
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
71cdf009
编写于
6月 20, 2023
作者:
S
ShenLiang
提交者:
GitHub
6月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
solve conflict (#54747)
上级
f469f176
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
116 addition
and
218 deletion
+116
-218
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+3
-2
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+113
-216
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
71cdf009
...
@@ -10,11 +10,10 @@
...
@@ -10,11 +10,10 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
import
os
import
time
import
time
import
warnings
import
warnings
import
os
import
paddle
import
paddle
from
paddle
import
framework
from
paddle
import
framework
...
@@ -176,6 +175,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -176,6 +175,7 @@ class PipelineParallel(MetaParallelBase):
self
.
_enable_timer
=
self
.
_strategy
.
hybrid_configs
[
self
.
_enable_timer
=
self
.
_strategy
.
hybrid_configs
[
"pp_configs"
"pp_configs"
].
enable_timer
].
enable_timer
self
.
_profiling
=
self
.
_strategy
.
hybrid_configs
[
"pp_configs"
].
profiling
self
.
_profiling
=
self
.
_strategy
.
hybrid_configs
[
"pp_configs"
].
profiling
self
.
_records
=
[]
self
.
_records
=
[]
self
.
_record_format
=
(
self
.
_record_format
=
(
...
@@ -303,6 +303,7 @@ class PipelineParallel(MetaParallelBase):
...
@@ -303,6 +303,7 @@ class PipelineParallel(MetaParallelBase):
for
model
in
models
:
for
model
in
models
:
# For virtual pipeline. Will separate parameters in different chunk into
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
# different groups to get the best performance.
parameter_list
=
[
parameter_list
=
[
p
for
p
in
model
.
parameters
()
if
not
p
.
stop_gradient
p
for
p
in
model
.
parameters
()
if
not
p
.
stop_gradient
]
]
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
71cdf009
...
@@ -175,30 +175,63 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
...
@@ -175,30 +175,63 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
if
not
_enable_partial_send_recv
:
if
not
_enable_partial_send_recv
:
return
False
return
False
tensor_numel
=
np
.
prod
(
tensor
.
shape
)
tensor_numel
=
np
.
prod
(
tensor
.
shape
)
assert
tensor_numel
!=
0
,
"can't send/recv zero element"
assert
tensor_numel
>
0
,
"can't send/recv zero element"
return
mp_degree
>
1
and
tensor_numel
%
mp_degree
==
0
return
mp_degree
>
1
and
tensor_numel
%
mp_degree
==
0
def
_
partial_send_op
(
tensor
,
group
,
dst
,
nranks
,
rank_id
):
def
_
send_on_calc_stream
(
tensor
,
group
,
dst
,
nranks
=
1
,
rank_id
=
0
):
assert
(
assert
(
group
is
not
None
group
is
not
None
),
"Group should be an instance for _
partial_send_op
."
),
"Group should be an instance for _
send_on_calc_stream
."
dst_rank_in_group
=
group
.
get_group_rank
(
dst
)
dst_rank_in_group
=
group
.
get_group_rank
(
dst
)
if
framework
.
in_dynamic_mode
(
):
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
group
.
process_group
.
send_partial
(
return
group
.
process_group
.
send_partial
_on_calc_stream
(
tensor
,
dst_rank_in_group
,
nranks
,
rank_id
tensor
,
dst_rank_in_group
,
nranks
,
rank_id
)
)
else
:
return
group
.
process_group
.
send_on_calc_stream
(
tensor
,
dst_rank_in_group
)
def
_
partial_recv_op
(
tensor
,
group
,
src
,
nranks
,
rank_id
):
def
_
recv_on_calc_stream
(
tensor
,
group
,
src
,
nranks
=
1
,
rank_id
=
0
):
assert
(
assert
(
group
is
not
None
group
is
not
None
),
"Group should be an instance for _
partial_recv_op
."
),
"Group should be an instance for _
recv_on_calc_stream
."
src_rank_in_group
=
group
.
get_group_rank
(
src
)
src_rank_in_group
=
group
.
get_group_rank
(
src
)
if
framework
.
in_dynamic_mode
(
):
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
group
.
process_group
.
recv_partial
(
return
group
.
process_group
.
recv_partial
_on_calc_stream
(
tensor
,
src_rank_in_group
,
nranks
,
rank_id
tensor
,
src_rank_in_group
,
nranks
,
rank_id
)
)
else
:
return
group
.
process_group
.
recv_on_calc_stream
(
tensor
,
src_rank_in_group
)
class
P2PonCalcStream
:
def
__init__
(
self
,
op
,
tensor
,
peer
,
group
,
nranks
=
1
,
rank_id
=
0
):
"""
Args:
op (function): The function to be executed on the calc stream.
tensor (Tensor): The tensor to be sent or received.
peer (int): The peer rank.
group (Group): The process group to p2p.
nranks (int): The number of ranks in model parallel group.
rank_id (int): The rank id in the model parallel group.
"""
if
op
not
in
[
_send_on_calc_stream
,
_recv_on_calc_stream
]:
raise
RuntimeError
(
"Invalid ``op`` function. Expected ``op`` "
"to be of type ``_send_on_calc_stream`` or "
"``_recv_on_calc_stream``."
)
self
.
op
=
op
self
.
tensor
=
tensor
self
.
peer
=
peer
self
.
group
=
group
self
.
nranks
=
nranks
self
.
rank_id
=
rank_id
def
_partial_allgather_op
(
def
_partial_allgather_op
(
...
@@ -231,15 +264,12 @@ def allgather_partial(
...
@@ -231,15 +264,12 @@ def allgather_partial(
)
)
def
partial_batch_isend_irecv
(
p2p_op_list
):
def
batch_send_recv_on_calc_stream
(
p2p_op_list
):
group
=
p2p_op_list
[
0
].
group
group
=
p2p_op_list
[
0
].
group
if
_warn_cur_rank_not_in_group
(
group
):
if
_warn_cur_rank_not_in_group
(
group
):
return
return
if
framework
.
in_dynamic_mode
():
group
=
_get_global_group
()
if
group
is
None
else
group
group
=
_get_global_group
()
if
group
is
None
else
group
backend
=
group
.
backend
backend
=
group
.
backend
tasks
=
[]
with
_with_batch_p2p_guard
(
backend
):
with
_with_batch_p2p_guard
(
backend
):
for
p2p_op
in
p2p_op_list
:
for
p2p_op
in
p2p_op_list
:
op
=
p2p_op
.
op
op
=
p2p_op
.
op
...
@@ -248,29 +278,25 @@ def partial_batch_isend_irecv(p2p_op_list):
...
@@ -248,29 +278,25 @@ def partial_batch_isend_irecv(p2p_op_list):
comm_group
=
p2p_op
.
group
comm_group
=
p2p_op
.
group
nranks
=
p2p_op
.
nranks
nranks
=
p2p_op
.
nranks
rank_id
=
p2p_op
.
rank_id
rank_id
=
p2p_op
.
rank_id
task
=
op
(
tensor
,
comm_group
,
peer
,
nranks
,
rank_id
)
op
(
tensor
,
comm_group
,
peer
,
nranks
,
rank_id
)
if
task
is
not
None
:
tasks
.
append
(
task
)
return
tasks
else
:
raise
RuntimeError
(
"Don't support static graph mode currently."
)
class
PartialP2POp
:
def
_process_p2p_tuple_or_tensor
(
def
__init__
(
self
,
op
,
nranks
,
rank_id
,
tensor
,
peer
,
group
):
tensors
,
p2p_func
,
pp_rank
,
pp_group
,
mp_degree
=
1
,
mp_rank
=
0
if
op
not
in
[
_partial_recv_op
,
_partial_send_op
]:
):
raise
RuntimeError
(
ops
=
[]
"Invalid ``op`` function. Expected ``op`` "
if
isinstance
(
tensors
,
tuple
):
"to be of type ``_partial_send_op`` or "
for
tensor
in
tensors
:
"``_partial_recv_op``."
op
=
P2PonCalcStream
(
p2p_func
,
tensor
,
pp_rank
,
pp_group
,
mp_degree
,
mp_rank
)
)
ops
.
append
(
op
)
self
.
op
=
op
else
:
self
.
nranks
=
nranks
op
=
P2PonCalcStream
(
self
.
rank_id
=
rank_id
p2p_func
,
tensors
,
pp_rank
,
pp_group
,
mp_degree
,
mp_rank
self
.
tensor
=
tensor
)
self
.
peer
=
peer
ops
.
append
(
op
)
self
.
group
=
group
return
ops
def
_p2p_helper
(
def
_p2p_helper
(
...
@@ -326,189 +352,60 @@ def _p2p_helper(
...
@@ -326,189 +352,60 @@ def _p2p_helper(
)
)
ops
=
[]
ops
=
[]
partial_ops
=
[]
pipe_group
=
_hcg
.
get_pipe_parallel_group
()
pipe_group
=
_hcg
.
get_pipe_parallel_group
()
# start to p2p communicate
# start to p2p communicate
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
src_rank
=
_hcg
.
_get_p2p_prev_rank
()
src_rank
=
_hcg
.
_get_p2p_prev_rank
()
if
isinstance
(
tensor_send_prev
,
tuple
):
ops
.
extend
(
for
d
in
tensor_send_prev
:
_process_p2p_tuple_or_tensor
(
if
_is_valid_send_recv_partial
(
d
,
mp_degree
):
tensor_send_prev
,
op
=
PartialP2POp
(
_send_on_calc_stream
,
_partial_send_op
,
mp_degree
,
mp_rank
,
d
,
src_rank
,
pipe_group
,
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
isend
,
d
,
src_rank
,
src_rank
,
pipe_group
,
pipe_group
,
)
ops
.
append
(
op
)
else
:
if
_is_valid_send_recv_partial
(
tensor_send_prev
,
mp_degree
):
op
=
PartialP2POp
(
_partial_send_op
,
mp_degree
,
mp_degree
,
mp_rank
,
mp_rank
,
tensor_send_prev
,
src_rank
,
pipe_group
,
)
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
isend
,
tensor_send_prev
,
src_rank
,
pipe_group
,
)
)
ops
.
append
(
op
)
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
dst_rank
=
_hcg
.
_get_p2p_prev_rank
()
dst_rank
=
_hcg
.
_get_p2p_prev_rank
()
if
isinstance
(
tensor_recv_prev
,
tuple
):
ops
.
extend
(
for
d
in
tensor_recv_prev
:
_process_p2p_tuple_or_tensor
(
if
_is_valid_send_recv_partial
(
d
,
mp_degree
):
tensor_recv_prev
,
op
=
PartialP2POp
(
_recv_on_calc_stream
,
_partial_recv_op
,
mp_degree
,
mp_rank
,
d
,
dst_rank
,
pipe_group
,
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
irecv
,
d
,
dst_rank
,
dst_rank
,
pipe_group
,
pipe_group
,
)
ops
.
append
(
op
)
else
:
if
_is_valid_send_recv_partial
(
tensor_recv_prev
,
mp_degree
):
op
=
PartialP2POp
(
_partial_recv_op
,
mp_degree
,
mp_degree
,
mp_rank
,
mp_rank
,
tensor_recv_prev
,
dst_rank
,
pipe_group
,
)
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
irecv
,
tensor_recv_prev
,
dst_rank
,
pipe_group
,
)
)
ops
.
append
(
op
)
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
src_rank
=
_hcg
.
_get_p2p_next_rank
()
src_rank
=
_hcg
.
_get_p2p_next_rank
()
if
isinstance
(
tensor_send_next
,
tuple
):
ops
.
extend
(
for
d
in
tensor_send_next
:
_process_p2p_tuple_or_tensor
(
if
_is_valid_send_recv_partial
(
d
,
mp_degree
):
tensor_send_next
,
op
=
PartialP2POp
(
_send_on_calc_stream
,
_partial_send_op
,
mp_degree
,
mp_rank
,
d
,
src_rank
,
pipe_group
,
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
isend
,
d
,
src_rank
,
src_rank
,
pipe_group
,
pipe_group
,
)
ops
.
append
(
op
)
else
:
if
_is_valid_send_recv_partial
(
tensor_send_next
,
mp_degree
):
op
=
PartialP2POp
(
_partial_send_op
,
mp_degree
,
mp_degree
,
mp_rank
,
mp_rank
,
tensor_send_next
,
src_rank
,
pipe_group
,
)
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
isend
,
tensor_send_next
,
src_rank
,
pipe_group
,
)
)
ops
.
append
(
op
)
if
tensor_recv_next
is
not
None
:
if
tensor_recv_next
is
not
None
:
dst_rank
=
_hcg
.
_get_p2p_next_rank
()
dst_rank
=
_hcg
.
_get_p2p_next_rank
()
if
isinstance
(
tensor_recv_next
,
tuple
):
ops
.
extend
(
for
d
in
tensor_recv_next
:
_process_p2p_tuple_or_tensor
(
if
_is_valid_send_recv_partial
(
d
,
mp_degree
):
tensor_recv_next
,
op
=
PartialP2POp
(
_recv_on_calc_stream
,
_partial_recv_op
,
mp_degree
,
mp_rank
,
d
,
dst_rank
,
pipe_group
,
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
irecv
,
d
,
dst_rank
,
dst_rank
,
pipe_group
,
pipe_group
,
)
ops
.
append
(
op
)
else
:
if
_is_valid_send_recv_partial
(
tensor_recv_next
,
mp_degree
):
op
=
PartialP2POp
(
_partial_recv_op
,
mp_degree
,
mp_degree
,
mp_rank
,
mp_rank
,
tensor_recv_next
,
dst_rank
,
pipe_group
,
)
)
partial_ops
.
append
(
op
)
else
:
op
=
paddle
.
distributed
.
P2POp
(
paddle
.
distributed
.
irecv
,
tensor_recv_next
,
dst_rank
,
pipe_group
,
)
)
ops
.
append
(
op
)
if
len
(
ops
)
>
0
:
if
len
(
ops
)
>
0
:
reqs
=
paddle
.
distributed
.
batch_isend_irecv
(
ops
)
batch_send_recv_on_calc_stream
(
ops
)
for
req
in
reqs
:
req
.
wait
()
if
len
(
partial_ops
)
>
0
:
reqs
=
partial_batch_isend_irecv
(
partial_ops
)
for
req
in
reqs
:
req
.
wait
()
# block cpu to wait the result
paddle
.
device
.
synchronize
()
tensors_for_all_gather
=
[]
tensors_for_all_gather
=
[]
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录