未验证 提交 22530137 编写于 作者: F feifei-111 提交者: GitHub

[BugFix] fixed a bug in decorator transformer, it can not analyze decorator...

[BugFix] fixed a bug in decorator transformer, it can not analyze decorator with params correctly (#46055)

* fix deco call

* add raise

* add test

* add warn, fix paddle api

* fix error type

* fix coverage
上级 abe1dca3
......@@ -47,8 +47,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
__all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']
def apply_optimization(transformers):
"""
......
......@@ -18,13 +18,14 @@ from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException
import warnings
import re
IGNORE_NAMES = [
'declarative', 'to_static', 'dygraph_to_static_func', 'wraps',
'staticmethod', 'classmethod'
'staticmethod', 'classmethod', 'decorator'
]
......@@ -77,20 +78,35 @@ class DecoratorTransformer(BaseTransformer):
deco_name = deco.id
if deco_name in IGNORE_NAMES:
continue
elif deco_name == 'contextmanager':
warnings.warn(
"Dy2Static : A context manager decorator is used, this may not work correctly after transform."
)
# get function after decoration
deco_full_name = ast_to_source_code(deco).strip()
decoed_func = '_decoby_' + deco_name
decoed_func = '_decoedby_' + deco_name
# get function after decoration
if isinstance(deco, gast.Call):
# in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)',
deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
if '_jst.Call' in deco_full_name:
# in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)',
deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else:
# paddle api will not be transformed to '_jst.Call'
rematch = re.match(r'(.+?)\((.*)\)', deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else:
decofun_str = '{} = _jst.Call({})({})'.format(
decoed_func, deco_full_name, deco_target)
......
......@@ -18,7 +18,9 @@ import paddle
import unittest
import numpy as np
import decos
import warnings
from functools import wraps
from contextlib import contextmanager
def deco1(func):
......@@ -84,6 +86,14 @@ def deco4(func=None, x=0):
return decorated(func)
def deco5():
return deco2
def deco6(x=0):
return deco2
@deco2
def fun1(x, y=0):
a = paddle.to_tensor(y)
......@@ -114,7 +124,7 @@ def fun4(x, y=0):
@deco2
@deco4(x=5)
@deco4()
def fun5(x, y=0):
a = paddle.to_tensor(y)
print('in fun5, x=%d' % (x))
......@@ -129,15 +139,55 @@ def fun6(x, y=0):
return a
@deco5()
def fun7(x, y=0):
a = paddle.to_tensor(y)
print('in fun7, x=%d' % (x))
return a
@deco6(2)
def fun8(x, y=0):
a = paddle.to_tensor(y)
print('in fun8, x=%d' % (x))
return a
@paddle.jit.to_static
def forward():
funcs = [fun1, fun2, fun3, fun4, fun5, fun6]
funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8]
out = []
for idx, fun in enumerate(funcs):
out.append(fun(idx + 1, idx + 1))
return out
@contextmanager
def contextmanager_warning():
yield
@contextmanager_warning()
def fun9():
print('in fun9 want contextmanager warning')
@paddle.jit.to_static
def warn1():
fun9()
@paddle.no_grad()
def fun10():
print('in fun10, paddle api decorated')
return True
@paddle.jit.to_static
def deco_with_paddle_api():
return fun10()
class TestDecoratorTransform(unittest.TestCase):
def test_deco_transform(self):
......@@ -146,8 +196,27 @@ class TestDecoratorTransform(unittest.TestCase):
np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05)
np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05)
np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05)
np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05)
np.testing.assert_allclose(outs[4], np.array(7), rtol=1e-05)
np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05)
def test_contextmanager_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn1()
flag = False
for warn in w:
if (issubclass(warn.category, UserWarning)
) and "A context manager decorator is used" in str(
warn.message):
flag = True
break
self.assertTrue(flag)
def test_deco_with_paddle_api(self):
self.assertTrue(deco_with_paddle_api())
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册