Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f6985774
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f6985774
编写于
10月 21, 2021
作者:
Y
YipZLF
提交者:
GitHub
10月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed unit test for auto parallel cost model (#36574)
上级
1d38a013
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
25 addition
and
28 deletion
+25
-28
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
...le/fluid/tests/unittests/test_auto_parallel_cost_model.py
+25
-28
未找到文件。
python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py
浏览文件 @
f6985774
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
copy
import
paddle
import
paddle.nn
as
nn
import
paddle.static
as
static
...
...
@@ -141,28 +142,24 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss
,
train_program
,
startup_program
=
mlp_forward
(
train_program
,
startup_program
)
dist_strategy
=
fleet
.
DistributedStrategy
()
# auto completion
complete_train_program
=
auto
.
complete_annotation
(
train_program
,
dist_context
)
partitioner
=
Partitioner
(
dist_strategy
,
dist_context
,
rank_id
)
# logical partition
auto_parallel_main_prog
,
auto_parallel_startup_prog
=
partitioner
.
transpile_forward
(
complete_train_program
,
startup_program
)
dist_params_grads
=
partitioner
.
apply_backward
(
loss
,
complete_train_program
,
startup_program
,
auto_parallel_main_prog
,
auto_parallel_startup_prog
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
()
opt_ops
=
partitioner
.
apply_optimize
(
optimizer
,
dist_params_grads
,
auto_parallel_main_prog
,
auto_parallel_startup_prog
)
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_main_prog
=
[]
dist_startup_prog
=
[]
for
rank_id
in
range
(
NUM_RANKS
):
partitioner
=
Partitioner
(
dist_strategy
,
dist_context
,
rank_id
)
# logical partition
auto_parallel_main_prog
,
auto_parallel_startup_prog
=
partitioner
.
transpile_forward
(
complete_train_program
,
startup_program
)
dist_params_grads
=
partitioner
.
apply_backward
(
loss
,
complete_train_program
,
startup_program
,
auto_parallel_main_prog
,
auto_parallel_startup_prog
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
()
opt_ops
=
partitioner
.
apply_optimize
(
optimizer
,
dist_params_grads
,
auto_parallel_main_prog
,
auto_parallel_startup_prog
)
dist_main_prog
.
append
(
auto_parallel_main_prog
)
dist_startup_prog
.
append
(
auto_parallel_startup_prog
)
return
dist_main_prog
,
dist_startup_prog
return
auto_parallel_main_prog
,
auto_parallel_startup_prog
def
check_runtime_estimation
(
cost
):
...
...
@@ -210,20 +207,20 @@ class TestCostModel(unittest.TestCase):
self
.
assertTrue
(
check_empty_program_memory
(
cost
))
def
test_auto_parallel_cost_model
(
self
):
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
standalone_cost_data
=
get_single_node_data
()
distributed_program
,
dist_startup_prog
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
0
)
dist_program
=
[]
for
rank_id
in
range
(
NUM_RANKS
):
complete_backward_annotation
(
distributed_program
[
rank_id
],
dist_context
)
reshard
(
distributed_program
[
rank_id
],
dist_startup_prog
[
rank_id
],
rank_id
,
dist_context
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
distributed_program
,
dist_startup_prog
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
reshard
(
distributed_program
,
dist_startup_prog
,
rank_id
,
dist_context
)
dist_program
.
append
(
distributed_program
)
cluster
=
None
cost
=
estimate_cost
(
dist
ributed
_program
,
dist_program
,
cluster
=
cluster
,
pipeline_config
=
pp_cfg
,
standalone_cost_data
=
standalone_cost_data
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录