Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e494b73b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e494b73b
编写于
3月 30, 2022
作者:
C
caozhou
提交者:
GitHub
3月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix reshard bug (#41106)
上级
ee8eeb45
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
12 addition
and
13 deletion
+12
-13
python/paddle/distributed/auto_parallel/planner.py
python/paddle/distributed/auto_parallel/planner.py
+9
-12
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+3
-1
未找到文件。
python/paddle/distributed/auto_parallel/planner.py
浏览文件 @
e494b73b
...
...
@@ -15,7 +15,6 @@
import
copy
import
time
import
random
import
logging
from
functools
import
reduce
from
itertools
import
chain
,
product
from
collections
import
OrderedDict
...
...
@@ -741,7 +740,7 @@ class MCMC(SearchAlgorithm):
return
best_dist_context
,
min_cost
def
search
(
self
):
logging
.
info
(
"Start MCMC searching."
)
print
(
"Start MCMC searching."
)
start_time
=
time
.
time
()
train_program
=
self
.
serial_program_info
.
train_program
cluster
=
self
.
serial_program_info
.
cluster
...
...
@@ -757,9 +756,8 @@ class MCMC(SearchAlgorithm):
searched_pipeline_dist_context
=
None
pipeline_min_cost
=
None
for
process_mesh_topology
in
process_mesh_topology_list
:
logging
.
info
(
"MCMC search: search process mesh {} with pipeline mode."
.
format
(
process_mesh_topology
))
print
(
"MCMC search: search process mesh {} with pipeline mode."
.
format
(
process_mesh_topology
))
valid_dist_attr_dict
,
pipeline_process_meshes
,
global_process_mesh
=
PlanSpace
.
enum_valid_dist_attr_for_program
(
train_program
,
process_mesh_topology
,
True
)
init_dist_context
=
self
.
init_program
(
...
...
@@ -768,7 +766,7 @@ class MCMC(SearchAlgorithm):
best_dist_context
,
cost
=
self
.
_search_core
(
valid_dist_attr_dict
,
init_dist_context
,
pipeline_process_meshes
)
logging
.
info
(
print
(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode."
.
format
(
cost
,
process_mesh_topology
))
best_dist_context
.
_dist_op_context
=
DistributedOperatorContext
()
...
...
@@ -784,9 +782,8 @@ class MCMC(SearchAlgorithm):
# if process_mesh_topology shape is 3, include pipeline mode by default
if
len
(
process_mesh_topology
)
==
3
:
continue
logging
.
info
(
"MCMC search: search process mesh {} without pipeline mode."
.
format
(
process_mesh_topology
))
print
(
"MCMC search: search process mesh {} without pipeline mode."
.
format
(
process_mesh_topology
))
valid_dist_attr_dict
,
pipeline_process_meshes
,
global_process_mesh
=
PlanSpace
.
enum_valid_dist_attr_for_program
(
train_program
,
process_mesh_topology
,
False
)
init_dist_context
=
self
.
init_program
(
...
...
@@ -795,7 +792,7 @@ class MCMC(SearchAlgorithm):
best_dist_context
,
cost
=
self
.
_search_core
(
valid_dist_attr_dict
,
init_dist_context
,
pipeline_process_meshes
)
logging
.
info
(
print
(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode."
.
format
(
cost
,
process_mesh_topology
))
best_dist_context
.
_dist_op_context
=
DistributedOperatorContext
()
...
...
@@ -808,7 +805,7 @@ class MCMC(SearchAlgorithm):
if
non_pipeline_min_cost
>
pipeline_min_cost
:
searched_dist_context
=
searched_pipeline_dist_context
min_cost
=
pipeline_min_cost
logging
.
info
(
print
(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
)
else
:
...
...
@@ -820,7 +817,7 @@ class MCMC(SearchAlgorithm):
for
process_mesh
in
searched_dist_context
.
_process_meshes
:
pg0
.
add_ranks
(
process_mesh
.
processes
)
end_time
=
time
.
time
()
logging
.
info
(
print
(
"End MCMC searching: the min cost is {} and the search time is {}s."
.
format
(
min_cost
,
end_time
-
start_time
))
return
searched_dist_context
,
min_cost
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
e494b73b
...
...
@@ -1239,7 +1239,9 @@ class Resharder:
for
item
in
self
.
has_allgather
[
var_name
]:
if
op_desc
.
group
==
item
[
0
]:
tensor_list
=
[
program
.
global_block
().
vars
[
var_name
]
get_var_with_recursion
(
var_name
,
block
,
self
.
auto_parallel_main_prog
)
for
var_name
in
item
[
1
]
]
break
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录