diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index fd146d77632ca1ccae52ded0f4bbfea7f2b428b0..e045348e6c942acfa04c168e587e2d6009048db2 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -18,6 +18,7 @@ from __future__ import print_function # It provides a compatibility layer between the AST of various Python versions, # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ + import os from paddle.utils import gast from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer @@ -38,6 +39,7 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTran from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer +from paddle.fluid.dygraph.dygraph_to_static.decorator_transformer import DecoratorTransformer from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code @@ -45,8 +47,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name __all__ = ['DygraphToStaticAst'] -DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func'] - def apply_optimization(transformers): """ @@ -105,6 +105,7 @@ class DygraphToStaticAst(BaseTransformer): CallTransformer, # transform call recursively CastTransformer, # type casting statement GradTransformer, # transform paddle.grad to paddle.gradients + DecoratorTransformer, # transform decorators to function call ] apply_optimization(transformers) @@ -120,30 +121,6 @@ class DygraphToStaticAst(BaseTransformer): self.decorate_func_name = node.name self.generic_visit(node) - # Remove the decorated name of dygraph_to_static - if hasattr(node, 'decorator_list'): - decorator_list = [] - ignore_list = ["staticmethod"] - for d in node.decorator_list: - if isinstance(d, gast.Name) and d.id in ignore_list: - continue - if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: - raise NotImplementedError( - "ProgramTranslator hasn't implemented multiple decorators. Please remove " - + d.id + " in " + self.decorate_func_name) - if isinstance(d, gast.Attribute): - full_attribute_name = get_attribute_full_name(d) - has_translate_decorator = False - for deco in DECORATOR_NAMES: - if deco in full_attribute_name: - has_translate_decorator = True - break - if not has_translate_decorator: - raise NotImplementedError( - "ProgramTranslator hasn't implemented multiple decorators. Please remove " - + full_attribute_name + " in " + - self.decorate_func_name) - node.decorator_list = decorator_list return node def get_module_name(self): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index fda668dc7455f2901b8b07f7448b5ea0f4d5a529..a3d96b6fe0ad868ee4e32455fbf539359e14f7c2 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -33,7 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators -from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap from paddle.fluid.dygraph.layers import Layer __all__ = ["convert_call"] @@ -206,8 +206,9 @@ def convert_call(func): # `foo` will be converted into a wrapper class, suppose as `StaticFunction`. # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of # `foo` function. So `isinstance(fn, StaticFunction)` is added here. + _origfunc = unwrap(func) global_functions = set() - for fn in func.__globals__.values(): + for fn in _origfunc.__globals__.values(): if inspect.isfunction(fn): global_functions.add(fn) elif isinstance(fn, StaticFunction): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8442403e04c83ef1653572dba12f196e97970d1a --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.utils import gast +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException +import warnings + +import re + +IGNORE_NAMES = [ + 'declarative', 'to_static', 'dygraph_to_static_func', 'wraps', + 'staticmethod', 'classmethod', 'decorator' +] + + +class DecoratorTransformer(BaseTransformer): + """ + Transform decorators. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Type of input node should be AstNodeWrapper, but received %s ." % type( + wrapper_root) + self.root = wrapper_root.node + + self.ancestor_nodes = [] + + def transform(self): + """ + Main function to transform AST. + """ + self.visit(self.root) + + def visit_FunctionDef(self, node): + assert isinstance(node, gast.FunctionDef) + self.generic_visit(node) + + deco_list = node.decorator_list + node.decorator_list = [] + + # every decorator will append a node + decofun_nodes = [] + # func to be decoed next time + deco_target = '_orig_' + node.name + # last decoed func + decoed_func = '' + + for deco in reversed(deco_list): + # skip INGNORE_NAMES + if isinstance(deco, gast.Attribute): + deco_name = deco.attr + elif isinstance(deco, gast.Call): + if hasattr(deco.func, 'args'): + deco_name = deco.func.args[0].id + elif hasattr(deco.func, 'attr'): + deco_name = deco.func.attr + else: + deco_name = deco.func.id + else: + deco_name = deco.id + if deco_name in IGNORE_NAMES: + continue + elif deco_name == 'contextmanager': + warnings.warn( + "Dy2Static : A context manager decorator is used, this may not work correctly after transform." + ) + + deco_full_name = ast_to_source_code(deco).strip() + decoed_func = '_decoedby_' + deco_name + + # get function after decoration + if isinstance(deco, gast.Call): + if '_jst.Call' in deco_full_name: + # in this case , the deco_full_name will be like: + # '_jst.Call(deco)(5)' + rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)', + deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + else: + # paddle api will not be transformed to '_jst.Call' + rematch = re.match(r'(.+?)\((.*)\)', deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + + else: + decofun_str = '{} = _jst.Call({})({})'.format( + decoed_func, deco_full_name, deco_target) + + decofun_nodes.extend(gast.parse(decofun_str).body) + deco_target = decoed_func + + if not decofun_nodes: + return node + + orig_func_node = gast.FunctionDef(name='_orig_' + node.name, + args=node.args, + body=node.body, + decorator_list=[], + returns=None, + type_comment=None) + + args = [arg.id for arg in node.args.args] + arg_str = ','.join(args) + callfun_str = 'return {}({})'.format(decoed_func, arg_str) + callfun_node = gast.parse(callfun_str).body[0] + + node.body = [orig_func_node] + decofun_nodes + [callfun_node] + + return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index 3eadd455e1033e32e805da00ae8407f517556d02..ed2a739936e1e337da693570d57fcc083226623f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -228,6 +228,20 @@ class ReturnTransformer(BaseTransformer): # Prepend no value placeholders self.function_def.pop() + + # Need update self.pre_analysis after pop + # For fix this case: + ''' + def fun(cond): + def inner(): + pass + if cond: + return True + else: + return False + ''' + if self.function_def: + self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1]) return node def visit_Return(self, node): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3333c15a0ce1c1ac44fcb37c5190c91b783620 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/decos.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import paddle + +from functools import wraps + + +def deco1(fun): + + @wraps(fun) + def inner(*args, **kwargs): + print('in decos.deco1, added 1') + _t = paddle.to_tensor([1]) + _tt = fun(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + +def deco2(x=0): + + def inner_deco(func): + + @wraps(func) + def inner(*args, **kwargs): + print('in decos.deco2, added {}'.format(x)) + _t = paddle.to_tensor(x) + _tt = func(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + return inner_deco diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..4acc789a451bb09397a9e4ac19e94eef362a591d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py @@ -0,0 +1,223 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import unittest +import numpy as np +import decos +import warnings +from functools import wraps +from contextlib import contextmanager + + +def deco1(func): + + @wraps(func) + def inner(*args, **kwargs): + print('in deco1, added 1') + _x = 2 + if (_x < 1): + _x += 1 + else: + _x -= 1 + _t = paddle.to_tensor([1]) + _tt = func(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + +def deco2(fun): + + @wraps(fun) + def inner(*args, **kwargs): + print('in deco2, added 2') + _t = paddle.to_tensor([2]) + _tt = fun(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + +def deco3(x=3): + + def inner_deco(func): + + @wraps(func) + def inner(*args, **kwargs): + print('in deco3, added {}'.format(x)) + _t = paddle.to_tensor(x) + _tt = func(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner + + return inner_deco + + +def deco4(func=None, x=0): + + def decorated(pyfunc): + + @wraps(pyfunc) + def inner_deco(*args, **kwargs): + print('in deco4, added {}'.format(x)) + _t = paddle.to_tensor(x) + _tt = pyfunc(*args, **kwargs) + return paddle.add(_t, _tt) + + return inner_deco + + if func == None: + return decorated + return decorated(func) + + +def deco5(): + return deco2 + + +def deco6(x=0): + return deco2 + + +@deco2 +def fun1(x, y=0): + a = paddle.to_tensor(y) + print('in fun1, x=%d' % (x)) + return a + + +@deco1 +@deco2 +def fun2(x, y=0): + a = paddle.to_tensor(y) + print('in fun2, x=%d' % (x)) + return a + + +@deco3(3) +def fun3(x, y=0): + a = paddle.to_tensor(y) + print('in fun3, x=%d' % (x)) + return a + + +@deco4(x=4) +def fun4(x, y=0): + a = paddle.to_tensor(y) + print('in fun4, x=%d' % (x)) + return a + + +@deco2 +@deco4() +def fun5(x, y=0): + a = paddle.to_tensor(y) + print('in fun5, x=%d' % (x)) + return a + + +@decos.deco1 +@decos.deco2(2) +def fun6(x, y=0): + a = paddle.to_tensor(y) + print('in fun6, x=%d' % (x)) + return a + + +@deco5() +def fun7(x, y=0): + a = paddle.to_tensor(y) + print('in fun7, x=%d' % (x)) + return a + + +@deco6(2) +def fun8(x, y=0): + a = paddle.to_tensor(y) + print('in fun8, x=%d' % (x)) + return a + + +@paddle.jit.to_static +def forward(): + funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8] + out = [] + for idx, fun in enumerate(funcs): + out.append(fun(idx + 1, idx + 1)) + return out + + +@contextmanager +def contextmanager_warning(): + yield + + +@contextmanager_warning() +def fun9(): + print('in fun9 want contextmanager warning') + + +@paddle.jit.to_static +def warn1(): + fun9() + + +@paddle.no_grad() +def fun10(): + print('in fun10, paddle api decorated') + return True + + +@paddle.jit.to_static +def deco_with_paddle_api(): + return fun10() + + +class TestDecoratorTransform(unittest.TestCase): + + def test_deco_transform(self): + outs = forward() + np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05) + np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) + np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) + np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05) + np.testing.assert_allclose(outs[4], np.array(7), rtol=1e-05) + np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05) + np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) + np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) + + def test_contextmanager_warning(self): + paddle.disable_static() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn1() + flag = False + for warn in w: + if (issubclass(warn.category, UserWarning) + ) and "A context manager decorator is used" in str( + warn.message): + flag = True + break + self.assertTrue(flag) + + def test_deco_with_paddle_api(self): + self.assertTrue(deco_with_paddle_api()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 27d7389b903cc481844783e4c861b83b56febf7d..97f0cf99b5f65df524b3bf66356526d5b155c504 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -399,16 +399,6 @@ class TestJitSaveInCompiletime(TestErrorBase): # # Situation 4: NotImplementedError -class TestErrorInOther(unittest.TestCase): - def test(self): - paddle.disable_static() - prog_trans = paddle.jit.ProgramTranslator() - with self.assertRaises(NotImplementedError): - prog_trans.get_output(func_decorated_by_other_1) - - with self.assertRaises(NotImplementedError): - func_decorated_by_other_2() - class TestSuggestionErrorInRuntime(TestErrorBase): def set_func(self):