Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
53a62ea4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53a62ea4
编写于
4月 01, 2022
作者:
A
Aurelius84
提交者:
GitHub
4月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ControlFlow] Fix contrib API bug in while_loop (#41230)
* [ControlFlow] Fix contrib API bug in while_loop * format code
上级
34241dd1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
54 addition
and
1 deletion
+54
-1
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+15
-1
python/paddle/fluid/tests/unittests/test_while_op.py
python/paddle/fluid/tests/unittests/test_while_op.py
+39
-0
未找到文件。
python/paddle/fluid/layers/control_flow.py
浏览文件 @
53a62ea4
...
...
@@ -974,6 +974,19 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
:return: inner_inputs, inner_outputs
"""
def
is_ignore_vars
(
op
,
var_name
):
# NOTE(dev): There are some persistable var created in some non-standard API
# such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
# Input and Output. This var shall not be considered as a loop_var in
# control_flow.
IGNORE_VAR_NAMES
=
{
"shuffle_batch"
:
[
"shuffle_batch_seed"
]}
if
op
.
type
in
IGNORE_VAR_NAMES
:
var_names
=
IGNORE_VAR_NAMES
[
op
.
type
]
for
name
in
var_names
:
if
name
in
var_name
:
return
True
return
False
# Step1: update inner_inputs and inner_outputs
# NOTE: Here assumes that all variables are input or output of Ops,
# but some variables are created without appendding a real op.
...
...
@@ -982,7 +995,8 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
assert
isinstance
(
op
,
Operator
)
for
iname
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
iname
):
if
in_var_name
not
in
inner_outputs
:
if
in_var_name
not
in
inner_outputs
and
not
is_ignore_vars
(
op
,
in_var_name
):
inner_inputs
.
add
(
in_var_name
)
for
oname
in
op
.
output_names
:
...
...
python/paddle/fluid/tests/unittests/test_while_op.py
浏览文件 @
53a62ea4
...
...
@@ -137,5 +137,44 @@ class BadInputTest(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
test_bad_x
)
class
TestIgnoreVarNameInWhile
(
unittest
.
TestCase
):
def
test_ignore_var
(
self
):
def
cond
(
i
,
ten
,
temp
,
y
):
return
i
<
ten
def
body_func
(
i
,
ten
,
batch_info
,
origin_seq
):
print
(
batch_info
)
batch_info
=
fluid
.
contrib
.
layers
.
shuffle_batch
(
batch_info
)
print
(
batch_info
)
i
=
i
+
1
return
[
i
,
ten
,
batch_info
,
origin_seq
]
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
-
1
,
1
,
4
])
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
-
1
,
1
,
1
])
temp
=
layers
.
concat
(
input
=
[
x
,
y
],
axis
=-
1
)
i
=
layers
.
fill_constant
(
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
)
num
=
layers
.
fill_constant
(
shape
=
[
1
],
value
=
5
,
dtype
=
'int32'
)
i
,
ten
,
shuffle_temp
,
y
=
layers
.
while_loop
(
cond
,
body_func
,
[
i
,
num
,
temp
,
y
])
output
=
shuffle_temp
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
input_x
=
numpy
.
array
([[
1
,
2
,
3
,
4
],
[
4
,
5
,
6
,
7
],
[
7
,
8
,
9
,
10
]])
input_x
=
input_x
.
reshape
(
3
,
1
,
4
)
input_y
=
numpy
.
array
([[
10
],
[
12
],
[
33
]])
input_y
=
input_y
.
reshape
(
3
,
1
,
1
)
res
,
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'x'
:
input_x
,
'y'
:
input_y
},
fetch_list
=
[
output
])
self
.
assertListEqual
(
list
(
res
.
shape
),
[
3
,
1
,
5
])
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录