Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6db6a347
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2301
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6db6a347
编写于
4月 27, 2023
作者:
S
ShenLiang
提交者:
GitHub
4月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add segment methods for pp (#53368)
add utest fix utest
上级
27016144
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
47 addition
and
1 deletion
+47
-1
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+33
-1
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
.../paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
+14
-0
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
6db6a347
...
...
@@ -109,7 +109,37 @@ class SegmentLayers:
),
"layer number should be greater than number of segments"
def
do_segment
(
self
):
if
self
.
method
==
"uniform"
:
if
isinstance
(
self
.
method
,
list
):
seg_method
=
self
.
method
[:]
source_num_parts
=
len
(
seg_method
)
-
1
def
check_sanity
():
assert
seg_method
[
0
]
==
0
,
"seg_method[0] should be 0"
for
part
in
seg_method
:
assert
isinstance
(
part
,
int
),
"part should be int"
assert
part
>=
0
,
f
"part[
{
part
}
] should be greater than 0"
assert
(
part
<=
self
.
num_items
),
"part[{}] should be less than num_items[{}]"
.
format
(
part
,
self
.
num_items
)
check_sanity
()
if
self
.
num_parts
==
source_num_parts
+
1
:
seg_method
.
append
(
self
.
num_items
)
return
seg_method
elif
self
.
num_parts
==
source_num_parts
:
return
seg_method
else
:
raise
ValueError
(
"We set seg_method as {}, this length is {}, but the number of stages is {}"
.
format
(
seg_method
,
len
(
seg_method
),
self
.
num_parts
)
)
elif
self
.
method
==
"uniform"
:
return
self
.
uniform
(
self
.
num_items
,
self
.
num_parts
)
elif
self
.
method
.
startswith
(
'layer:'
):
...
...
@@ -144,6 +174,8 @@ class SegmentLayers:
memory_counter
=
0
result
[
actual_num_parts
]
=
len
(
weights
)
return
result
else
:
raise
ValueError
(
f
"method
{
self
.
method
}
is not supported"
)
def
_gen_layer_weight
(
self
,
layername
):
weight_idxs
=
[]
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
浏览文件 @
6db6a347
...
...
@@ -136,6 +136,20 @@ class TestPipeLayerAPI(unittest.TestCase):
np
.
testing
.
assert_array_equal
(
param_a
.
name
,
param_b
.
name
)
np
.
testing
.
assert_allclose
(
param_a
.
numpy
(),
param_b
.
numpy
())
def
test_pipelayer_segment_method
(
self
):
init_net
=
AlexNetPipe
()
pipe_model
=
PipelineLayer
(
layers
=
init_net
.
to_layers
(),
num_stages
=
self
.
pipeline_parallel_size
,
seg_method
=
[
0
,
4
],
loss_fn
=
nn
.
CrossEntropyLoss
(),
)
stage_id
=
self
.
hcg
.
get_stage_id
()
if
stage_id
==
0
:
np
.
testing
.
assert_array_equal
(
len
(
pipe_model
.
parameters
()),
4
)
elif
stage_id
==
1
:
np
.
testing
.
assert_array_equal
(
len
(
pipe_model
.
parameters
()),
8
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录