Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
72b5b5bf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72b5b5bf
编写于
9月 06, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
9月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dygraph hybrid pp for interleave] The interleave scheduler for pipeline parallel (#45497)
上级
fd86a938
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
783 addition
and
150 deletion
+783
-150
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+2
-2
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+12
-0
python/paddle/distributed/fleet/meta_parallel/__init__.py
python/paddle/distributed/fleet/meta_parallel/__init__.py
+1
-0
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+4
-1
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+365
-27
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+169
-107
python/paddle/distributed/fleet/model.py
python/paddle/distributed/fleet/model.py
+12
-2
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+0
-7
python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt
...dle/fluid/tests/unittests/collective/fleet/CMakeLists.txt
+14
-0
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_layer_with_virtual_stage.py
...tive/fleet/hybrid_parallel_pp_layer_with_virtual_stage.py
+4
-2
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py
...leet/hybrid_parallel_pp_transformer_with_virtual_stage.py
+195
-0
python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py
..._parallel_dygraph_pipeline_parallel_with_virtual_stage.py
+4
-2
python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv
...ddle/fluid/tests/unittests/collective/fleet/testslist.csv
+1
-0
未找到文件。
python/paddle/distributed/collective.py
浏览文件 @
72b5b5bf
...
...
@@ -2384,7 +2384,7 @@ def isend(tensor, dst, group=None):
assert
group_dst_rank
>=
0
,
(
"dst rank out of group, need global rank"
)
return
group
.
process_group
.
send
(
tensor
,
group_dst_rank
)
else
:
raise
RuntimeError
(
"
Don't support static graph mode currently
."
)
raise
RuntimeError
(
"
Only support eager dygraph mode
."
)
def
irecv
(
tensor
,
src
=
None
,
group
=
None
):
...
...
@@ -2433,7 +2433,7 @@ def irecv(tensor, src=None, group=None):
assert
group_src_rank
>=
0
,
(
"src rank out of group, need global rank"
)
return
group
.
process_group
.
recv
(
tensor
,
group_src_rank
)
else
:
raise
RuntimeError
(
"
Don't support static graph mode currently
."
)
raise
RuntimeError
(
"
Only support eager dygraph mode
."
)
class
P2POp
(
object
):
...
...
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
72b5b5bf
...
...
@@ -240,6 +240,14 @@ class HybridCommunicateGroup(object):
return
parallel_group
,
parallel_comm_group
def
_get_p2p_next_rank
(
self
):
assert
hasattr
(
self
,
'next_rank'
),
"next_rank has not been inited"
return
self
.
next_rank
def
_get_p2p_prev_rank
(
self
):
assert
hasattr
(
self
,
'prev_rank'
),
"prev_rank has not been inited"
return
self
.
prev_rank
def
_set_p2p_group
(
self
):
comm_lists
=
self
.
_topo
.
get_comm_list
(
'pipe'
)
...
...
@@ -255,6 +263,10 @@ class HybridCommunicateGroup(object):
next_rank
=
comm_ranks
[(
idx
+
1
)
%
self
.
_pp_degree
]
prev_rank
=
comm_ranks
[(
idx
-
1
)
%
self
.
_pp_degree
]
if
self
.
global_rank
==
curr_rank
:
self
.
next_rank
=
next_rank
self
.
prev_rank
=
prev_rank
next_group
=
paddle
.
distributed
.
new_group
(
ranks
=
[
curr_rank
,
next_rank
])
if
self
.
global_rank
==
curr_rank
:
...
...
python/paddle/distributed/fleet/meta_parallel/__init__.py
浏览文件 @
72b5b5bf
...
...
@@ -24,6 +24,7 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401
from
.parallel_layers
import
get_rng_state_tracker
# noqa: F401
from
.tensor_parallel
import
TensorParallel
# noqa: F401
from
.pipeline_parallel
import
PipelineParallel
# noqa: F401
from
.pipeline_parallel
import
PipelineParallelWithInterleave
# noqa: F401
from
.sharding_parallel
import
ShardingParallel
# noqa: F401
__all__
=
[]
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
72b5b5bf
...
...
@@ -189,7 +189,7 @@ class PipelineLayerChunk(Layer):
# Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute
# are in the forward function of PipelineLayer. Any directly call will bring unexpected
# behavior under recompute circumstance.
raise
NotImplemented
Error
(
raise
Permission
Error
(
"The forward function of PipelineLayerChunk cannot be called directly. "
"Please call forward function of PipelineLayer."
)
...
...
@@ -385,6 +385,9 @@ class PipelineLayer(Layer):
start_idx
+
stage
+
1
]:
return
stage
def
get_num_virtual_stages
(
self
):
return
self
.
_num_virtual_pipeline_stages
def
get_model_chunks
(
self
):
return
None
if
self
.
_num_virtual_pipeline_stages
==
1
else
self
.
_model_chunks
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
72b5b5bf
此差异已折叠。
点击以展开。
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
72b5b5bf
...
...
@@ -54,7 +54,7 @@ class SendRecvMeta:
def
_recv_shape_dtype
(
self
,
group
):
# recv len(shape)
dims
=
paddle
.
to_tensor
([
0
])
src_rank
=
group
.
ranks
[
0
]
src_rank
=
_hcg
.
_get_p2p_prev_rank
()
paddle
.
distributed
.
recv
(
dims
,
src
=
src_rank
,
group
=
group
)
dims
=
dims
.
item
()
...
...
@@ -74,7 +74,7 @@ class SendRecvMeta:
def
recv_meta
(
self
,
group
):
tensor_type
=
paddle
.
to_tensor
([
0
])
src_rank
=
group
.
ranks
[
0
]
src_rank
=
_hcg
.
_get_p2p_prev_rank
()
paddle
.
distributed
.
recv
(
tensor_type
,
src
=
src_rank
,
group
=
group
)
tensor_type
=
tensor_type
.
item
()
...
...
@@ -105,7 +105,7 @@ class SendRecvMeta:
def
_send_dims_shape_dtype
(
self
,
tensor
,
group
):
# send len(shape)
dims
=
paddle
.
to_tensor
(
len
(
tensor
.
shape
))
dst_rank
=
group
.
ranks
[
1
]
dst_rank
=
_hcg
.
_get_p2p_next_rank
()
paddle
.
distributed
.
send
(
dims
,
dst
=
dst_rank
,
group
=
group
)
...
...
@@ -122,7 +122,7 @@ class SendRecvMeta:
paddle
.
distributed
.
send
(
stop_grad
,
dst
=
dst_rank
,
group
=
group
)
def
send_meta
(
self
,
tensor
,
group
):
dst_rank
=
group
.
ranks
[
1
]
dst_rank
=
_hcg
.
_get_p2p_next_rank
()
if
isinstance
(
tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)):
tensor_type
=
paddle
.
to_tensor
([
0
])
...
...
@@ -165,20 +165,17 @@ def _is_valid_send_recv_partial(tensor, mp_degree):
def
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
):
dst_rank_in_group
=
dst
if
group
is
None
else
group
.
get_group_rank
(
dst
)
if
_in_legacy_dygraph
():
return
_legacy_C_ops
.
partial_send
(
tensor
.
detach
(),
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
dst
,
'num'
,
nranks
,
'id
'
,
rank_id
)
'peer'
,
dst
_rank_in_group
,
'num
'
,
nranks
,
'id'
,
rank_id
)
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
task
=
group
.
process_group
.
send_partial
(
tensor
,
dst
,
nranks
,
rank_id
)
if
use_calc_stream
:
task
.
wait
()
return
None
else
:
return
task
return
group
.
process_group
.
send_partial
(
tensor
,
dst_rank_in_group
,
nranks
,
rank_id
)
def
send_partial
(
tensor
,
...
...
@@ -192,33 +189,35 @@ def send_partial(tensor,
return
ring_id
=
0
if
group
is
None
else
group
.
id
dst_rank
=
_hcg
.
_get_p2p_next_rank
(
)
if
dst
==
1
else
_hcg
.
_get_p2p_prev_rank
()
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst
,
nranks
,
rank_id
)
return
_partial_send_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
dst_rank
,
nranks
,
rank_id
)
else
:
return
paddle
.
distributed
.
send
(
tensor
.
detach
(),
dst
=
group
.
ranks
[
dst
],
group
=
group
,
use_calc_stream
=
use_calc_stream
)
if
_in_legacy_dygraph
():
send_op
=
paddle
.
distributed
.
send
elif
in_dygraph_mode
():
send_op
=
paddle
.
distributed
.
isend
return
send_op
(
tensor
.
detach
(),
dst
=
dst_rank
,
group
=
group
)
def
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
):
src_rank_in_group
=
src
if
group
is
None
else
group
.
get_group_rank
(
src
)
if
_in_legacy_dygraph
():
return
_legacy_C_ops
.
partial_recv
(
tensor
.
detach
(),
'use_calc_stream'
,
use_calc_stream
,
'ring_id'
,
ring_id
,
'peer'
,
src
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
tensor
.
shape
)
'peer'
,
src_rank_in_group
,
'num'
,
nranks
,
'id'
,
rank_id
,
'dtype'
,
tensor
.
dtype
,
'out_shape'
,
tensor
.
shape
)
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
task
=
group
.
process_group
.
recv_partial
(
tensor
,
src
,
nranks
,
rank_id
)
if
use_calc_stream
:
task
.
wait
()
return
None
else
:
return
task
return
group
.
process_group
.
recv_partial
(
tensor
,
src_rank_in_group
,
nranks
,
rank_id
)
def
recv_partial
(
tensor
,
...
...
@@ -232,14 +231,18 @@ def recv_partial(tensor,
return
ring_id
=
0
if
group
is
None
else
group
.
id
src_rank
=
_hcg
.
_get_p2p_prev_rank
(
)
if
src
==
0
else
_hcg
.
_get_p2p_next_rank
()
if
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src
,
nranks
,
rank_id
)
return
_partial_recv_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
src_rank
,
nranks
,
rank_id
)
else
:
return
paddle
.
distributed
.
recv
(
tensor
.
detach
(),
src
=
group
.
ranks
[
src
],
group
=
group
,
use_calc_stream
=
use_calc_stream
)
if
_in_legacy_dygraph
():
recv_op
=
paddle
.
distributed
.
recv
elif
in_dygraph_mode
():
recv_op
=
paddle
.
distributed
.
irecv
return
recv_op
(
tensor
.
detach
(),
src
=
src_rank
,
group
=
group
)
def
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
nranks
,
...
...
@@ -253,13 +256,8 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif
in_dygraph_mode
():
group
=
paddle
.
distributed
.
collective
.
_get_default_group
(
)
if
group
is
None
else
group
task
=
group
.
process_group
.
all_gather_partial
(
tensor
,
tensor
,
nranks
,
return
group
.
process_group
.
all_gather_partial
(
tensor
,
tensor
,
nranks
,
rank_id
)
if
use_calc_stream
:
task
.
wait
()
return
None
else
:
return
task
def
allgather_partial
(
tensor
,
...
...
@@ -268,9 +266,9 @@ def allgather_partial(tensor,
group
=
None
,
use_calc_stream
=
True
):
if
not
_is_valid_send_recv_partial
(
tensor
,
nranks
):
return
tensor
return
None
if
group
is
not
None
and
not
group
.
is_member
():
return
return
None
ring_id
=
0
if
group
is
None
else
group
.
id
return
_partial_allgather_op
(
tensor
,
group
,
use_calc_stream
,
ring_id
,
...
...
@@ -323,105 +321,124 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
tensor_recv_next
=
paddle
.
empty
(
shape
=
send_shape_msg
,
dtype
=
number_2_dtype
(
send_dtype_msg
))
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
tasks
=
[]
# start to p2p communicate
if
tensor_send_prev
is
not
None
:
if
isinstance
(
tensor_send_prev
,
tuple
):
for
d
in
tensor_send_prev
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
tasks
.
append
(
send_partial
(
d
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
))
else
:
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
tasks
.
append
(
send_partial
(
tensor_send_prev
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
)
else
:
paddle
.
distributed
.
wait
(
tensor_send_prev
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_prev
,
dst
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_prev_group
,
use_calc_stream
=
False
)
use_calc_stream
=
False
))
if
tensor_recv_prev
is
not
None
:
if
isinstance
(
tensor_recv_prev
,
tuple
):
for
d
in
tensor_recv_prev
:
recv_partial
(
d
,
tasks
.
append
(
recv_partial
(
d
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
True
))
tasks
.
append
(
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
))
else
:
tasks
.
append
(
recv_partial
(
tensor_recv_prev
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
True
)
allgather_partial
(
d
,
use_calc_stream
=
True
))
tasks
.
append
(
allgather_partial
(
tensor_recv_prev
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
else
:
recv_partial
(
tensor_recv_prev
,
src
=
0
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_prev_group
,
use_calc_stream
=
True
)
allgather_partial
(
tensor_recv_prev
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
use_calc_stream
=
True
))
if
tensor_send_next
is
not
None
:
if
isinstance
(
tensor_send_next
,
tuple
):
for
d
in
tensor_send_next
:
paddle
.
distributed
.
wait
(
d
,
use_calc_stream
=
True
)
send_partial
(
d
,
tasks
.
append
(
send_partial
(
d
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
))
else
:
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
tasks
.
append
(
send_partial
(
tensor_send_next
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
)
else
:
paddle
.
distributed
.
wait
(
tensor_send_next
,
use_calc_stream
=
True
)
send_partial
(
tensor_send_next
,
dst
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
send_next_group
,
use_calc_stream
=
False
)
use_calc_stream
=
False
))
if
tensor_recv_next
is
not
None
:
if
isinstance
(
tensor_recv_next
,
tuple
):
for
d
in
tensor_recv_next
:
recv_partial
(
d
,
tasks
.
append
(
recv_partial
(
d
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
True
))
tasks
.
append
(
allgather_partial
(
d
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
))
else
:
tasks
.
append
(
recv_partial
(
tensor_recv_next
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
True
)
allgather_partial
(
d
,
use_calc_stream
=
True
))
tasks
.
append
(
allgather_partial
(
tensor_recv_next
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
else
:
recv_partial
(
tensor_recv_next
,
src
=
1
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
_hcg
.
recv_next_group
,
use_calc_stream
=
True
)
allgather_partial
(
tensor_recv_next
,
nranks
=
mp_degree
,
rank_id
=
mp_rank
,
group
=
mp_group
,
use_calc_stream
=
True
)
use_calc_stream
=
True
))
if
in_dygraph_mode
():
# wait tasks in new dygraph mode with new comm library
for
task
in
tasks
:
if
task
is
not
None
:
task
.
wait
()
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
():
if
_hcg
.
is
_first_stage
:
def
recv_forward
(
pp_first_stage
):
if
pp
_first_stage
:
input_tensor
=
None
else
:
if
not
_send_recv_meta
.
has_recv_meta
:
...
...
@@ -435,8 +452,8 @@ def recv_forward():
return
input_tensor
def
recv_backward
():
if
_hcg
.
is
_last_stage
:
def
recv_backward
(
pp_last_stage
):
if
pp
_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
...
...
@@ -446,8 +463,8 @@ def recv_backward():
return
output_tensor_grad
def
send_forward
(
output_tensor
):
if
not
_hcg
.
is
_last_stage
:
def
send_forward
(
output_tensor
,
pp_last_stage
):
if
not
pp
_last_stage
:
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
...
...
@@ -459,16 +476,16 @@ def send_forward(output_tensor):
recv_next
=
False
)
def
send_backward
(
input_tensor_grad
):
if
not
_hcg
.
is
_first_stage
:
def
send_backward
(
input_tensor_grad
,
pp_first_stage
):
if
not
pp
_first_stage
:
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
)
def
send_forward_recv_backward
(
output_tensor
):
if
_hcg
.
is
_last_stage
:
def
send_forward_recv_backward
(
output_tensor
,
pp_last_stage
):
if
pp
_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
...
...
@@ -478,8 +495,8 @@ def send_forward_recv_backward(output_tensor):
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
):
if
_hcg
.
is
_first_stage
:
def
send_backward_recv_forward
(
input_tensor_grad
,
pp_first_stage
):
if
pp
_first_stage
:
input_tensor
=
None
else
:
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
...
...
@@ -487,3 +504,48 @@ def send_backward_recv_forward(input_tensor_grad):
recv_prev
=
True
,
recv_next
=
False
)
return
input_tensor
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
):
# always have to send dytpe info to downstream
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
has_send_meta
=
_use_cache
if
recv_prev
and
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
return
input_tensor
,
output_tensor_grad
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
):
# always have to send dytpe info to downstream
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
send_next_group
)
_send_recv_meta
.
has_send_meta
=
_use_cache
if
recv_prev
and
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
recv_prev_group
)
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
)
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
):
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
)
return
output_tensor_grad
python/paddle/distributed/fleet/model.py
浏览文件 @
72b5b5bf
...
...
@@ -18,7 +18,7 @@ import numpy as np
from
.base
import
topology
as
tp
from
.base.topology
import
ParallelMode
from
.meta_parallel
import
TensorParallel
,
model_parallel_random_seed
from
.meta_parallel
import
PipelineParallel
,
ShardingParallel
from
.meta_parallel
import
PipelineParallel
,
ShardingParallel
,
PipelineParallelWithInterleave
,
PipelineLayer
from
paddle.fluid
import
core
from
paddle.distributed.fleet.utils.recompute
import
LegacyRecomputeFunction
from
paddle.fluid.dygraph.varbase_patch_methods
import
_grad_scalar
...
...
@@ -185,6 +185,16 @@ def distributed_model(model):
elif
fleet_env
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
TENSOR_PARALLEL
:
model
=
TensorParallel
(
model
,
fleet_env
.
_hcg
,
strategy
=
strategy
)
elif
fleet_env
.
_hcg
.
get_parallel_mode
()
==
ParallelMode
.
PIPELINE_PARALLEL
:
model
=
PipelineParallel
(
model
,
fleet_env
.
_hcg
,
strategy
=
strategy
)
assert
isinstance
(
model
,
PipelineLayer
),
"For pipeline parallel, the model should an instance of PipelineLayer"
if
model
.
get_num_virtual_stages
()
==
1
:
# 1f1b pipeline
model
=
PipelineParallel
(
model
,
fleet_env
.
_hcg
,
strategy
=
strategy
)
else
:
# interleave pipeline
model
=
PipelineParallelWithInterleave
(
model
,
fleet_env
.
_hcg
,
strategy
=
strategy
)
return
model
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
72b5b5bf
...
...
@@ -27,8 +27,6 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_feedforward
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_attention
)
list
(
APPEND DIST_TEST_OPS test_static_model_parallel_fused_multi_transformer
)
list
(
APPEND DIST_TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_data_unshard
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_save_load
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_autoconvert
)
...
...
@@ -178,8 +176,6 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
# TODO(shenliang03): batch_fc_op support CPU device in future
# TODO(Yancey1989): parallel dygraph support CPU device in future
list
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel
)
list
(
REMOVE_ITEM TEST_OPS
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
)
list
(
REMOVE_ITEM TEST_OPS test_fleet_base_single
)
list
(
REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner
)
list
(
REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt
)
...
...
@@ -1178,9 +1174,6 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
if
(
WITH_DISTRIBUTE
AND WITH_GPU
AND WITH_NCCL
)
set_tests_properties
(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT 500
)
set_tests_properties
(
test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_save_load PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt
浏览文件 @
72b5b5bf
...
...
@@ -204,6 +204,20 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT)
set_tests_properties
(
test_parallel_dygraph_pipeline_parallel
PROPERTIES TIMEOUT
"500"
)
endif
()
if
((
WITH_GPU
)
AND LOCAL_ALL_PLAT
)
bash_test_modules
(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
START_BASH
../../dist_test.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21282;http_proxy=;https_proxy=;PYTHONPATH=../..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT
"500"
RUN_SERIAL 1
)
endif
()
if
((
WITH_GPU
OR WITH_XPU
OR WITH_ASCEND
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer_with_virtual_stage.py
→
python/paddle/fluid/tests/unittests/
collective/fleet/
hybrid_parallel_pp_layer_with_virtual_stage.py
浏览文件 @
72b5b5bf
...
...
@@ -19,7 +19,7 @@ import paddle
from
paddle.distributed
import
fleet
import
paddle.nn
as
nn
from
paddle.fluid.dygraph.layers
import
Layer
from
paddle.distributed.fleet.meta_parallel
import
LayerDesc
,
PipelineLayer
from
paddle.distributed.fleet.meta_parallel
import
LayerDesc
,
PipelineLayer
,
PipelineParallelWithInterleave
import
paddle.nn.functional
as
F
...
...
@@ -87,7 +87,8 @@ class TestPipeLayerAPI(unittest.TestCase):
try
:
model_chunks
[
0
](
paddle
.
to_tensor
([
1.
,
2.
]))
except
NotImplementedError
:
raise
NotImplementedError
except
PermissionError
:
pass
# fake call for the forward function of virtual pipeline layer
...
...
@@ -102,6 +103,7 @@ class TestPipeLayerAPI(unittest.TestCase):
# just make sure the model can be wrapped with distributed model
dist_model
=
fleet
.
distributed_model
(
pipe_model
)
assert
isinstance
(
dist_model
,
PipelineParallelWithInterleave
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py
0 → 100644
浏览文件 @
72b5b5bf
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
paddle
import
numpy
as
np
import
random
import
paddle.distributed
as
dist
import
paddle.distributed.fleet
as
fleet
from
paddle.fluid
import
layers
import
paddle.nn.functional
as
F
from
paddle.distributed.fleet.meta_parallel
import
PipelineLayer
,
LayerDesc
from
paddle.fluid.dygraph.layers
import
Layer
import
paddle.nn
as
nn
def
set_random_seed
(
seed
,
dp_id
,
rank_id
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
+
dp_id
)
paddle
.
seed
(
seed
+
dp_id
)
batch_size
=
8
length
=
8
micro_batch_size
=
2
num_virtual_pipeline_stages
=
2
vocab_size
=
128
hidden_size
=
16
d_model
=
hidden_size
dim_feedforward
=
4
*
d_model
class
EmbeddingNet
(
Layer
):
def
__init__
(
self
):
super
(
EmbeddingNet
,
self
).
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
hidden_size
)
self
.
position_embeddings
=
nn
.
Embedding
(
vocab_size
,
hidden_size
)
def
forward
(
self
,
x
):
attention_mask
=
paddle
.
tensor
.
triu
((
paddle
.
ones
(
(
length
,
length
),
dtype
=
"float32"
)
*
-
1e9
),
1
)
no_used
=
paddle
.
ones
((
3
,
3
),
dtype
=
"int32"
)
w_emb
=
self
.
word_embeddings
(
x
)
p_emb
=
self
.
position_embeddings
(
x
)
w_emb
=
w_emb
+
p_emb
attention_mask
.
stop_gradient
=
True
no_used
.
stop_gradient
=
True
# need to fix bug of backward()
return
w_emb
,
attention_mask
,
no_used
,
p_emb
class
TransformerNet
(
Layer
):
def
__init__
(
self
):
super
(
TransformerNet
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
)
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
)
self
.
q_proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
k_proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
v_proj
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
,
epsilon
=
1e-5
)
def
forward
(
self
,
x
,
mask
):
q
=
self
.
q_proj
(
x
)
k
=
self
.
k_proj
(
x
)
v
=
self
.
v_proj
(
x
)
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
d_model
**-
0.5
)
weights
=
F
.
softmax
(
product
+
mask
)
tgt
=
layers
.
matmul
(
weights
,
v
)
residual
=
tgt
tgt
=
self
.
norm1
(
tgt
)
tgt
=
residual
+
tgt
out
=
self
.
linear2
(
F
.
gelu
(
self
.
linear1
(
tgt
),
approximate
=
True
))
return
out
class
EmbeddingPipe
(
EmbeddingNet
):
def
forward
(
self
,
x
):
return
super
().
forward
(
x
)
class
TransformerNetPipe
(
TransformerNet
):
def
forward
(
self
,
args
):
x
,
mask
,
no_used
,
p_emb
=
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]
output
=
super
().
forward
(
x
,
mask
)
output
=
output
+
p_emb
mask
.
stop_gradient
=
True
return
output
,
mask
,
no_used
,
p_emb
class
CriterionPipe
(
Layer
):
def
__init__
(
self
):
super
(
CriterionPipe
,
self
).
__init__
()
def
forward
(
self
,
out
,
label
):
loss
=
out
.
mean
()
return
loss
class
ModelPipe
(
PipelineLayer
):
def
__init__
(
self
,
topology
):
self
.
descs
=
[]
self
.
descs
.
append
(
LayerDesc
(
EmbeddingPipe
))
for
x
in
range
(
8
):
self
.
descs
.
append
(
LayerDesc
(
TransformerNetPipe
))
self
.
descs
.
append
(
lambda
x
:
x
[
0
])
super
().
__init__
(
layers
=
self
.
descs
,
loss_fn
=
CriterionPipe
(),
topology
=
topology
,
num_virtual_pipeline_stages
=
num_virtual_pipeline_stages
,
seg_method
=
"layer:TransformerNetPipe"
)
class
TestDistPPTraning
(
unittest
.
TestCase
):
def
setUp
(
self
):
strategy
=
fleet
.
DistributedStrategy
()
self
.
model_parallel_size
=
1
self
.
data_parallel_size
=
1
self
.
pipeline_parallel_size
=
2
strategy
.
hybrid_configs
=
{
"dp_degree"
:
self
.
data_parallel_size
,
"mp_degree"
:
self
.
model_parallel_size
,
"pp_degree"
:
self
.
pipeline_parallel_size
,
}
strategy
.
pipeline_configs
=
{
"accumulate_steps"
:
batch_size
//
micro_batch_size
,
"micro_batch_size"
:
micro_batch_size
}
fleet
.
init
(
is_collective
=
True
,
strategy
=
strategy
)
def
test_pp_model
(
self
):
hcg
=
fleet
.
get_hybrid_communicate_group
()
word_size
=
hcg
.
get_model_parallel_world_size
()
dp_id
=
hcg
.
get_data_parallel_rank
()
pp_id
=
hcg
.
get_stage_id
()
rank_id
=
dist
.
get_rank
()
topology
=
hcg
.
topology
()
set_random_seed
(
1024
,
dp_id
,
rank_id
)
model
=
ModelPipe
(
topology
)
scheduler
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
[
2
],
values
=
[
0.001
,
0.002
],
verbose
=
True
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler
,
parameters
=
model
.
parameters
())
model
=
fleet
.
distributed_model
(
model
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
for
step_id
in
range
(
5
):
x_data
=
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
[
batch_size
,
length
])
x
=
paddle
.
to_tensor
(
x_data
)
x
.
stop_gradient
=
True
e_loss
=
model
.
eval_batch
([
x
,
x
],
True
)
loss
=
model
.
train_batch
([
x
,
x
],
optimizer
,
scheduler
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
e_loss
.
numpy
())
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py
→
python/paddle/fluid/tests/unittests/
collective/fleet/
test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py
浏览文件 @
72b5b5bf
...
...
@@ -25,8 +25,10 @@ class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
def
test_hybrid_parallel_pp_layer_with_virtual_stage
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_layer_with_virtual_stage.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_layer_with_virtual_stage.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_pp_transformer_with_virtual_stage
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer_with_virtual_stage.py'
)
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv
浏览文件 @
72b5b5bf
...
...
@@ -16,6 +16,7 @@ test_fleet_graph_execution_meta_optimizer,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,../../
test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_fleet_graph_executor,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,1,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录