未验证 提交 945b286b 编写于 作者: F feifei-111 提交者: GitHub

[dy2static] support user to use decorator in their program (#45768)

* support deco

* fix deco ast type

* arg_str

* 1

* support callable deco

* code style

* codestyle

* test_error

* fix decos in another file

* recover conflict codes
上级 d3366853
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
# It provides a compatibility layer between the AST of various Python versions, # It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module. # as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/ # See details in https://github.com/serge-sans-paille/gast/
import os import os
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
...@@ -38,6 +39,7 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTran ...@@ -38,6 +39,7 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTran
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.decorator_transformer import DecoratorTransformer
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
...@@ -105,6 +107,7 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -105,6 +107,7 @@ class DygraphToStaticAst(BaseTransformer):
CallTransformer, # transform call recursively CallTransformer, # transform call recursively
CastTransformer, # type casting statement CastTransformer, # type casting statement
GradTransformer, # transform paddle.grad to paddle.gradients GradTransformer, # transform paddle.grad to paddle.gradients
DecoratorTransformer, # transform decorators to function call
] ]
apply_optimization(transformers) apply_optimization(transformers)
...@@ -120,30 +123,6 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -120,30 +123,6 @@ class DygraphToStaticAst(BaseTransformer):
self.decorate_func_name = node.name self.decorate_func_name = node.name
self.generic_visit(node) self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = []
ignore_list = ["staticmethod"]
for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id in ignore_list:
continue
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ d.id + " in " + self.decorate_func_name)
if isinstance(d, gast.Attribute):
full_attribute_name = get_attribute_full_name(d)
has_translate_decorator = False
for deco in DECORATOR_NAMES:
if deco in full_attribute_name:
has_translate_decorator = True
break
if not has_translate_decorator:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
+ full_attribute_name + " in " +
self.decorate_func_name)
node.decorator_list = decorator_list
return node return node
def get_module_name(self): def get_module_name(self):
......
...@@ -33,7 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge ...@@ -33,7 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"] __all__ = ["convert_call"]
...@@ -206,8 +206,9 @@ def convert_call(func): ...@@ -206,8 +206,9 @@ def convert_call(func):
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`. # `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here. # `foo` function. So `isinstance(fn, StaticFunction)` is added here.
_origfunc = unwrap(func)
global_functions = set() global_functions = set()
for fn in func.__globals__.values(): for fn in _origfunc.__globals__.values():
if inspect.isfunction(fn): if inspect.isfunction(fn):
global_functions.add(fn) global_functions.add(fn)
elif isinstance(fn, StaticFunction): elif isinstance(fn, StaticFunction):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code
import re
IGNORE_NAMES = [
'declarative', 'to_static', 'dygraph_to_static_func', 'wraps',
'staticmethod', 'classmethod'
]
class DecoratorTransformer(BaseTransformer):
"""
Transform decorators.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
self.ancestor_nodes = []
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
def visit_FunctionDef(self, node):
assert isinstance(node, gast.FunctionDef)
self.generic_visit(node)
deco_list = node.decorator_list
node.decorator_list = []
# every decorator will append a node
decofun_nodes = []
# func to be decoed next time
deco_target = '_orig_' + node.name
# last decoed func
decoed_func = ''
for deco in reversed(deco_list):
# skip INGNORE_NAMES
if isinstance(deco, gast.Attribute):
deco_name = deco.attr
elif isinstance(deco, gast.Call):
if hasattr(deco.func, 'args'):
deco_name = deco.func.args[0].id
elif hasattr(deco.func, 'attr'):
deco_name = deco.func.attr
else:
deco_name = deco.func.id
else:
deco_name = deco.id
if deco_name in IGNORE_NAMES:
continue
# get function after decoration
deco_full_name = ast_to_source_code(deco).strip()
decoed_func = '_decoby_' + deco_name
if isinstance(deco, gast.Call):
# in this case , the deco_full_name will be like:
# '_jst.Call(deco)(5)'
rematch = re.match(r'\_jst\.Call\((.+?)\)\((.+?)\)',
deco_full_name)
re_name = rematch.group(1)
re_args = rematch.group(2)
re_args_with_func = deco_target + ', ' + re_args
decofun_str = 'try:\n\t{0} = _jst.Call({1})({2})\nexcept:\n\t{0} = _jst.Call({1})({3})({4})'\
.format(decoed_func, re_name, re_args_with_func, re_args, deco_target)
else:
decofun_str = '{} = _jst.Call({})({})'.format(
decoed_func, deco_full_name, deco_target)
decofun_nodes.extend(gast.parse(decofun_str).body)
deco_target = decoed_func
if not decofun_nodes:
return node
orig_func_node = gast.FunctionDef(name='_orig_' + node.name,
args=node.args,
body=node.body,
decorator_list=[],
returns=None,
type_comment=None)
args = [arg.id for arg in node.args.args]
arg_str = ','.join(args)
callfun_str = 'return {}({})'.format(decoed_func, arg_str)
callfun_node = gast.parse(callfun_str).body[0]
node.body = [orig_func_node] + decofun_nodes + [callfun_node]
return node
...@@ -228,6 +228,20 @@ class ReturnTransformer(BaseTransformer): ...@@ -228,6 +228,20 @@ class ReturnTransformer(BaseTransformer):
# Prepend no value placeholders # Prepend no value placeholders
self.function_def.pop() self.function_def.pop()
# Need update self.pre_analysis after pop
# For fix this case:
'''
def fun(cond):
def inner():
pass
if cond:
return True
else:
return False
'''
if self.function_def:
self.pre_analysis = ReturnAnalysisVisitor(self.function_def[-1])
return node return node
def visit_Return(self, node): def visit_Return(self, node):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import paddle
from functools import wraps
def deco1(fun):
@wraps(fun)
def inner(*args, **kwargs):
print('in decos.deco1, added 1')
_t = paddle.to_tensor([1])
_tt = fun(*args, **kwargs)
return paddle.add(_t, _tt)
return inner
def deco2(x=0):
def inner_deco(func):
@wraps(func)
def inner(*args, **kwargs):
print('in decos.deco2, added {}'.format(x))
_t = paddle.to_tensor(x)
_tt = func(*args, **kwargs)
return paddle.add(_t, _tt)
return inner
return inner_deco
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import unittest
import numpy as np
import decos
from functools import wraps
def deco1(func):
@wraps(func)
def inner(*args, **kwargs):
print('in deco1, added 1')
_x = 2
if (_x < 1):
_x += 1
else:
_x -= 1
_t = paddle.to_tensor([1])
_tt = func(*args, **kwargs)
return paddle.add(_t, _tt)
return inner
def deco2(fun):
@wraps(fun)
def inner(*args, **kwargs):
print('in deco2, added 2')
_t = paddle.to_tensor([2])
_tt = fun(*args, **kwargs)
return paddle.add(_t, _tt)
return inner
def deco3(x=3):
def inner_deco(func):
@wraps(func)
def inner(*args, **kwargs):
print('in deco3, added {}'.format(x))
_t = paddle.to_tensor(x)
_tt = func(*args, **kwargs)
return paddle.add(_t, _tt)
return inner
return inner_deco
def deco4(func=None, x=0):
def decorated(pyfunc):
@wraps(pyfunc)
def inner_deco(*args, **kwargs):
print('in deco4, added {}'.format(x))
_t = paddle.to_tensor(x)
_tt = pyfunc(*args, **kwargs)
return paddle.add(_t, _tt)
return inner_deco
if func == None:
return decorated
return decorated(func)
@deco2
def fun1(x, y=0):
a = paddle.to_tensor(y)
print('in fun1, x=%d' % (x))
return a
@deco1
@deco2
def fun2(x, y=0):
a = paddle.to_tensor(y)
print('in fun2, x=%d' % (x))
return a
@deco3(3)
def fun3(x, y=0):
a = paddle.to_tensor(y)
print('in fun3, x=%d' % (x))
return a
@deco4(x=4)
def fun4(x, y=0):
a = paddle.to_tensor(y)
print('in fun4, x=%d' % (x))
return a
@deco2
@deco4(x=5)
def fun5(x, y=0):
a = paddle.to_tensor(y)
print('in fun5, x=%d' % (x))
return a
@decos.deco1
@decos.deco2(2)
def fun6(x, y=0):
a = paddle.to_tensor(y)
print('in fun6, x=%d' % (x))
return a
@paddle.jit.to_static
def forward():
funcs = [fun1, fun2, fun3, fun4, fun5, fun6]
out = []
for idx, fun in enumerate(funcs):
out.append(fun(idx + 1, idx + 1))
return out
class TestDecoratorTransform(unittest.TestCase):
def test_deco_transform(self):
outs = forward()
np.testing.assert_allclose(outs[0], np.array(3), rtol=1e-05)
np.testing.assert_allclose(outs[1], np.array(5), rtol=1e-05)
np.testing.assert_allclose(outs[2], np.array(6), rtol=1e-05)
np.testing.assert_allclose(outs[3], np.array(8), rtol=1e-05)
np.testing.assert_allclose(outs[4], np.array(12), rtol=1e-05)
np.testing.assert_allclose(outs[5], np.array(9), rtol=1e-05)
if __name__ == '__main__':
unittest.main()
...@@ -399,16 +399,6 @@ class TestJitSaveInCompiletime(TestErrorBase): ...@@ -399,16 +399,6 @@ class TestJitSaveInCompiletime(TestErrorBase):
# # Situation 4: NotImplementedError # # Situation 4: NotImplementedError
class TestErrorInOther(unittest.TestCase):
def test(self):
paddle.disable_static()
prog_trans = paddle.jit.ProgramTranslator()
with self.assertRaises(NotImplementedError):
prog_trans.get_output(func_decorated_by_other_1)
with self.assertRaises(NotImplementedError):
func_decorated_by_other_2()
class TestSuggestionErrorInRuntime(TestErrorBase): class TestSuggestionErrorInRuntime(TestErrorBase):
def set_func(self): def set_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册