未验证 提交 53a62ea4 编写于 作者: A Aurelius84 提交者: GitHub

[ControlFlow] Fix contrib API bug in while_loop (#41230)

* [ControlFlow] Fix contrib API bug in while_loop

* format code
上级 34241dd1
......@@ -974,6 +974,19 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
:return: inner_inputs, inner_outputs
"""
def is_ignore_vars(op, var_name):
# NOTE(dev): There are some persistable var created in some non-standard API
# such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
# Input and Output. This var shall not be considered as a loop_var in
# control_flow.
IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
if op.type in IGNORE_VAR_NAMES:
var_names = IGNORE_VAR_NAMES[op.type]
for name in var_names:
if name in var_name:
return True
return False
# Step1: update inner_inputs and inner_outputs
# NOTE: Here assumes that all variables are input or output of Ops,
# but some variables are created without appendding a real op.
......@@ -982,7 +995,8 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
assert isinstance(op, Operator)
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in inner_outputs:
if in_var_name not in inner_outputs and not is_ignore_vars(
op, in_var_name):
inner_inputs.add(in_var_name)
for oname in op.output_names:
......
......@@ -137,5 +137,44 @@ class BadInputTest(unittest.TestCase):
self.assertRaises(TypeError, test_bad_x)
class TestIgnoreVarNameInWhile(unittest.TestCase):
def test_ignore_var(self):
def cond(i, ten, temp, y):
return i < ten
def body_func(i, ten, batch_info, origin_seq):
print(batch_info)
batch_info = fluid.contrib.layers.shuffle_batch(batch_info)
print(batch_info)
i = i + 1
return [i, ten, batch_info, origin_seq]
x = fluid.layers.data(name='x', shape=[-1, 1, 4])
y = fluid.layers.data(name='y', shape=[-1, 1, 1])
temp = layers.concat(input=[x, y], axis=-1)
i = layers.fill_constant(shape=[1], value=0, dtype='int32')
num = layers.fill_constant(shape=[1], value=5, dtype='int32')
i, ten, shuffle_temp, y = layers.while_loop(cond, body_func,
[i, num, temp, y])
output = shuffle_temp
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
input_x = numpy.array([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]])
input_x = input_x.reshape(3, 1, 4)
input_y = numpy.array([[10], [12], [33]])
input_y = input_y.reshape(3, 1, 1)
res, = exe.run(fluid.default_main_program(),
feed={'x': input_x,
'y': input_y},
fetch_list=[output])
self.assertListEqual(list(res.shape), [3, 1, 5])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册