diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 4126e942259434dc5035a48a7fd054a7b0433f98..d27af5c0dd9e0c783394028156fe9aca0eef3c6d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -20,6 +20,7 @@ from paddle.fluid.layers import array_length, array_read, array_write, create_ar from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment +from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME def convert_while_loop(cond, body, loop_vars): @@ -204,10 +205,45 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars): """ if isinstance(pred, Variable): - return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, - return_vars) + out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, + return_vars) else: - return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args) + out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args) + + return _remove_no_value_return_var(out) + + +def _remove_no_value_return_var(out): + if out and isinstance(out, tuple): + processed_out = out + align_ret = out[0] + if isinstance(align_ret, tuple): + for index, item in enumerate(align_ret): + if isinstance(item, Variable) and ( + RETURN_NO_VALUE_VAR_NAME in item.name): + # return None + if index == 0: + processed_out = (None, ) + out[1:] + elif index == 1: + processed_out = align_ret[:1] + out[1:] + else: + processed_out = (align_ret[:index], ) + out[1:] + break + + for index, item in enumerate(processed_out): + if isinstance(item, Variable) and ( + RETURN_NO_VALUE_VAR_NAME in item.name): + processed_out = processed_out[:index] + + if not processed_out: + return None + elif len(processed_out) == 1: + return processed_out[0] + else: + return processed_out + + else: + return out def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index b118eeadf7e7e5a93b75fa467b46af511156bb23..2cd6c5e43f7e1261d2bb48a8cbfc8151327c7dea 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -93,14 +93,14 @@ def create_fill_constant_node(name, value): func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format( name) if isinstance(value, bool): - func_code += "dtype='bool', value={})".format(value) + func_code += "dtype='bool', value={}, name='{}')".format(value, name) return gast.parse(func_code).body[0] if isinstance(value, float): - func_code += "dtype='float64', value={})".format(value) + func_code += "dtype='float64', value={}, name='{}')".format(value, name) return gast.parse(func_code).body[0] if isinstance(value, int): - func_code += "dtype='int64', value={})".format(value) + func_code += "dtype='int64', value={}, name='{}')".format(value, name) return gast.parse(func_code).body[0] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 54dcc152fd6b281648991141973fc3a2b9a63f69..bb1942692fd9d2ac904496ba87f52f1dc54340d7 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -261,5 +261,100 @@ class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): self.assertTrue(np.array_equal(out.numpy(), x.numpy())) +class TestIfElseNoValue(unittest.TestCase): + def test_else_ret_none(self): + input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + + @paddle.jit.to_static + def with_common_value(x, use_cache=False): + if use_cache: + y = x + 1 + z = x + 2 + return y, z + else: + c = x + 1 + z = x - 1 + return None + + @paddle.jit.to_static + def without_common_value(x, use_cache=False): + if use_cache: + y = x + 1 + z = x + 2 + return y, z + else: + c = x + 1 + return None + + out = with_common_value(input_x, False) + self.assertIsNone(out) + out = without_common_value(input_x, False) + self.assertIsNone(out) + + def test_else_ret_c(self): + input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + + @paddle.jit.to_static + def with_common_value(x, use_cache=False): + if use_cache: + y = x + 1 + z = x + 2 + return y, z + else: + c = x + 1 + z = x - 1 + return c + + @paddle.jit.to_static + def without_common_value(x, use_cache=False): + if use_cache: + y = x + 1 + z = x + 2 + return y, z + else: + c = x + 1 + return c + + out = with_common_value(input_x, False) + self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1)) + out = without_common_value(input_x, False) + self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1)) + y, z = with_common_value(input_x, True) + self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1)) + self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2)) + + def test_else_ret_cz(self): + input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + + @paddle.jit.to_static + def with_common_value(x, use_cache=False): + if use_cache: + y = x + 1 + z = x + 2 + return y, z, 1 + else: + c = x + 1 + z = x - 1 + return c, z + + @paddle.jit.to_static + def without_common_value(x, use_cache=False): + if use_cache: + y = x + 1 + z = x + 2 + return y, z, 1 + else: + c = x + 1 + d = x - 1 + return c, d + + c, z = with_common_value(input_x, False) + self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1)) + self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x - 1)) + c, d = without_common_value(input_x, False) + self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1)) + self.assertListEqual(paddle.tolist(d), paddle.tolist(input_x - 1)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 9e12b6fa208505ef75c80516d1c65d06141a048e..6fef356326b81d00a0eca205586cc0d8247c1e5a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -64,7 +64,7 @@ def get_source_code(func): class StaticCode1(): def dyfunc_with_if_else(x_v, label=None): __return_value_init_0 = paddle.fluid.layers.fill_constant( - shape=[1], dtype='float64', value=0.0) + shape=[1], dtype='float64', value=0.0, name='__return_value_init_0') __return_value_0 = __return_value_init_0 def true_fn_0(x_v): @@ -116,7 +116,7 @@ class StaticCode2(): # TODO: Transform return statement def dyfunc_with_if_else(x_v, label=None): __return_value_init_1 = paddle.fluid.layers.fill_constant( - shape=[1], dtype='float64', value=0.0) + shape=[1], dtype='float64', value=0.0, name='__return_value_init_1') __return_value_1 = __return_value_init_1 def true_fn_3(x_v): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py index 3431c6aac4cbefd9827181b6a094bfa2655dc627..8500f46d974d8fb04b35b0f4b62532180977430c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py @@ -50,16 +50,22 @@ class TestDataLayerNotCheck(unittest.TestCase): class TestVariableTransFunc(unittest.TestCase): def test_create_fill_constant_node(self): node = create_fill_constant_node("a", 1.0) - source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0)" - self.assertEqual(ast_to_source_code(node).strip(), source) + source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0, name='a')" + self.assertEqual( + ast_to_source_code(node).replace('\n', '').replace(' ', ''), + source.replace(' ', '')) node = create_fill_constant_node("b", True) - source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)" - self.assertEqual(ast_to_source_code(node).strip(), source) + source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True, name='b')" + self.assertEqual( + ast_to_source_code(node).replace('\n', '').replace(' ', ''), + source.replace(' ', '')) node = create_fill_constant_node("c", 4293) - source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)" - self.assertEqual(ast_to_source_code(node).strip(), source) + source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293, name='c')" + self.assertEqual( + ast_to_source_code(node).replace('\n', '').replace(' ', ''), + source.replace(' ', '')) self.assertIsNone(create_fill_constant_node("e", None)) self.assertIsNone(create_fill_constant_node("e", []))