diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index eb654458fbf29ecc575f301c566df9287a947c44..1046c27fef54dfdfd634ce5050b1668326364b1f 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2839,7 +2839,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): # Merge ture and false output if they are not None if return_names is None: is_dy2staic = False - return_names = ["no name"] * len(to_sequence(true_output)) + return_names = ["no name"] * len(_to_sequence_except_dict(true_output)) else: """ dy2static will set the return_names and expand the return values to UndefinedVar. @@ -2855,16 +2855,19 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): true_output, false_output, return_names ) - if len(to_sequence(true_output)) != len(to_sequence(false_output)): + if len(_to_sequence_except_dict(true_output)) != len( + _to_sequence_except_dict(false_output) + ): raise ValueError( "true fn returns {} vars, but false fn returns {} vars, which is not equals".format( - len(to_sequence(true_output)), len(to_sequence(false_output)) + len(_to_sequence_except_dict(true_output)), + len(_to_sequence_except_dict(false_output)), ) ) for true_out, false_out, return_name in zip( - to_sequence(true_output), - to_sequence(false_output), - to_sequence(return_names), + _to_sequence_except_dict(true_output), + _to_sequence_except_dict(false_output), + _to_sequence_except_dict(return_names), ): try: assert_same_structure(true_out, false_out, check_types=False) @@ -2876,10 +2879,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ) def check_ret_none(seq_true, seq_false, seq_names): - length = len(seq_true) - for i in range(length): - f_true = flatten(seq_true[i]) - f_false = flatten(seq_false[i]) + for f_true, f_false, f_name in zip(seq_true, seq_false, seq_names): + f_true = flatten(f_true) + f_false = flatten(f_false) for idx in range(len(f_true)): if ( f_true[idx] is None @@ -2891,7 +2893,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): "In cond : Var '{}' or part of it is set differently in ifelse branchs, " "<{}, {}> in true branch and <{}, {}> in false branch. Set var to " "'None' in ifelse block might lead to error.".format( - seq_names[i], + f_name, type(f_true[idx]), f_true[idx], type(f_false[idx]), @@ -2900,9 +2902,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ) check_ret_none( - to_sequence(true_output), - to_sequence(false_output), - to_sequence(return_names), + _to_sequence_except_dict(true_output), + _to_sequence_except_dict(false_output), + _to_sequence_except_dict(return_names), ) if is_dy2staic: @@ -2923,9 +2925,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): merged_output = list( map( merge_every_var_list, - to_sequence(false_output), - to_sequence(true_output), - to_sequence(return_names), + _to_sequence_except_dict(false_output), + _to_sequence_except_dict(true_output), + _to_sequence_except_dict(return_names), ) ) merged_output = pack_sequence_as(false_output, flatten(merged_output)) @@ -2945,6 +2947,24 @@ def change_none_to_undefinedvar(nest1, nest2): return nest1_out, nest2_out +def _to_sequence_except_dict(x): + """ + In this function, dict is not viewed as sequence. + """ + if isinstance(x, dict): + return [x] + return to_sequence(x) + + +def _is_sequence_except_dict(x): + """ + In this function, dict is not viewed as sequence. + """ + if isinstance(x, dict): + return False + return is_sequence(x) + + def expand_undefined_var(nest1, nest2, names): """TODO: make this function recursively. nest1: Var1, (UndefinedVar, [1,2,3]) @@ -2988,24 +3008,24 @@ def expand_undefined_var(nest1, nest2, names): nest1_out = list( map( map_fn, - to_sequence(nest1), - to_sequence(nest2), - to_sequence(names), - [0 for i in to_sequence(names)], + _to_sequence_except_dict(nest1), + _to_sequence_except_dict(nest2), + _to_sequence_except_dict(names), + [0 for i in _to_sequence_except_dict(names)], ) ) nest2_out = list( map( map_fn, - to_sequence(nest2), - to_sequence(nest1), - to_sequence(names), - [1 for i in to_sequence(names)], + _to_sequence_except_dict(nest2), + _to_sequence_except_dict(nest1), + _to_sequence_except_dict(names), + [1 for i in _to_sequence_except_dict(names)], ) ) - if not is_sequence(nest1): + if not _is_sequence_except_dict(nest1): nest1_out = nest1_out[0] - if not is_sequence(nest2): + if not _is_sequence_except_dict(nest2): nest2_out = nest2_out[0] return nest1_out, nest2_out diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index 3d05d7694fa217f39ef1ffcfd32a4da7d264911c..a09ff49df2efc4a06401826dc0e40a10a7458fb6 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -676,5 +676,44 @@ class TestCondWithError(unittest.TestCase): layers.cond(pred, func, func, set()) +class TestCondWithDict(unittest.TestCase): + def test_input_with_dict(self): + paddle.enable_static() + main_program = framework.Program() + startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): + + def true_func(): + return { + '1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1), + '2': paddle.full( + shape=[2, 3], dtype='bool', fill_value=True + ), + } + + def false_func(): + return { + '1': paddle.full( + shape=[3, 4], dtype='float32', fill_value=3 + ), + '2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2), + } + + x = paddle.full(shape=[1], dtype='float32', fill_value=0.1) + y = paddle.full(shape=[1], dtype='float32', fill_value=0.23) + pred = paddle.less_than(x=x, y=y, name=None) + ret = paddle.static.nn.cond(pred, true_func, false_func) + self.assertEqual( + ret['1'].shape, + (3, -1), + f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.", + ) + self.assertEqual( + ret['2'].shape, + (-1, -1), + f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.", + ) + + if __name__ == '__main__': unittest.main()