未验证 提交 e7305160 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Refine return mechanism in @to_static (#28116)

* remove some judgement

* fix len(outputs) == 1
上级 68449d19
...@@ -606,8 +606,10 @@ class ConcreteProgram(object): ...@@ -606,8 +606,10 @@ class ConcreteProgram(object):
error.attach_error_data(e) error.attach_error_data(e)
raise raise
if not isinstance(outputs, if outputs is not None:
(tuple, list)) and outputs is not None: need_wrap_into_list = not isinstance(outputs, (
tuple, list)) or len(outputs) == 1
if need_wrap_into_list:
outputs = [outputs] outputs = [outputs]
main_program = update_op_callstack_with_origin_info(main_program) main_program = update_op_callstack_with_origin_info(main_program)
......
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph import declarative from paddle.jit import to_static
from paddle.fluid.dygraph import ProgramTranslator from paddle.jit import ProgramTranslator
from ifelse_simple_func import dyfunc_with_if_else from ifelse_simple_func import dyfunc_with_if_else
...@@ -27,13 +27,13 @@ SEED = 2020 ...@@ -27,13 +27,13 @@ SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
@declarative @to_static
def test_return_base(x): def test_return_base(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
return x return x
@declarative @to_static
def test_inside_func_base(x): def test_inside_func_base(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
...@@ -43,7 +43,7 @@ def test_inside_func_base(x): ...@@ -43,7 +43,7 @@ def test_inside_func_base(x):
return inner_func(x) return inner_func(x)
@declarative @to_static
def test_return_if(x): def test_return_if(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
if x < 0: if x < 0:
...@@ -53,7 +53,7 @@ def test_return_if(x): ...@@ -53,7 +53,7 @@ def test_return_if(x):
return x return x
@declarative @to_static
def test_return_if_else(x): def test_return_if_else(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
if x > 0: if x > 0:
...@@ -66,7 +66,7 @@ def test_return_if_else(x): ...@@ -66,7 +66,7 @@ def test_return_if_else(x):
x -= 8888 # useless statement to test our code can handle it. x -= 8888 # useless statement to test our code can handle it.
@declarative @to_static
def test_return_in_while(x): def test_return_in_while(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
...@@ -79,7 +79,7 @@ def test_return_in_while(x): ...@@ -79,7 +79,7 @@ def test_return_in_while(x):
return x return x
@declarative @to_static
def test_return_in_for(x): def test_return_in_for(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
for i in range(10): for i in range(10):
...@@ -91,13 +91,13 @@ def test_return_in_for(x): ...@@ -91,13 +91,13 @@ def test_return_in_for(x):
return x - 1 return x - 1
@declarative @to_static
def test_recursive_return(x): def test_recursive_return(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
return dyfunc_with_if_else(x) return dyfunc_with_if_else(x)
@declarative @to_static
def test_return_different_length_if_body(x): def test_return_different_length_if_body(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
...@@ -108,7 +108,7 @@ def test_return_different_length_if_body(x): ...@@ -108,7 +108,7 @@ def test_return_different_length_if_body(x):
return x return x
@declarative @to_static
def test_return_different_length_else(x): def test_return_different_length_else(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
...@@ -119,13 +119,13 @@ def test_return_different_length_else(x): ...@@ -119,13 +119,13 @@ def test_return_different_length_else(x):
return x return x
@declarative @to_static
def test_no_return(x): def test_no_return(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
@declarative @to_static
def test_return_none(x): def test_return_none(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
...@@ -136,7 +136,7 @@ def test_return_none(x): ...@@ -136,7 +136,7 @@ def test_return_none(x):
return x, y return x, y
@declarative @to_static
def test_return_no_variable(x): def test_return_no_variable(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
...@@ -147,6 +147,38 @@ def test_return_no_variable(x): ...@@ -147,6 +147,38 @@ def test_return_no_variable(x):
return return
@to_static
def test_return_list_one_value(x):
x = fluid.dygraph.to_variable(x)
x += 1
return [x]
@to_static
def test_return_list_many_values(x):
x = fluid.dygraph.to_variable(x)
x += 1
y = x * 2
z = x * x
return [x, y, z]
@to_static
def test_return_tuple_one_value(x):
x = fluid.dygraph.to_variable(x)
x += 1
return (x, )
@to_static
def test_return_tuple_many_values(x):
x = fluid.dygraph.to_variable(x)
x += 1
y = x * 2
z = x * x
return (x, y, z)
class TestReturnBase(unittest.TestCase): class TestReturnBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones((1)).astype('int32') self.input = np.ones((1)).astype('int32')
...@@ -158,29 +190,19 @@ class TestReturnBase(unittest.TestCase): ...@@ -158,29 +190,19 @@ class TestReturnBase(unittest.TestCase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_base self.dygraph_func = test_return_base
def run_dygraph_mode(self): def _run(self, to_static=False):
self.program_translator.enable(False) self.program_translator.enable(to_static)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
res = self.dygraph_func(self.input) res = self.dygraph_func(self.input)
if isinstance(res, (tuple)): if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy()
return res
def run_static_mode(self):
self.program_translator.enable(True)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, tuple):
return tuple(r.numpy() for r in res) return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase): elif isinstance(res, core.VarBase):
return res.numpy() return res.numpy()
return res return res
def test_transformed_static_result(self): def test_transformed_static_result(self):
dygraph_res = self.run_dygraph_mode() dygraph_res = self._run(to_static=False)
static_res = self.run_static_mode() static_res = self._run(to_static=True)
if isinstance(dygraph_res, tuple): if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple)) self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res)) self.assertEqual(len(dygraph_res), len(static_res))
...@@ -255,5 +277,25 @@ class TestReturnNoVariable(TestReturnBase): ...@@ -255,5 +277,25 @@ class TestReturnNoVariable(TestReturnBase):
self.dygraph_func = test_return_no_variable self.dygraph_func = test_return_no_variable
class TestReturnListOneValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_list_one_value
class TestReturnListManyValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_list_many_values
class TestReturnTupleOneValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_tuple_one_value
class TestReturnTupleManyValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_tuple_many_values
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册