Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ebbd3564
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ebbd3564
编写于
12月 23, 2021
作者:
J
JZ-LIANG
提交者:
GitHub
12月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove unitest for auto_searcher (#38370)
上级
4d5a6064
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
0 addition
and
35 deletion
+0
-35
python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py
...ddle/fluid/tests/unittests/test_auto_parallel_searcher.py
+0
-35
未找到文件。
python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py
100644 → 100755
浏览文件 @
ebbd3564
...
...
@@ -212,41 +212,6 @@ class TestMLPSearcher(unittest.TestCase):
self
.
assertTrue
(
check_nonpipeline_enumerater
(
train_program
,
process_mesh_topology
))
def
test_get_dist_programs
(
self
):
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
loss
,
train_program
,
startup_program
=
mlp_forward
(
train_program
,
startup_program
)
process_mesh_topology
=
[
4
]
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
valid_dist_attr_dict
,
pipeline_process_meshes
,
global_process_mesh
=
PlanSpace
.
enum_valid_dist_attr_for_program
(
train_program
,
process_mesh_topology
,
False
)
from
test_auto_parallel_cluster
import
cluster_json
cluster_json_file
=
""
cluster_json_object
=
json
.
loads
(
cluster_json
)
with
open
(
"./auto_parallel_cluster.json"
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
cluster
=
Cluster
()
cluster
.
build_from_file
(
"./auto_parallel_cluster.json"
)
os
.
remove
(
"./auto_parallel_cluster.json"
)
ops
=
train_program
.
global_block
().
ops
vars
=
train_program
.
global_block
().
vars
new_dist_context
=
DistributedContext
()
set_default_dist_attr
(
train_program
,
new_dist_context
,
global_process_mesh
)
serial_program_info
=
SerialProgramInfo
(
train_program
,
startup_program
,
loss
,
optimizer
,
cluster
)
result
=
get_all_distributed_main_program
(
serial_program_info
,
new_dist_context
)
self
.
assertEqual
(
len
(
result
),
4
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录