Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
01c5ca73
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
01c5ca73
编写于
3月 29, 2018
作者:
J
JiayiFeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
917b205c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
25 addition
and
6 deletion
+25
-6
paddle/fluid/operators/compare_op.cc
paddle/fluid/operators/compare_op.cc
+8
-1
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+2
-0
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+15
-5
未找到文件。
paddle/fluid/operators/compare_op.cc
浏览文件 @
01c5ca73
...
...
@@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Y"
,
string
::
Sprintf
(
"(LoDTensor) the right hand operand of %s operator"
,
comment
.
type
));
AddAttr
<
bool
>
(
"force_cpu"
,
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device"
)
.
SetDefault
(
false
);
AddOutput
(
"Out"
,
string
::
Sprintf
(
"(LoDTensor) n-dim bool tensor. Each element is %s"
,
comment
.
equation
));
...
...
@@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
OperatorWithKernel
::
GetExpectedKernelType
(
ctx
);
// CompareOp kernel's device type is decided by input tensor place
kt
.
place_
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
place
();
bool
force_cpu
=
ctx
.
Attr
<
bool
>
(
"force_cpu"
);
kt
.
place_
=
force_cpu
?
platform
::
CPUPlace
()
:
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
place
();
return
kt
;
}
};
...
...
paddle/fluid/operators/while_op.cc
浏览文件 @
01c5ca73
...
...
@@ -54,6 +54,8 @@ class WhileOp : public framework::OperatorBase {
auto
step_scopes
=
scope
.
FindVar
(
Output
(
kStepScopes
))
->
GetMutable
<
StepScopeVar
>
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
cond
.
place
()),
"Condition of while op must in CPU memory."
);
while
(
cond
.
data
<
bool
>
()[
0
])
{
auto
&
current_scope
=
scope
.
NewScope
();
step_scopes
->
push_back
(
&
current_scope
);
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
01c5ca73
...
...
@@ -18,6 +18,7 @@ from tensor import assign, fill_constant
from
..
import
core
from
..framework
import
Program
,
Variable
,
Operator
from
..layer_helper
import
LayerHelper
,
unique_name
from
..initializer
import
force_init_on_cpu
from
ops
import
logical_and
,
logical_not
,
logical_or
__all__
=
[
...
...
@@ -949,7 +950,7 @@ def create_array(dtype):
dtype
=
dtype
)
def
less_than
(
x
,
y
,
cond
=
None
,
**
ignored
):
def
less_than
(
x
,
y
,
force_cpu
=
True
,
cond
=
None
,
**
ignored
):
"""
**Less than**
...
...
@@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored):
Args:
x(Variable): First operand of *less_than*
y(Variable): Second operand of *less_than*
force_cpu(Bool|True): The output data will be on CPU if set true.
cond(Variable|None): Optional output variable to store the result of *less_than*
Returns:
...
...
@@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored):
cond
.
stop_gradient
=
True
helper
.
append_op
(
type
=
'less_than'
,
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
y
]},
outputs
=
{
'Out'
:
[
cond
]})
type
=
'less_than'
,
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
y
]},
outputs
=
{
'Out'
:
[
cond
]},
attrs
=
{
'force_cpu'
:
force_cpu
or
force_init_on_cpu
()})
return
cond
...
...
@@ -1395,7 +1400,8 @@ class DynamicRNN(object):
type
=
'less_than'
,
inputs
=
{
'X'
:
self
.
step_idx
,
'Y'
:
self
.
max_seq_len
},
outputs
=
{
'Out'
:
self
.
cond
})
outputs
=
{
'Out'
:
self
.
cond
},
attrs
=
{
'force_cpu'
:
True
})
input_array
=
parent_block
.
create_var
(
name
=
unique_name
.
generate
(
'dynamic_rnn_input_array'
),
...
...
@@ -1443,7 +1449,11 @@ class DynamicRNN(object):
for
new_mem
,
mem_array
in
self
.
mem_link
:
array_write
(
x
=
new_mem
,
i
=
self
.
step_idx
,
array
=
mem_array
)
less_than
(
x
=
self
.
step_idx
,
y
=
self
.
max_seq_len
,
cond
=
self
.
cond
)
less_than
(
x
=
self
.
step_idx
,
y
=
self
.
max_seq_len
,
force_cpu
=
True
,
cond
=
self
.
cond
)
self
.
status
=
DynamicRNN
.
AFTER_RNN
for
each_array
in
self
.
output_array
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录