From e73051609075ed273b283668015d45d1b027151e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 22 Oct 2020 11:30:04 +0800 Subject: [PATCH] [Dy2stat] Refine return mechanism in @to_static (#28116) * remove some judgement * fix len(outputs) == 1 --- .../dygraph_to_static/program_translator.py | 8 +- .../dygraph_to_static/test_return.py | 100 +++++++++++++----- 2 files changed, 76 insertions(+), 32 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 2ff3fe833d6..6d9bfc909a1 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -606,9 +606,11 @@ class ConcreteProgram(object): error.attach_error_data(e) raise - if not isinstance(outputs, - (tuple, list)) and outputs is not None: - outputs = [outputs] + if outputs is not None: + need_wrap_into_list = not isinstance(outputs, ( + tuple, list)) or len(outputs) == 1 + if need_wrap_into_list: + outputs = [outputs] main_program = update_op_callstack_with_origin_info(main_program) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py index 1f4f8214664..f592b7ed244 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py @@ -18,8 +18,8 @@ import unittest import numpy as np import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.dygraph import declarative -from paddle.fluid.dygraph import ProgramTranslator +from paddle.jit import to_static +from paddle.jit import ProgramTranslator from ifelse_simple_func import dyfunc_with_if_else @@ -27,13 +27,13 @@ SEED = 2020 np.random.seed(SEED) -@declarative +@to_static def test_return_base(x): x = fluid.dygraph.to_variable(x) return x -@declarative +@to_static def test_inside_func_base(x): x = fluid.dygraph.to_variable(x) @@ -43,7 +43,7 @@ def test_inside_func_base(x): return inner_func(x) -@declarative +@to_static def test_return_if(x): x = fluid.dygraph.to_variable(x) if x < 0: @@ -53,7 +53,7 @@ def test_return_if(x): return x -@declarative +@to_static def test_return_if_else(x): x = fluid.dygraph.to_variable(x) if x > 0: @@ -66,7 +66,7 @@ def test_return_if_else(x): x -= 8888 # useless statement to test our code can handle it. -@declarative +@to_static def test_return_in_while(x): x = fluid.dygraph.to_variable(x) i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) @@ -79,7 +79,7 @@ def test_return_in_while(x): return x -@declarative +@to_static def test_return_in_for(x): x = fluid.dygraph.to_variable(x) for i in range(10): @@ -91,13 +91,13 @@ def test_return_in_for(x): return x - 1 -@declarative +@to_static def test_recursive_return(x): x = fluid.dygraph.to_variable(x) return dyfunc_with_if_else(x) -@declarative +@to_static def test_return_different_length_if_body(x): x = fluid.dygraph.to_variable(x) y = x + 1 @@ -108,7 +108,7 @@ def test_return_different_length_if_body(x): return x -@declarative +@to_static def test_return_different_length_else(x): x = fluid.dygraph.to_variable(x) y = x + 1 @@ -119,13 +119,13 @@ def test_return_different_length_else(x): return x -@declarative +@to_static def test_no_return(x): x = fluid.dygraph.to_variable(x) y = x + 1 -@declarative +@to_static def test_return_none(x): x = fluid.dygraph.to_variable(x) y = x + 1 @@ -136,7 +136,7 @@ def test_return_none(x): return x, y -@declarative +@to_static def test_return_no_variable(x): x = fluid.dygraph.to_variable(x) y = x + 1 @@ -147,6 +147,38 @@ def test_return_no_variable(x): return +@to_static +def test_return_list_one_value(x): + x = fluid.dygraph.to_variable(x) + x += 1 + return [x] + + +@to_static +def test_return_list_many_values(x): + x = fluid.dygraph.to_variable(x) + x += 1 + y = x * 2 + z = x * x + return [x, y, z] + + +@to_static +def test_return_tuple_one_value(x): + x = fluid.dygraph.to_variable(x) + x += 1 + return (x, ) + + +@to_static +def test_return_tuple_many_values(x): + x = fluid.dygraph.to_variable(x) + x += 1 + y = x * 2 + z = x * x + return (x, y, z) + + class TestReturnBase(unittest.TestCase): def setUp(self): self.input = np.ones((1)).astype('int32') @@ -158,29 +190,19 @@ class TestReturnBase(unittest.TestCase): def init_dygraph_func(self): self.dygraph_func = test_return_base - def run_dygraph_mode(self): - self.program_translator.enable(False) + def _run(self, to_static=False): + self.program_translator.enable(to_static) with fluid.dygraph.guard(): res = self.dygraph_func(self.input) - if isinstance(res, (tuple)): - return tuple(r.numpy() for r in res) - elif isinstance(res, core.VarBase): - return res.numpy() - return res - - def run_static_mode(self): - self.program_translator.enable(True) - with fluid.dygraph.guard(): - res = self.dygraph_func(self.input) - if isinstance(res, tuple): + if isinstance(res, (tuple, list)): return tuple(r.numpy() for r in res) elif isinstance(res, core.VarBase): return res.numpy() return res def test_transformed_static_result(self): - dygraph_res = self.run_dygraph_mode() - static_res = self.run_static_mode() + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) if isinstance(dygraph_res, tuple): self.assertTrue(isinstance(static_res, tuple)) self.assertEqual(len(dygraph_res), len(static_res)) @@ -255,5 +277,25 @@ class TestReturnNoVariable(TestReturnBase): self.dygraph_func = test_return_no_variable +class TestReturnListOneValue(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_list_one_value + + +class TestReturnListManyValue(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_list_many_values + + +class TestReturnTupleOneValue(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_tuple_one_value + + +class TestReturnTupleManyValue(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_tuple_many_values + + if __name__ == '__main__': unittest.main() -- GitLab