Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9b6c7eb9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9b6c7eb9
编写于
8月 03, 2021
作者:
S
ShenLiang
提交者:
GitHub
8月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel] Support segment for PipelineParallel (#34529)
* add layer segment * add segement for transformer * add utest
上级
2714fc7e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
78 addition
and
23 deletion
+78
-23
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+73
-21
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py
...e/fluid/tests/unittests/hybrid_parallel_pp_transformer.py
+5
-2
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
9b6c7eb9
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
math
import
paddle
import
re
from
paddle.fluid.dygraph.layers
import
Layer
from
...utils.log_util
import
logger
,
layer_to_str
from
functools
import
partial
...
...
@@ -20,27 +21,6 @@ from functools import partial
__all__
=
[]
class
SegmentLayers
(
object
):
def
__init__
(
self
,
layers_desc
,
num_parts
,
method
=
"uniform"
):
self
.
_layers_desc
=
layers_desc
self
.
method
=
method
self
.
num_parts
=
num_parts
self
.
num_items
=
len
(
layers_desc
)
assert
self
.
num_items
>=
self
.
num_parts
,
"layer number should be greater than number of segments"
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
return
self
.
uniform
(
self
.
num_items
,
self
.
num_parts
)
def
uniform
(
self
,
num_items
,
num_parts
):
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
part_size
=
math
.
floor
(
num_items
/
num_parts
)
for
i
in
range
(
num_parts
):
result
[
i
]
=
int
(
min
(
part_size
*
i
,
num_items
))
result
[
num_parts
]
=
num_items
return
result
class
LayerDesc
(
object
):
def
__init__
(
self
,
layer_func
,
*
inputs
,
**
kwargs
):
self
.
layer_func
=
layer_func
...
...
@@ -73,6 +53,75 @@ class SharedLayerDesc(LayerDesc):
self
.
shared_weight_attr
=
shared_weight_attr
class
SegmentLayers
(
object
):
def
__init__
(
self
,
layers_desc
,
num_parts
,
method
=
"uniform"
):
self
.
_layers_desc
=
layers_desc
self
.
method
=
method
self
.
num_parts
=
num_parts
self
.
num_items
=
len
(
layers_desc
)
assert
self
.
num_items
>=
self
.
num_parts
,
"layer number should be greater than number of segments"
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
return
self
.
uniform
(
self
.
num_items
,
self
.
num_parts
)
elif
self
.
method
.
startswith
(
'layer:'
):
# Divide equally according to the specified layer
layername
=
self
.
method
.
split
(
':'
)[
1
]
weights
=
[
0
]
*
len
(
self
.
_layers_desc
)
weight_idxs
=
self
.
_gen_layer_weight
(
layername
)
for
idx
in
weight_idxs
:
weights
[
idx
]
=
1
assert
sum
(
weights
)
%
self
.
num_parts
==
0
,
"number of layers ({}) should be divided by part number({})"
.
format
(
sum
(
weights
),
self
.
num_parts
)
part_size
=
sum
(
weights
)
//
self
.
num_parts
result
=
[
0
for
_
in
range
(
self
.
num_parts
+
1
)]
memory_counter
=
0
result_idx
=
1
for
idx
,
weight
in
enumerate
(
weights
):
memory_counter
+=
weight
if
memory_counter
==
part_size
:
result
[
result_idx
]
=
idx
+
1
result_idx
+=
1
memory_counter
=
0
result
[
self
.
num_parts
]
=
len
(
weights
)
return
result
def
_gen_layer_weight
(
self
,
layername
):
weight_idxs
=
[]
regex
=
re
.
compile
(
layername
,
re
.
IGNORECASE
)
for
idx
,
layer
in
enumerate
(
self
.
_layers_desc
):
name
=
None
if
isinstance
(
layer
,
Layer
):
name
=
layer
.
__class__
.
__name__
elif
isinstance
(
layer
,
LayerDesc
):
name
=
layer
.
layer_func
.
__name__
else
:
try
:
name
=
layer
.
__name__
except
AttributeError
:
# it is not error
continue
if
regex
.
search
(
name
):
weight_idxs
.
append
(
idx
)
assert
len
(
weight_idxs
)
>
0
,
"weight_idxs' length should be greater than 0"
return
weight_idxs
def
uniform
(
self
,
num_items
,
num_parts
):
result
=
[
0
for
_
in
range
(
num_parts
+
1
)]
part_size
=
math
.
floor
(
num_items
/
num_parts
)
for
i
in
range
(
num_parts
):
result
[
i
]
=
int
(
min
(
part_size
*
i
,
num_items
))
result
[
num_parts
]
=
num_items
return
result
class
PipelineLayer
(
Layer
):
def
__init__
(
self
,
layers
,
...
...
@@ -205,6 +254,9 @@ class PipelineLayer(Layer):
self
.
_layers_desc
,
num_parts
=
self
.
_num_stages
,
method
=
seg_method
)
self
.
segment_parts
=
seg
.
do_segment
()
logger
.
info
(
"segment result:"
+
", "
.
join
(
str
(
arg
)
for
arg
in
self
.
segment_parts
))
self
.
_start_pos
=
self
.
segment_parts
[
self
.
_stage_id
]
self
.
_end_pos
=
self
.
segment_parts
[
self
.
_stage_id
+
1
]
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py
浏览文件 @
9b6c7eb9
...
...
@@ -121,13 +121,16 @@ class ModelPipe(PipelineLayer):
self
.
descs
=
[]
self
.
descs
.
append
(
LayerDesc
(
EmbeddingPipe
))
for
x
in
range
(
5
):
for
x
in
range
(
6
):
self
.
descs
.
append
(
LayerDesc
(
TransformerNetPipe
))
self
.
descs
.
append
(
lambda
x
:
x
[
0
])
super
().
__init__
(
layers
=
self
.
descs
,
loss_fn
=
CriterionPipe
(),
topology
=
topology
)
layers
=
self
.
descs
,
loss_fn
=
CriterionPipe
(),
topology
=
topology
,
seg_method
=
"layer:TransformerNetPipe"
)
class
TestDistPPTraning
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录