Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8ad90558
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8ad90558
编写于
8月 23, 2018
作者:
C
chengduo
提交者:
GitHub
8月 23, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add is_test for while_op (#12874)
* add is_test for while_op * Change API
上级
c6f212a3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
13 addition
and
3 deletion
+13
-3
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+7
-0
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+5
-2
未找到文件。
paddle/fluid/API.spec
浏览文件 @
8ad90558
...
...
@@ -191,7 +191,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', '
name'], varargs=None, keywords=None, defaults=(None,
))
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', '
is_test', 'name'], varargs=None, keywords=None, defaults=(False, None
))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/operators/while_op.cc
浏览文件 @
8ad90558
...
...
@@ -58,11 +58,15 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
cond
.
place
()),
"Condition of while op must in CPU memory."
);
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
());
while
(
cond
.
data
<
bool
>
()[
0
])
{
auto
&
current_scope
=
scope
.
NewScope
();
step_scopes
->
push_back
(
&
current_scope
);
executor
.
RunPreparedContext
(
ctx
.
get
(),
&
current_scope
,
false
);
if
(
is_test
)
{
scope
.
DeleteScope
(
&
current_scope
);
}
}
}
};
...
...
@@ -88,6 +92,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"variables generated in the i'th step."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kStepBlock
,
"The step block inside WhileOp"
);
AddAttr
<
bool
>
(
"is_test"
,
"True if in test phase."
).
SetDefault
(
false
);
AddComment
(
R"DOC(
)DOC"
);
}
...
...
@@ -103,6 +108,8 @@ class WhileGradOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE
(
!
Attr
<
bool
>
(
"is_test"
),
"GradOp is only callable when is_test is false"
);
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
8ad90558
...
...
@@ -661,6 +661,7 @@ class While(object):
Args:
cond (Variable): condition used to compare.
is_test(bool): A flag indicating whether execution is in test phase.
name (str): The name of this layer.
Examples:
...
...
@@ -683,7 +684,7 @@ class While(object):
IN_WHILE_BLOCK
=
1
AFTER_WHILE_BLOCK
=
2
def
__init__
(
self
,
cond
,
name
=
None
):
def
__init__
(
self
,
cond
,
is_test
=
False
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"while"
,
name
=
name
)
self
.
status
=
While
.
BEFORE_WHILE_BLOCK
if
not
isinstance
(
cond
,
Variable
):
...
...
@@ -694,6 +695,7 @@ class While(object):
if
reduce
(
lambda
a
,
b
:
a
*
b
,
cond
.
shape
,
1
)
!=
1
:
raise
TypeError
(
"condition should be a bool scalar"
)
self
.
cond_var
=
cond
self
.
is_test
=
is_test
def
block
(
self
):
return
WhileGuard
(
self
)
...
...
@@ -735,7 +737,8 @@ class While(object):
},
outputs
=
{
'Out'
:
out_vars
,
'StepScopes'
:
[
step_scope
]},
attrs
=
{
'sub_block'
:
while_block
})
attrs
=
{
'sub_block'
:
while_block
,
"is_test"
:
self
.
is_test
})
def
lod_rank_table
(
x
,
level
=
0
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录