未验证 提交 22530137 编写于 作者: F feifei-111 提交者: GitHub

[BugFix] fixed a bug in decorator transformer, it can not analyze decorator...

[BugFix] fixed a bug in decorator transformer, it can not analyze decorator with params correctly (#46055)

* fix deco call

* add raise

* add test

* add warn, fix paddle api

* fix error type

* fix coverage
上级 abe1dca3
...@@ -47,8 +47,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name ...@@ -47,8 +47,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
__all__ = ['DygraphToStaticAst'] __all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']
def apply_optimization(transformers): def apply_optimization(transformers):
""" """
......
...@@ -18,13 +18,14 @@ from __future__ import print_function ...@@ -18,13 +18,14 @@ from __future__ import print_function
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code, is_paddle_api, Dygraph2StaticException
import warnings
import re import re
IGNORE_NAMES = [ IGNORE_NAMES = [
'declarative', 'to_static', 'dygraph_to_static_func', 'wraps', 'declarative', 'to_static', 'dygraph_to_static_func', 'wraps',
'staticmethod', 'classmethod' 'staticmethod', 'classmethod', 'decorator'
] ]
...@@ -77,20 +78,35 @@ class DecoratorTransformer(BaseTransformer): ...@@ -77,20 +78,35 @@ class DecoratorTransformer(BaseTransformer):
deco_name = deco.id deco_name = deco.id
if deco_name in IGNORE_NAMES: if deco_name in IGNORE_NAMES:
continue continue
elif deco_name == 'contextmanager':
warnings.warn(
"Dy2Static : A context manager decorator is used, this may not work correctly after transform."
)
# get function after decoration
deco_full_name = ast_to_source_code(deco).strip() deco_full_name = ast_to_source_code(deco).strip()
decoed_func = '_decoby_' + deco_name decoed_func = '_decoedby_' + deco_name
# get function after decoration
if isinstance(deco, gast.Call): if isinstance(deco, gast.Call):
if '_jst.Call' in deco_full_name:
# in this case , the deco_full_name will be like: # in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)' # '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)', rematch = re.match(r'\_jst\.Call\((.+?)\)\((.*)\)',
deco_full_name) deco_full_name)
re_name = rematch.group(1) re_name = rematch.group(1)
re_args = rematch.group(2) re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\ decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target) .format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else:
# paddle api will not be transformed to '_jst.Call'
rematch = re.match(r'(.+?)\((.*)\)', deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = {1}({2})\nexcept:\n\t{0} = {1}({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else: else:
decofun_str = '{} = _jst.Call({})({})'.format( decofun_str = '{} = _jst.Call({})({})'.format(
decoed_func, deco_full_name, deco_target) decoed_func, deco_full_name, deco_target)
......
...@@ -18,7 +18,9 @@ import paddle ...@@ -18,7 +18,9 @@ import paddle
import unittest import unittest
import numpy as np import numpy as np
import decos import decos
import warnings
from functools import wraps from functools import wraps
from contextlib import contextmanager
def deco1(func): def deco1(func):
...@@ -84,6 +86,14 @@ def deco4(func=None, x=0): ...@@ -84,6 +86,14 @@ def deco4(func=None, x=0):
return decorated(func) return decorated(func)
def deco5():
return deco2
def deco6(x=0):
return deco2
@deco2 @deco2
def fun1(x, y=0): def fun1(x, y=0):
a = paddle.to_tensor(y) a = paddle.to_tensor(y)
...@@ -114,7 +124,7 @@ def fun4(x, y=0): ...@@ -114,7 +124,7 @@ def fun4(x, y=0):
@deco2 @deco2
@deco4(x=5) @deco4()
def fun5(x, y=0): def fun5(x, y=0):
a = paddle.to_tensor(y) a = paddle.to_tensor(y)
print('in fun5, x=%d' % (x)) print('in fun5, x=%d' % (x))
...@@ -129,15 +139,55 @@ def fun6(x, y=0): ...@@ -129,15 +139,55 @@ def fun6(x, y=0):
return a return a
@deco5()
def fun7(x, y=0):
a = paddle.to_tensor(y)
print('in fun7, x=%d' % (x))
return a
@deco6(2)
def fun8(x, y=0):
a = paddle.to_tensor(y)
print('in fun8, x=%d' % (x))
return a
@paddle.jit.to_static @paddle.jit.to_static
def forward(): def forward():
funcs = [fun1, fun2, fun3, fun4, fun5, fun6] funcs = [fun1, fun2, fun3, fun4, fun5, fun6, fun7, fun8]
out = [] out = []
for idx, fun in enumerate(funcs): for idx, fun in enumerate(funcs):
out.append(fun(idx + 1, idx + 1)) out.append(fun(idx + 1, idx + 1))
return out return out
@contextmanager
def contextmanager_warning():
yield
@contextmanager_warning()
def fun9():
print('in fun9 want contextmanager warning')
@paddle.jit.to_static
def warn1():
fun9()
@paddle.no_grad()
def fun10():
print('in fun10, paddle api decorated')
return True
@paddle.jit.to_static
def deco_with_paddle_api():
return fun10()
class TestDecoratorTransform(unittest.TestCase): class TestDecoratorTransform(unittest.TestCase):
def test_deco_transform(self): def test_deco_transform(self):
...@@ -146,8 +196,27 @@ class TestDecoratorTransform(unittest.TestCase): ...@@ -146,8 +196,27 @@ class TestDecoratorTransform(unittest.TestCase):
np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05) np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05)
np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05) np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05)
np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05) np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05)
np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05) np.testing.assert_allclose(outs[4], np.array(7), rtol=1e-05)
np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05) np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[6], np.array(9), rtol=1e-05)
np.testing.assert_allclose(outs[7], np.array(10), rtol=1e-05)
def test_contextmanager_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn1()
flag = False
for warn in w:
if (issubclass(warn.category, UserWarning)
) and "A context manager decorator is used" in str(
warn.message):
flag = True
break
self.assertTrue(flag)
def test_deco_with_paddle_api(self):
self.assertTrue(deco_with_paddle_api())
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册