Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
781dc722
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
781dc722
编写于
9月 01, 2018
作者:
X
Xin Pan
提交者:
GitHub
9月 01, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13128 from chengduoZH/refine_while_op
Refine while op
上级
91e10fb0
16359da2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
14 addition
and
3 deletion
+14
-3
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+8
-0
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+5
-2
未找到文件。
paddle/fluid/API.spec
浏览文件 @
781dc722
...
...
@@ -190,7 +190,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', '
name'], varargs=None, keywords=None, defaults=(None,
))
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', '
is_test', 'name'], varargs=None, keywords=None, defaults=(False, None
))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/operators/while_op.cc
浏览文件 @
781dc722
...
...
@@ -55,6 +55,7 @@ class WhileOp : public framework::OperatorBase {
auto
step_scopes
=
scope
.
FindVar
(
Output
(
kStepScopes
))
->
GetMutable
<
StepScopeVar
>
();
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
cond
.
place
()),
"Condition of while op must in CPU memory."
);
while
(
cond
.
data
<
bool
>
()[
0
])
{
...
...
@@ -63,6 +64,10 @@ class WhileOp : public framework::OperatorBase {
executor
.
Run
(
*
program
,
&
current_scope
,
block
->
ID
(),
false
/*create_local_scope*/
);
if
(
is_test
)
{
scope
.
DeleteScope
(
&
current_scope
);
}
}
}
};
...
...
@@ -88,6 +93,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"variables generated in the i'th step."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kStepBlock
,
"The step block inside WhileOp"
);
AddAttr
<
bool
>
(
"is_test"
,
"True if in test phase."
).
SetDefault
(
false
);
AddComment
(
R"DOC(
)DOC"
);
}
...
...
@@ -103,6 +109,8 @@ class WhileGradOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE
(
!
Attr
<
bool
>
(
"is_test"
),
"GradOp is only callable when is_test is false"
);
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
781dc722
...
...
@@ -661,6 +661,7 @@ class While(object):
Args:
cond (Variable): condition used to compare.
is_test(bool): A flag indicating whether execution is in test phase.
name (str): The name of this layer.
Examples:
...
...
@@ -683,7 +684,7 @@ class While(object):
IN_WHILE_BLOCK
=
1
AFTER_WHILE_BLOCK
=
2
def
__init__
(
self
,
cond
,
name
=
None
):
def
__init__
(
self
,
cond
,
is_test
=
False
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"while"
,
name
=
name
)
self
.
status
=
While
.
BEFORE_WHILE_BLOCK
if
not
isinstance
(
cond
,
Variable
):
...
...
@@ -694,6 +695,7 @@ class While(object):
if
reduce
(
lambda
a
,
b
:
a
*
b
,
cond
.
shape
,
1
)
!=
1
:
raise
TypeError
(
"condition should be a bool scalar"
)
self
.
cond_var
=
cond
self
.
is_test
=
is_test
def
block
(
self
):
return
WhileGuard
(
self
)
...
...
@@ -735,7 +737,8 @@ class While(object):
},
outputs
=
{
'Out'
:
out_vars
,
'StepScopes'
:
[
step_scope
]},
attrs
=
{
'sub_block'
:
while_block
})
attrs
=
{
'sub_block'
:
while_block
,
"is_test"
:
self
.
is_test
})
def
lod_rank_table
(
x
,
level
=
0
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录