diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 5ec1dbea504cc6f536d13c34260b1c77e3681366..5e23acfe4249c4c74448810a7c3647591aa18028 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -28,6 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer +from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import TypeHintTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer @@ -104,6 +105,7 @@ class DygraphToStaticAst(BaseTransformer): CastTransformer, # type casting statement #GradTransformer, # transform paddle.grad to paddle.gradients DecoratorTransformer, # transform decorators to function call + TypeHintTransformer, # remove all typehint in gast.Name ] apply_optimization(transformers) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f258b98b50711942f62fac9ad0c874328ce0cdfd --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/typehint_transformer.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils import gast +import warnings + +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static import utils +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer + + +class TypeHintTransformer(BaseTransformer): + """ + A class remove all the typehint in gast.Name(annotation). + Please put it behind other transformers because other transformer may relay on typehints. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of TypeHintTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + + def visit_FunctionDef(self, node): + node.returns = None + self.generic_visit(node) + return node + + def visit_Name(self, node): + node.annotation = None + self.generic_visit(node) + return node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py new file mode 100644 index 0000000000000000000000000000000000000000..b8addd53d5b1804f127658c46ac431facc0821b4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_typehint.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import unittest + +from paddle.fluid.dygraph.jit import declarative + +SEED = 2020 +np.random.seed(SEED) + + +class A: + pass + + +def function(x: A) -> A: + t: A = A() + return 2 * x + + +class TestTransformWhileLoop(unittest.TestCase): + + def setUp(self): + self.place = fluid.CUDAPlace( + 0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() + self.x = np.zeros(shape=(1), dtype=np.int32) + self._init_dyfunc() + + def _init_dyfunc(self): + self.dyfunc = function + + def _run_static(self): + return self._run(to_static=True) + + def _run_dygraph(self): + return self._run(to_static=False) + + def _run(self, to_static): + with fluid.dygraph.guard(self.place): + # Set the input of dyfunc to VarBase + tensor_x = fluid.dygraph.to_variable(self.x, zero_copy=False) + if to_static: + ret = declarative(self.dyfunc)(tensor_x) + else: + ret = self.dyfunc(tensor_x) + if hasattr(ret, "numpy"): + return ret.numpy() + else: + return ret + + def test_ast_to_func(self): + static_numpy = self._run_static() + dygraph_numpy = self._run_dygraph() + print(static_numpy, dygraph_numpy) + np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05) + + +class TestTypeHint(TestTransformWhileLoop): + + def _init_dyfunc(self): + self.dyfunc = function + + +if __name__ == '__main__': + with fluid.framework._test_eager_guard(): + unittest.main()