未验证 提交 cc8ca8ce 编写于 作者: A Aurelius84 提交者: GitHub

Polish error Info in while_loop (#23183)

* Polish error Info in while_loop test=develop
上级 a486a739
......@@ -1031,9 +1031,12 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
if len(output_vars) != len(loop_vars):
try:
assert_same_structure(output_vars, loop_vars, check_types=False)
except ValueError as e:
raise ValueError("body in while_loop should return the same arity "
"(length and structure) and types as loop_vars")
"(length and structure) as loop_vars: {0}".format(
e))
now_cond = cond(*output_vars)
map_structure(assign, output_vars, loop_vars)
assign(now_cond, pre_cond)
......
......@@ -384,6 +384,23 @@ class TestApiWhileLoop_Error(unittest.TestCase):
def body_returns_error_type(i, ten):
return layers.increment(i)
def cond_returns_with_mutable_dict(i, test_dict):
return i > 0
def body_returns_with_mutable_dict(i, test_dict):
test_dict['new_key'] = layers.fill_constant(
shape=[1], dtype='int64', value=1)
return layers.increment(i), test_dict
def cond_returns_with_mutable_list(i, test_list):
return i > 0
def body_returns_with_mutable_list(i, test_list):
test_list.append(
layers.fill_constant(
shape=[1], dtype='int64', value=1))
return layers.increment(i), test_list
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
......@@ -451,6 +468,31 @@ class TestApiWhileLoop_Error(unittest.TestCase):
self.assertRaises(ValueError, value_error_body_returns_error_type)
# The length of `output_vars` with mutable value should keep same with `loop_vars`
def value_error_body_returns_with_mutable_dict():
test_dict = {
"int_constant": layers.fill_constant(
shape=[2, 2], dtype='int64', value=1)
}
out = layers.while_loop(cond_returns_with_mutable_dict,
body_returns_with_mutable_dict,
[data, test_dict])
self.assertRaises(ValueError,
value_error_body_returns_with_mutable_dict)
def value_error_body_returns_with_mutable_list():
test_list = [
layers.fill_constant(
shape=[2, 2], dtype='int64', value=1)
]
out = layers.while_loop(cond_returns_with_mutable_list,
body_returns_with_mutable_list,
[data, test_list])
self.assertRaises(ValueError,
value_error_body_returns_with_mutable_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册