未验证 提交 c7797802 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Support Python3 type hint (#33745)

* support type hint

* fix unittest
上级 ae79a56b
......@@ -485,8 +485,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
os.remove(filepath)
source = ast_to_source_code(ast_root)
import_fluid = "import paddle\nimport paddle.fluid as fluid\n"
source = import_fluid + source
source = _inject_import_statements() + source
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
......@@ -519,6 +518,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
return callable_func, f.name
def _inject_import_statements():
import_statements = [
"import paddle", "import paddle.fluid as fluid", "from typing import *",
"import numpy as np"
]
return '\n'.join(import_statements) + '\n'
def recover_globals_attribute(src_obj, dst_obj):
attr_name = '__globals__'
......
......@@ -65,7 +65,7 @@ class TestOriginInfo(unittest.TestCase):
self.func = simple_func
def set_static_lineno(self):
self.static_abs_lineno_list = [3, 4, 5]
self.static_abs_lineno_list = [5, 6, 7]
def set_dygraph_info(self):
self.line_num = 3
......@@ -149,7 +149,7 @@ class TestOriginInfoWithNestedFunc(TestOriginInfo):
self.func = nested_func
def set_static_lineno(self):
self.static_abs_lineno_list = [3, 5, 6, 7, 8]
self.static_abs_lineno_list = [5, 7, 8, 9, 10]
def set_dygraph_info(self):
self.line_num = 5
......@@ -174,7 +174,7 @@ class TestOriginInfoWithDecoratedFunc(TestOriginInfo):
self.func = decorated_func
def set_static_lineno(self):
self.static_abs_lineno_list = [3, 4]
self.static_abs_lineno_list = [5, 6]
def set_dygraph_info(self):
self.line_num = 2
......@@ -208,7 +208,7 @@ class TestOriginInfoWithDecoratedFunc2(TestOriginInfo):
self.func = decorated_func2
def set_static_lineno(self):
self.static_abs_lineno_list = [3, 4]
self.static_abs_lineno_list = [5, 6]
def set_dygraph_info(self):
self.line_num = 2
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
from typing import Tuple, List, Dict, TypeVar
class BaseLayer(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(BaseLayer, self).__init__()
self._linear = paddle.nn.Linear(in_size, out_size)
def build(self, x):
out1 = self._linear(x)
out2 = paddle.mean(out1)
return out1, out2
class LinearNetWithTuple(BaseLayer):
def __init__(self, in_size, out_size):
super(LinearNetWithTuple, self).__init__(in_size, out_size)
def forward(self, x) -> Tuple[paddle.Tensor, str]:
out1, out2 = self.build(x)
return (out2, 'str')
class LinearNetWithTuple2(BaseLayer):
def __init__(self, in_size, out_size):
super(LinearNetWithTuple2, self).__init__(in_size, out_size)
def forward(self, x) -> Tuple[paddle.Tensor, np.array]:
out1, out2 = self.build(x)
return (out2, np.ones([4, 16]))
class LinearNetWithList(BaseLayer):
def __init__(self, in_size, out_size):
super(LinearNetWithList, self).__init__(in_size, out_size)
def forward(self, x) -> List[paddle.Tensor]:
out1, out2 = self.build(x)
return [out2]
class LinearNetWithDict(BaseLayer):
def __init__(self, in_size, out_size):
super(LinearNetWithDict, self).__init__(in_size, out_size)
def forward(self, x) -> Dict[str, paddle.Tensor]:
out1, out2 = self.build(x)
return {'out': out2}
class TestTyping(unittest.TestCase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x = paddle.randn([4, 16])
self.spec = [paddle.static.InputSpec(shape=[None, 16], dtype='float32')]
def build_net(self):
return LinearNetWithTuple(self.in_num, self.out_num)
def save_and_load(self, suffix=''):
path = './layer_typing_' + suffix
paddle.jit.save(self.net, path, input_spec=self.spec)
return paddle.jit.load(path)
def run_dy(self):
out, _ = self.net(self.x)
return out
def test_type(self):
self.net = self.build_net()
out = self.run_dy()
load_net = self.save_and_load('tuple')
load_out = load_net(self.x)
self.assertTrue(np.allclose(out, load_out))
class TestTypingTuple(TestTyping):
def build_net(self):
return LinearNetWithTuple2(self.in_num, self.out_num)
def run_dy(self):
out, np_data = self.net(self.x)
self.assertTrue(np.equal(np_data, np.ones_like(np_data)).all())
return out
class TestTypingList(TestTyping):
def build_net(self):
return LinearNetWithList(self.in_num, self.out_num)
def run_dy(self):
out = self.net(self.x)[0]
return out
class TestTypingDict(TestTyping):
def build_net(self):
return LinearNetWithDict(self.in_num, self.out_num)
def run_dy(self):
out = self.net(self.x)['out']
return out
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册