diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 001116a74c9cc5f149de8ab1ebd7f8f5c2f68068..1513b9f5222e6d78fb37e1fa4fa0485aa603c205 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -487,8 +487,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): os.remove(filepath) source = ast_to_source_code(ast_root) - import_fluid = "import paddle\nimport paddle.fluid as fluid\n" - source = import_fluid + source + source = _inject_import_statements() + source if six.PY2: source = source.encode('utf-8') @@ -528,6 +527,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): return callable_func, f.name +def _inject_import_statements(): + import_statements = [ + "import paddle", "import paddle.fluid as fluid", "from typing import *", + "import numpy as np" + ] + return '\n'.join(import_statements) + '\n' + + def recover_globals_attribute(src_obj, dst_obj): attr_name = '__globals__' diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py index 144b16873aa9bc576338a71a7bc532e7df53aa4a..016a1b3b588ab015be14c6b45cd9a4145bb7cff5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py @@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase): self.func = simple_func def set_static_lineno(self): - self.static_abs_lineno_list = [3, 4, 5] + self.static_abs_lineno_list = [5, 6, 7] def set_dygraph_info(self): self.line_num = 3 @@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo): self.func = nested_func def set_static_lineno(self): - self.static_abs_lineno_list = [3, 5, 6, 7, 8] + self.static_abs_lineno_list = [5, 7, 8, 9, 10] def set_dygraph_info(self): self.line_num = 5 @@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo): self.func = decorated_func def set_static_lineno(self): - self.static_abs_lineno_list = [3, 4] + self.static_abs_lineno_list = [5, 6] def set_dygraph_info(self): self.line_num = 2 @@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo): self.func = decorated_func2 def set_static_lineno(self): - self.static_abs_lineno_list = [3, 4] + self.static_abs_lineno_list = [5, 6] def set_dygraph_info(self): self.line_num = 2 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typing.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..c3c0453bde3f405bf6e3422e95f363583e172b47 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typing.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np +from typing import Tuple, List, Dict, TypeVar + + +class BaseLayer(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(BaseLayer, self).__init__() + self._linear = paddle.nn.Linear(in_size, out_size) + + def build(self, x): + out1 = self._linear(x) + out2 = paddle.mean(out1) + return out1, out2 + + +class LinearNetWithTuple(BaseLayer): + def __init__(self, in_size, out_size): + super(LinearNetWithTuple, self).__init__(in_size, out_size) + + def forward(self, x) -> Tuple[paddle.Tensor, str]: + out1, out2 = self.build(x) + return (out2, 'str') + + +class LinearNetWithTuple2(BaseLayer): + def __init__(self, in_size, out_size): + super(LinearNetWithTuple2, self).__init__(in_size, out_size) + + def forward(self, x) -> Tuple[paddle.Tensor, np.array]: + out1, out2 = self.build(x) + return (out2, np.ones([4, 16])) + + +class LinearNetWithList(BaseLayer): + def __init__(self, in_size, out_size): + super(LinearNetWithList, self).__init__(in_size, out_size) + + def forward(self, x) -> List[paddle.Tensor]: + out1, out2 = self.build(x) + return [out2] + + +class LinearNetWithDict(BaseLayer): + def __init__(self, in_size, out_size): + super(LinearNetWithDict, self).__init__(in_size, out_size) + + def forward(self, x) -> Dict[str, paddle.Tensor]: + out1, out2 = self.build(x) + return {'out': out2} + + +class TestTyping(unittest.TestCase): + def setUp(self): + self.in_num = 16 + self.out_num = 16 + self.x = paddle.randn([4, 16]) + self.spec = [paddle.static.InputSpec(shape=[None, 16], dtype='float32')] + + def build_net(self): + return LinearNetWithTuple(self.in_num, self.out_num) + + def save_and_load(self, suffix=''): + path = './layer_typing_' + suffix + paddle.jit.save(self.net, path, input_spec=self.spec) + return paddle.jit.load(path) + + def run_dy(self): + out, _ = self.net(self.x) + return out + + def test_type(self): + self.net = self.build_net() + out = self.run_dy() + load_net = self.save_and_load('tuple') + load_out = load_net(self.x) + self.assertTrue(np.allclose(out, load_out)) + + +class TestTypingTuple(TestTyping): + def build_net(self): + return LinearNetWithTuple2(self.in_num, self.out_num) + + def run_dy(self): + out, np_data = self.net(self.x) + self.assertTrue(np.equal(np_data, np.ones_like(np_data)).all()) + return out + + +class TestTypingList(TestTyping): + def build_net(self): + return LinearNetWithList(self.in_num, self.out_num) + + def run_dy(self): + out = self.net(self.x)[0] + return out + + +class TestTypingDict(TestTyping): + def build_net(self): + return LinearNetWithDict(self.in_num, self.out_num) + + def run_dy(self): + out = self.net(self.x)['out'] + return out + + +if __name__ == '__main__': + unittest.main()