diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index ba45dedc40faa473c3c1a7e1f2dfba5a47e2a381..3a7b012b02bee530d2340e8969f96db404168745 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -248,20 +248,6 @@ def _remove_no_value_return_var(out): def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, return_vars): - - return_var_ids = [id(var) for var in return_vars] - # NOTE 1: Returned vars of Paddle op `control_flow.cond` must be Paddle Tensors - # NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`, - # which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized. - true_args = [ - to_static_variable(var) if id(var) in return_var_ids else var - for var in true_args - ] - false_args = [ - to_static_variable(var) if id(var) in return_var_ids else var - for var in false_args - ] - pred = cast_bool_if_necessary(pred) return control_flow.cond(pred, lambda: true_fn(*true_args), lambda: false_fn(*false_args)) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index af2316a9a443e29a3215b64473fd2034d88483f0..668cb01549f6c5665783f7b3219f34f537ce1a15 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -102,6 +102,41 @@ def select_input(inputs, mask): return out +def select_input_with_buildin_type(inputs, mask): + from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable + support_ret_buildin_type = (bool, float, six.integer_types) + false_var, true_var = inputs + + if isinstance(false_var, Variable) and isinstance(true_var, Variable): + return select_input(inputs, mask) + + elif (isinstance(false_var, (support_ret_buildin_type)) and + isinstance(false_var, type(true_var))): + if false_var == true_var: + return false_var + else: + inputs = [ + to_static_variable(false_var), to_static_variable(true_var) + ] + # Deal with the situations like this: false_var is int and true_var is Variable + elif ((isinstance(false_var, support_ret_buildin_type) and + isinstance(true_var, Variable)) or + (isinstance(true_var, support_ret_buildin_type) and + isinstance(false_var, Variable))): + inputs = [to_static_variable(false_var), to_static_variable(true_var)] + warnings.warn( + "Return results from different branches in cond are not same type: " + "false_var returned by fasle_fn is '{}' and true_var of true_fn is " + "'{}'".format(type(false_var), type(true_var))) + else: + raise TypeError( + "Unsupported return type of true_fn and false_fn in cond: false_var " + "returned by fasle_fn is '{}' and true_var of true_fn is '{}'". + format(type(false_var), type(true_var))) + + return select_input(inputs, mask) + + def split_lod_tensor(input, mask, level=0): """ This function takes in an input that contains the complete lod information, @@ -2282,8 +2317,8 @@ class ConditionalBlock(object): def copy_var_to_parent_block(var, layer_helper): - if var is None: - return None + if not isinstance(var, Variable): + return var prog = layer_helper.main_program parent_idx = prog.current_block().parent_idx assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow" @@ -2466,7 +2501,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): format(e)) mask = cast(pred, dtype='int32') - merge_func = lambda false_var, true_var : select_input([false_var, true_var], mask) + merge_func = lambda false_var, true_var : select_input_with_buildin_type([false_var, true_var], mask) merged_output = map_structure(merge_func, false_output, true_output) return merged_output diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index b343c54d6b1ee640a325ba719aa483e193f8cbf6..04f44f68b4234b2f4553c3d0eb0274ce23a06f31 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -340,3 +340,53 @@ def if_tensor_case(x): x += 1 return x + + +def dyfunc_ifelse_ret_int1(x): + index = 0 + pred = paddle.to_tensor([1]) + if pred: + y = x[index] + 1 + index = index + 1 + return y, index + else: + y = x[index] + 2 + index = index + 1 + return y, index + + +def dyfunc_ifelse_ret_int2(x): + index = 0 + pred = paddle.to_tensor([1]) + if pred: + y = x[index] + 1 + index = index + 1 + return y, index + else: + y = x[index] + 2 + index = index + 1 + return y + + +def dyfunc_ifelse_ret_int3(x): + index = 0 + pred = paddle.to_tensor([1]) + if pred: + y = x[index] + 1 + index = index + 1 + return index + else: + y = x[index] + 2 + return y + + +def dyfunc_ifelse_ret_int4(x): + index = 0 + pred = paddle.to_tensor([1]) + if pred: + y = x[index] + 1 + index = index + 1 + return 'unsupport ret' + else: + y = x[index] + 2 + return y diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 5db1bb2a384f582c30a7877e49745cd9582e096e..7e999e3b21a8824671641e91362597a670e98220 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -365,5 +365,60 @@ class TestNewVarCreateInOneBranch(unittest.TestCase): self.assertEqual(paddle.jit.to_static(case_func)(True), -2) +class TestDy2StIfElseRetInt1(unittest.TestCase): + def setUp(self): + self.x = np.random.random([5]).astype('float32') + self.dyfunc = dyfunc_ifelse_ret_int1 + self.out = self.get_dy2stat_out() + + def get_dy2stat_out(self): + ProgramTranslator().enable(True) + static_func = paddle.jit.to_static(self.dyfunc) + out = static_func(self.x) + ProgramTranslator().enable(False) + return out + + def test_ast_to_func(self): + self.assertIsInstance(self.out[0], paddle.Tensor) + self.assertIsInstance(self.out[1], int) + + +class TestDy2StIfElseRetInt2(TestDy2StIfElseRetInt1): + def setUp(self): + self.x = np.random.random([5]).astype('float32') + self.dyfunc = dyfunc_ifelse_ret_int2 + self.out = self.get_dy2stat_out() + + def test_ast_to_func(self): + self.assertIsInstance(self.out[0], paddle.Tensor) + self.assertIsInstance(self.out[1], paddle.Tensor) + + +class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1): + def setUp(self): + self.x = np.random.random([5]).astype('float32') + self.dyfunc = dyfunc_ifelse_ret_int3 + self.out = self.get_dy2stat_out() + + def test_ast_to_func(self): + self.assertIsInstance(self.out, paddle.Tensor) + + +class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): + def setUp(self): + self.x = np.random.random([5]).astype('float32') + self.dyfunc = dyfunc_ifelse_ret_int4 + + def test_ast_to_func(self): + with self.assertRaises(TypeError): + ProgramTranslator().enable(True) + static_func = paddle.jit.to_static(self.dyfunc) + out = static_func(self.x) + + def __del__(self): + ProgramTranslator().enable(False) + super(TestDy2StIfElseRetInt4, self).__del__() + + if __name__ == '__main__': unittest.main()