未验证 提交 13a5f183 编写于 作者: X xiongkun 提交者: GitHub

[BugFix] while cond receives dict as input (#47299)

* fix bugs while cond receives dict as input

* add unittest

* change flatten -> _is_sequence_except_dict
上级 ac3b882f
......@@ -2839,7 +2839,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
# Merge ture and false output if they are not None
if return_names is None:
is_dy2staic = False
return_names = ["no name"] * len(to_sequence(true_output))
return_names = ["no name"] * len(_to_sequence_except_dict(true_output))
else:
"""
dy2static will set the return_names and expand the return values to UndefinedVar.
......@@ -2855,16 +2855,19 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
true_output, false_output, return_names
)
if len(to_sequence(true_output)) != len(to_sequence(false_output)):
if len(_to_sequence_except_dict(true_output)) != len(
_to_sequence_except_dict(false_output)
):
raise ValueError(
"true fn returns {} vars, but false fn returns {} vars, which is not equals".format(
len(to_sequence(true_output)), len(to_sequence(false_output))
len(_to_sequence_except_dict(true_output)),
len(_to_sequence_except_dict(false_output)),
)
)
for true_out, false_out, return_name in zip(
to_sequence(true_output),
to_sequence(false_output),
to_sequence(return_names),
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(false_output),
_to_sequence_except_dict(return_names),
):
try:
assert_same_structure(true_out, false_out, check_types=False)
......@@ -2876,10 +2879,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
)
def check_ret_none(seq_true, seq_false, seq_names):
length = len(seq_true)
for i in range(length):
f_true = flatten(seq_true[i])
f_false = flatten(seq_false[i])
for f_true, f_false, f_name in zip(seq_true, seq_false, seq_names):
f_true = flatten(f_true)
f_false = flatten(f_false)
for idx in range(len(f_true)):
if (
f_true[idx] is None
......@@ -2891,7 +2893,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
seq_names[i],
f_name,
type(f_true[idx]),
f_true[idx],
type(f_false[idx]),
......@@ -2900,9 +2902,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
)
check_ret_none(
to_sequence(true_output),
to_sequence(false_output),
to_sequence(return_names),
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(false_output),
_to_sequence_except_dict(return_names),
)
if is_dy2staic:
......@@ -2923,9 +2925,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
merged_output = list(
map(
merge_every_var_list,
to_sequence(false_output),
to_sequence(true_output),
to_sequence(return_names),
_to_sequence_except_dict(false_output),
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(return_names),
)
)
merged_output = pack_sequence_as(false_output, flatten(merged_output))
......@@ -2945,6 +2947,24 @@ def change_none_to_undefinedvar(nest1, nest2):
return nest1_out, nest2_out
def _to_sequence_except_dict(x):
"""
In this function, dict is not viewed as sequence.
"""
if isinstance(x, dict):
return [x]
return to_sequence(x)
def _is_sequence_except_dict(x):
"""
In this function, dict is not viewed as sequence.
"""
if isinstance(x, dict):
return False
return is_sequence(x)
def expand_undefined_var(nest1, nest2, names):
"""TODO: make this function recursively.
nest1: Var1, (UndefinedVar, [1,2,3])
......@@ -2988,24 +3008,24 @@ def expand_undefined_var(nest1, nest2, names):
nest1_out = list(
map(
map_fn,
to_sequence(nest1),
to_sequence(nest2),
to_sequence(names),
[0 for i in to_sequence(names)],
_to_sequence_except_dict(nest1),
_to_sequence_except_dict(nest2),
_to_sequence_except_dict(names),
[0 for i in _to_sequence_except_dict(names)],
)
)
nest2_out = list(
map(
map_fn,
to_sequence(nest2),
to_sequence(nest1),
to_sequence(names),
[1 for i in to_sequence(names)],
_to_sequence_except_dict(nest2),
_to_sequence_except_dict(nest1),
_to_sequence_except_dict(names),
[1 for i in _to_sequence_except_dict(names)],
)
)
if not is_sequence(nest1):
if not _is_sequence_except_dict(nest1):
nest1_out = nest1_out[0]
if not is_sequence(nest2):
if not _is_sequence_except_dict(nest2):
nest2_out = nest2_out[0]
return nest1_out, nest2_out
......
......@@ -676,5 +676,44 @@ class TestCondWithError(unittest.TestCase):
layers.cond(pred, func, func, set())
class TestCondWithDict(unittest.TestCase):
def test_input_with_dict(self):
paddle.enable_static()
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
def true_func():
return {
'1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
'2': paddle.full(
shape=[2, 3], dtype='bool', fill_value=True
),
}
def false_func():
return {
'1': paddle.full(
shape=[3, 4], dtype='float32', fill_value=3
),
'2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
}
x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
pred = paddle.less_than(x=x, y=y, name=None)
ret = paddle.static.nn.cond(pred, true_func, false_func)
self.assertEqual(
ret['1'].shape,
(3, -1),
f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
)
self.assertEqual(
ret['2'].shape,
(-1, -1),
f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册