未验证 提交 e7305160 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Refine return mechanism in @to_static (#28116)

* remove some judgement

* fix len(outputs) == 1
上级 68449d19
......@@ -606,9 +606,11 @@ class ConcreteProgram(object):
error.attach_error_data(e)
raise
if not isinstance(outputs,
(tuple, list)) and outputs is not None:
outputs = [outputs]
if outputs is not None:
need_wrap_into_list = not isinstance(outputs, (
tuple, list)) or len(outputs) == 1
if need_wrap_into_list:
outputs = [outputs]
main_program = update_op_callstack_with_origin_info(main_program)
......
......@@ -18,8 +18,8 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import ProgramTranslator
from paddle.jit import to_static
from paddle.jit import ProgramTranslator
from ifelse_simple_func import dyfunc_with_if_else
......@@ -27,13 +27,13 @@ SEED = 2020
np.random.seed(SEED)
@declarative
@to_static
def test_return_base(x):
x = fluid.dygraph.to_variable(x)
return x
@declarative
@to_static
def test_inside_func_base(x):
x = fluid.dygraph.to_variable(x)
......@@ -43,7 +43,7 @@ def test_inside_func_base(x):
return inner_func(x)
@declarative
@to_static
def test_return_if(x):
x = fluid.dygraph.to_variable(x)
if x < 0:
......@@ -53,7 +53,7 @@ def test_return_if(x):
return x
@declarative
@to_static
def test_return_if_else(x):
x = fluid.dygraph.to_variable(x)
if x > 0:
......@@ -66,7 +66,7 @@ def test_return_if_else(x):
x -= 8888 # useless statement to test our code can handle it.
@declarative
@to_static
def test_return_in_while(x):
x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
......@@ -79,7 +79,7 @@ def test_return_in_while(x):
return x
@declarative
@to_static
def test_return_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
......@@ -91,13 +91,13 @@ def test_return_in_for(x):
return x - 1
@declarative
@to_static
def test_recursive_return(x):
x = fluid.dygraph.to_variable(x)
return dyfunc_with_if_else(x)
@declarative
@to_static
def test_return_different_length_if_body(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
......@@ -108,7 +108,7 @@ def test_return_different_length_if_body(x):
return x
@declarative
@to_static
def test_return_different_length_else(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
......@@ -119,13 +119,13 @@ def test_return_different_length_else(x):
return x
@declarative
@to_static
def test_no_return(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
@declarative
@to_static
def test_return_none(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
......@@ -136,7 +136,7 @@ def test_return_none(x):
return x, y
@declarative
@to_static
def test_return_no_variable(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
......@@ -147,6 +147,38 @@ def test_return_no_variable(x):
return
@to_static
def test_return_list_one_value(x):
x = fluid.dygraph.to_variable(x)
x += 1
return [x]
@to_static
def test_return_list_many_values(x):
x = fluid.dygraph.to_variable(x)
x += 1
y = x * 2
z = x * x
return [x, y, z]
@to_static
def test_return_tuple_one_value(x):
x = fluid.dygraph.to_variable(x)
x += 1
return (x, )
@to_static
def test_return_tuple_many_values(x):
x = fluid.dygraph.to_variable(x)
x += 1
y = x * 2
z = x * x
return (x, y, z)
class TestReturnBase(unittest.TestCase):
def setUp(self):
self.input = np.ones((1)).astype('int32')
......@@ -158,29 +190,19 @@ class TestReturnBase(unittest.TestCase):
def init_dygraph_func(self):
self.dygraph_func = test_return_base
def run_dygraph_mode(self):
self.program_translator.enable(False)
def _run(self, to_static=False):
self.program_translator.enable(to_static)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (tuple)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy()
return res
def run_static_mode(self):
self.program_translator.enable(True)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, tuple):
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy()
return res
def test_transformed_static_result(self):
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res))
......@@ -255,5 +277,25 @@ class TestReturnNoVariable(TestReturnBase):
self.dygraph_func = test_return_no_variable
class TestReturnListOneValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_list_one_value
class TestReturnListManyValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_list_many_values
class TestReturnTupleOneValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_tuple_one_value
class TestReturnTupleManyValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_tuple_many_values
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册