Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5c9fce0e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c9fce0e
编写于
7月 01, 2021
作者:
S
ShenLiang
提交者:
GitHub
7月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add p2p (#33864)
上级
c522530a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
126 addition
and
75 deletion
+126
-75
python/paddle/distributed/fleet/base/topology.py
python/paddle/distributed/fleet/base/topology.py
+21
-0
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+35
-75
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+70
-0
未找到文件。
python/paddle/distributed/fleet/base/topology.py
浏览文件 @
5c9fce0e
...
...
@@ -164,9 +164,27 @@ class HybridCommunicateGroup(object):
self
.
_dp_group
,
self
.
_check_group
)
logger
.
info
(
debug_str
)
# create p2p_groups and no new group
self
.
_p2p_groups
=
self
.
_build_p2p_lists
()
global
_HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP
=
self
def
_build_p2p_lists
(
self
):
comm_lists
=
self
.
_topo
.
get_comm_list
(
'pipe'
)
p2p_lists
=
[]
for
rank
in
range
(
self
.
nranks
):
for
comm_ranks
in
comm_lists
:
assert
len
(
comm_ranks
)
==
self
.
_pp_degree
if
rank
in
comm_ranks
:
idx
=
comm_ranks
.
index
(
rank
)
next_rank
=
comm_ranks
[(
idx
+
1
)
%
self
.
_pp_degree
]
p2p_lists
.
append
([
rank
,
next_rank
])
break
assert
len
(
p2p_lists
)
==
self
.
nranks
,
"len(p2p_lists) should be equal nranks"
return
p2p_lists
def
get_parallel_mode
(
self
):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
...
...
@@ -286,6 +304,9 @@ class HybridCommunicateGroup(object):
# TODO should the src rank related to the shard rank for each parameter ?
return
self
.
_sharding_comm_group
.
ranks
[
0
]
def
get_p2p_groups
(
self
):
return
self
.
_p2p_groups
# check parallel group
def
get_check_parallel_group
(
self
):
return
self
.
_check_comm_group
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
5c9fce0e
...
...
@@ -24,6 +24,7 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from
..utils.hybrid_parallel_util
import
broadcast_dp_parameters
from
..utils.log_util
import
logger
from
..meta_optimizers.dygraph_optimizer
import
HybridParallelOptimizer
from
.pp_utils
import
p2p_communication
as
p2p
__all__
=
[]
...
...
@@ -63,6 +64,7 @@ class PipelineParallel(MetaParallelBase):
self
.
prev_stage_id
=
self
.
stage_id
-
1
self
.
next_stage_id
=
self
.
stage_id
+
1
self
.
pp_group
=
self
.
_hcg
.
get_pipe_parallel_group
()
p2p
.
initialize_p2p_groups
(
hcg
)
self
.
is_first_stage
=
self
.
stage_id
==
0
self
.
is_last_stage
=
(
self
.
stage_id
==
(
self
.
num_stages
-
1
))
...
...
@@ -275,97 +277,86 @@ class PipelineParallel(MetaParallelBase):
if
isinstance
(
data
,
paddle
.
Tensor
):
tensor_type
=
paddle
.
to_tensor
([
0
])
# send tensor type
paddle
.
distributed
.
send
(
tensor_type
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
tensor_type
,
self
.
next_stage_id
)
# send len(shape)
dims
=
paddle
.
to_tensor
(
len
(
data
.
shape
))
paddle
.
distributed
.
send
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
dims
,
self
.
next_stage_id
)
# send shape
shape
=
paddle
.
to_tensor
(
data
.
shape
)
paddle
.
distributed
.
send
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
shape
,
self
.
next_stage_id
)
# send dtype
dtype
=
paddle
.
to_tensor
(
paddle_2_number
(
data
.
dtype
))
paddle
.
distributed
.
send
(
dtype
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
dtype
,
self
.
next_stage_id
)
elif
isinstance
(
data
,
tuple
):
tensor_type
=
paddle
.
to_tensor
([
1
])
p
addle
.
distributed
.
send
(
tensor_type
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p
2p
.
send
(
tensor_type
,
self
.
next_stage_id
)
nums
=
paddle
.
to_tensor
(
len
(
data
))
p
addle
.
distributed
.
send
(
nums
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p
2p
.
send
(
nums
,
self
.
next_stage_id
)
for
idx
,
d
in
enumerate
(
data
):
assert
isinstance
(
d
,
paddle
.
Tensor
)
# send len(shape)
dims
=
paddle
.
to_tensor
(
len
(
d
.
shape
))
paddle
.
distributed
.
send
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
dims
,
self
.
next_stage_id
)
# send shape
shape
=
paddle
.
to_tensor
(
d
.
shape
)
paddle
.
distributed
.
send
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
shape
,
self
.
next_stage_id
)
# send dtype
dtype
=
paddle
.
to_tensor
(
paddle_2_number
(
d
.
dtype
))
paddle
.
distributed
.
send
(
dtype
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
dtype
,
self
.
next_stage_id
)
def
_recv_meta
(
self
,
peer
):
tensor_type
=
paddle
.
to_tensor
([
0
])
p
addle
.
distributed
.
recv
(
tensor_type
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p
2p
.
recv
(
tensor_type
,
self
.
prev_stage_id
)
tensor_type
=
tensor_type
.
item
()
if
tensor_type
==
0
:
# recv len(shape)
dims
=
paddle
.
to_tensor
([
0
])
p
addle
.
distributed
.
recv
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p
2p
.
recv
(
dims
,
self
.
prev_stage_id
)
dims
=
dims
.
item
()
# recv shape
shape
=
paddle
.
to_tensor
([
0
]
*
dims
)
p
addle
.
distributed
.
recv
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p
2p
.
recv
(
shape
,
self
.
prev_stage_id
)
shape
=
shape
.
numpy
().
tolist
()
# recv dtype
dtype
=
paddle
.
to_tensor
([
0
])
p
addle
.
distributed
.
recv
(
dtype
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p
2p
.
recv
(
dtype
,
self
.
prev_stage_id
)
return
self
.
_allocate_cache
(
shape
,
dtype
=
number_2_dtype
(
dtype
.
item
()),
num_caches
=
1
)[
0
]
elif
tensor_type
==
1
:
num
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
num
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
num
,
self
.
prev_stage_id
)
num
=
num
.
item
()
shapes
=
[]
dtypes
=
[]
for
i
in
range
(
num
):
# recv len(shape)
dims
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
dims
,
self
.
prev_stage_id
)
# recv shape
dims
=
dims
.
item
()
shape
=
paddle
.
to_tensor
([
0
]
*
dims
)
paddle
.
distributed
.
recv
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
shape
,
self
.
prev_stage_id
)
shapes
.
append
(
shape
.
numpy
().
tolist
())
# recv dtype
dtype
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
dtype
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
dtype
,
self
.
prev_stage_id
)
dtypes
.
append
(
number_2_dtype
(
dtype
.
item
()))
caches
=
self
.
_allocate_caches
(
shapes
,
dtypes
,
num_caches
=
1
)[
0
]
...
...
@@ -380,39 +371,25 @@ class PipelineParallel(MetaParallelBase):
self
.
_send_meta
(
outputs
,
self
.
next_stage_id
)
if
isinstance
(
outputs
,
paddle
.
Tensor
):
paddle
.
distributed
.
send
(
outputs
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
outputs
,
self
.
next_stage_id
)
elif
isinstance
(
outputs
,
tuple
):
for
output
in
outputs
:
paddle
.
distributed
.
send
(
output
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
output
,
self
.
next_stage_id
)
def
_send_gradients
(
self
,
cache_id
):
inputs
=
self
.
caches
[
'inputs'
][
cache_id
]
if
isinstance
(
inputs
,
paddle
.
Tensor
):
assert
inputs
.
grad
is
not
None
paddle
.
distributed
.
send
(
paddle
.
to_tensor
(
inputs
.
grad
),
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
inputs
.
grad
,
self
.
prev_stage_id
)
else
:
for
idx
,
d
in
enumerate
(
inputs
):
# Skip tensors that will not produce a grad
if
not
is_float_tensor
(
d
):
assert
d
.
grad
is
None
continue
paddle
.
distributed
.
send
(
d
.
grad
,
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
send
(
d
.
grad
,
self
.
prev_stage_id
)
self
.
caches
[
'inputs'
][
cache_id
]
=
None
def
_recv_activations
(
self
,
cache_id
):
...
...
@@ -421,11 +398,7 @@ class PipelineParallel(MetaParallelBase):
self
.
recv_cache
=
self
.
_recv_meta
(
self
.
prev_stage_id
)
if
isinstance
(
self
.
recv_cache
,
paddle
.
Tensor
):
paddle
.
distributed
.
recv
(
self
.
recv_cache
,
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
self
.
recv_cache
,
self
.
prev_stage_id
)
inputs
=
self
.
recv_cache
.
clone
().
detach
()
inputs
.
stop_gradient
=
not
is_float_tensor
(
inputs
)
else
:
...
...
@@ -433,12 +406,7 @@ class PipelineParallel(MetaParallelBase):
inputs
=
[
None
]
*
len
(
self
.
recv_cache
)
for
idx
,
d
in
enumerate
(
self
.
recv_cache
):
assert
isinstance
(
d
,
paddle
.
Tensor
)
paddle
.
distributed
.
recv
(
d
,
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
d
,
self
.
prev_stage_id
)
inputs
[
idx
]
=
d
.
clone
().
detach
()
inputs
=
tuple
(
inputs
)
...
...
@@ -466,19 +434,11 @@ class PipelineParallel(MetaParallelBase):
sizes
,
dtypes
,
num_caches
=
1
)[
0
]
if
isinstance
(
self
.
grad_tensors
,
paddle
.
Tensor
):
paddle
.
distributed
.
recv
(
self
.
grad_tensors
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
self
.
grad_tensors
,
self
.
next_stage_id
)
else
:
assert
isinstance
(
outputs
,
tuple
)
for
d
in
self
.
grad_tensors
:
paddle
.
distributed
.
recv
(
d
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
p2p
.
recv
(
d
,
self
.
next_stage_id
)
def
_step
(
self
):
self
.
optimizer
.
step
()
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
0 → 100644
浏览文件 @
5c9fce0e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.distributed
as
dist
_groups
=
None
_hcg
=
None
def
initialize_p2p_groups
(
hcg
):
global
_groups
,
_hcg
_groups
=
[
dist
.
new_group
(
ranks
=
group
)
for
group
in
hcg
.
get_p2p_groups
()]
_hcg
=
hcg
def
send
(
tensor
,
dest_stage
):
global
_groups
,
_hcg
src_stage
=
_hcg
.
get_stage_id
()
src_rank
=
_hcg
.
get_rank_from_stage
(
stage_id
=
src_stage
)
_is_valid_communciate
(
src_stage
,
dest_stage
)
group
=
_get_send_recv_group
(
src_stage
,
dest_stage
)
dst_rank
=
_hcg
.
get_rank_from_stage
(
stage_id
=
dest_stage
)
return
dist
.
broadcast
(
tensor
,
src_rank
,
group
=
group
)
def
recv
(
tensor
,
src_stage
):
global
_groups
,
_hcg
dest_stage
=
_hcg
.
get_stage_id
()
_is_valid_communciate
(
src_stage
,
dest_stage
)
group
=
_get_send_recv_group
(
src_stage
,
dest_stage
)
src_rank
=
_hcg
.
get_rank_from_stage
(
stage_id
=
src_stage
)
return
dist
.
broadcast
(
tensor
,
src_rank
,
group
=
group
)
def
_is_valid_communciate
(
src_stage
,
dest_stage
):
first_stage
=
0
last_stage
=
_hcg
.
get_pipe_parallel_world_size
()
-
1
assert
abs
(
src_stage
-
dest_stage
)
==
1
or
\
(
src_stage
==
first_stage
and
dest_stage
==
last_stage
)
or
\
(
src_stage
==
last_stage
and
dest_stage
==
first_stage
)
def
_get_send_recv_group
(
src_stage
,
dest_stage
):
global
_groups
,
_hcg
stage_id
=
None
first_stage
=
0
last_stage
=
_hcg
.
get_pipe_parallel_world_size
()
-
1
if
(
src_stage
==
first_stage
and
dest_stage
==
last_stage
)
or
\
(
dest_stage
==
first_stage
and
src_stage
==
last_stage
):
stage_id
=
last_stage
elif
src_stage
>
dest_stage
:
stage_id
=
dest_stage
else
:
stage_id
=
src_stage
group_id
=
_hcg
.
get_rank_from_stage
(
stage_id
=
stage_id
)
return
_groups
[
group_id
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录