Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5ce58d57
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5ce58d57
编写于
7月 16, 2021
作者:
W
WangXi
提交者:
GitHub
7月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid check] improve pipeline stage check (#34193)
上级
4e5cb7d8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
7 deletion
+12
-7
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+12
-7
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
5ce58d57
...
...
@@ -4663,6 +4663,7 @@ class PipelineOptimizer(object):
pre_stage_id
=
None
decrease_flag
=
False
in_optimize
=
False
in_forward
=
True
for
op
in
block
.
ops
:
if
not
op
.
_has_kernel
(
op
.
type
):
assert
op
.
type
==
"conditional_block"
and
(
...
...
@@ -4680,6 +4681,8 @@ class PipelineOptimizer(object):
valid_op_role_value
)
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Optimize
):
in_optimize
=
True
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Backward
):
in_forward
=
False
assert
op
.
has_attr
(
self
.
_op_device_key
),
(
"op ({}) has no {} attribute."
.
format
(
op
.
type
,
...
...
@@ -4707,14 +4710,16 @@ class PipelineOptimizer(object):
"but the interval of op={} and prev op is {}"
.
format
(
op
,
interval
)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if
interval
==
-
1
:
decrease_flag
=
True
if
interval
==
1
:
# FIXME(wangxi): recompute failed
if
in_forward
:
assert
interval
>=
0
,
\
"Pipeline stage must be sequential increment in Forward, prev_stage={}, "
\
"please check the stage of op={}"
.
format
(
pre_stage_id
,
op
)
else
:
# FIXME(wangxi): recompute check failed
pass
#assert
decrease_flag is False
, \
# "Pipeline stage must be
in order
, " \
# "please check the stage of op={}".format(op)
#assert
interval <=0
, \
# "Pipeline stage must be
sequential decrement in Backward, prev_stage={}
, " \
# "please check the stage of op={}".format(
pre_stage_id,
op)
pre_stage_id
=
stage_id
return
device_list
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录