diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index b936c47b51135801eb8eba89aa4b81bc22c72731..e045348e6c942acfa04c168e587e2d6009048db2 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -47,8 +47,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name __all__ = ['DygraphToStaticAst'] -DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func'] - def apply_optimization(transformers): """ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py index ab193b674c25c430445ee7d44c130862530fcc3c..8442403e04c83ef1653572dba12f196e97970d1a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py @@ -18,13 +18,14 @@ from __future__ import print_function from paddle.utils import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer -from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException +import warnings import re IGNORE_NAMES = [ 'declarative', 'to_static', 'dygraph_to_static_func', 'wraps', - 'staticmethod', 'classmethod' + 'staticmethod', 'classmethod', 'decorator' ] @@ -77,20 +78,35 @@ class DecoratorTransformer(BaseTransformer): deco_name = deco.id if deco_name in IGNORE_NAMES: continue + elif deco_name == 'contextmanager': + warnings.warn( + "Dy2Static : A context manager decorator is used, this may not work correctly after transform." + ) - # get function after decoration deco_full_name = ast_to_source_code(deco).strip() - decoed_func = '_decoby_' + deco_name + decoed_func = '_decoedby_' + deco_name + + # get function after decoration if isinstance(deco, gast.Call): - # in this case , the deco_full_name will be like: - # '_jst.Call(deco)(5)' - rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)', - deco_full_name) - re_name = rematch.group(1) - re_args = rematch.group(2) - re_args_with_func = deco_target + ', ' + re_args - decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ - .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + if '_jst.Call' in deco_full_name: + # in this case , the deco_full_name will be like: + # '_jst.Call(deco)(5)' + rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)', + deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + else: + # paddle api will not be transformed to '_jst.Call' + rematch = re.match(r'(.+?)\((.*)\)', deco_full_name) + re_name = rematch.group(1) + re_args = rematch.group(2) + re_args_with_func = deco_target + ', ' + re_args + decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\ + .format(decoed_func, re_name, re_args_with_func, re_args, deco_target) + else: decofun_str = '{} = _jst.Call({})({})'.format( decoed_func, deco_full_name, deco_target) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py index c6c2750e307939884ac104eb3283cb5dc80c6c7b..4acc789a451bb09397a9e4ac19e94eef362a591d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_decorator_transform.py @@ -18,7 +18,9 @@ import paddle import unittest import numpy as np import decos +import warnings from functools import wraps +from contextlib import contextmanager def deco1(func): @@ -84,6 +86,14 @@ def deco4(func=None, x=0): return decorated(func) +def deco5(): + return deco2 + + +def deco6(x=0): + return deco2 + + @deco2 def fun1(x, y=0): a = paddle.to_tensor(y) @@ -114,7 +124,7 @@ def fun4(x, y=0): @deco2 -@deco4(x=5) +@deco4() def fun5(x, y=0): a = paddle.to_tensor(y) print('in fun5, x=%d' % (x)) @@ -129,15 +139,55 @@ def fun6(x, y=0): return a +@deco5() +def fun7(x, y=0): + a = paddle.to_tensor(y) + print('in fun7, x=%d' % (x)) + return a + + +@deco6(2) +def fun8(x, y=0): + a = paddle.to_tensor(y) + print('in fun8, x=%d' % (x)) + return a + + @paddle.jit.to_static def forward(): - funcs = [fun1, fun2, fun3, fun4, fun5, fun6] + funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8] out = [] for idx, fun in enumerate(funcs): out.append(fun(idx + 1, idx + 1)) return out +@contextmanager +def contextmanager_warning(): + yield + + +@contextmanager_warning() +def fun9(): + print('in fun9 want contextmanager warning') + + +@paddle.jit.to_static +def warn1(): + fun9() + + +@paddle.no_grad() +def fun10(): + print('in fun10, paddle api decorated') + return True + + +@paddle.jit.to_static +def deco_with_paddle_api(): + return fun10() + + class TestDecoratorTransform(unittest.TestCase): def test_deco_transform(self): @@ -146,8 +196,27 @@ class TestDecoratorTransform(unittest.TestCase): np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05) - np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05) + np.testing.assert_allclose(outs[4], np.array(7), rtol=1e-05) np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05) + np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05) + np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05) + + def test_contextmanager_warning(self): + paddle.disable_static() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn1() + flag = False + for warn in w: + if (issubclass(warn.category, UserWarning) + ) and "A context manager decorator is used" in str( + warn.message): + flag = True + break + self.assertTrue(flag) + + def test_deco_with_paddle_api(self): + self.assertTrue(deco_with_paddle_api()) if __name__ == '__main__':