未验证 提交 f3ea6156 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Fix bug: unwrap func in dy2Stat. (#26279)

上级 1fbee267
......@@ -18,8 +18,8 @@ import collections
import inspect
import gast
from paddle.fluid import core
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
from paddle.fluid.framework import Program
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
......@@ -197,18 +197,6 @@ def attach_origin_info(ast_node, func):
return ast_node
# NOTE: inspect.unwrap() exits in PY3 but not in PY2.
def unwrap(func):
def _is_wrapped(f):
return hasattr(f, '__wrapped__')
unwrapped_f = func
while (_is_wrapped(unwrapped_f)):
unwrapped_f = unwrapped_f.__wrapped__
return unwrapped_f
def ast_walk(transformed_node, static_node):
"""
Recursively yield all descendant nodes in the trees starting at transformed_node and static_node (including itself) in parallel.
......
......@@ -13,32 +13,38 @@
# limitations under the License.
from __future__ import print_function
import gast
import collections
import inspect
import warnings
import textwrap
import threading
import collections
import warnings
import gast
import numpy as np
from paddle.fluid import core, scope_guard
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid import executor
from paddle.fluid import framework
from paddle.fluid import scope_guard
from paddle.fluid import unique_name
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA
__all__ = ['ProgramTranslator', 'convert_to_static']
......@@ -89,7 +95,7 @@ class FunctionCache(object):
"""
# Note: In Python2, it will raise OSError when inspect function
# with decorator directly and function.__wrapped__ holds the actual function.
func = getattr(func, '__wrapped__', func)
func = unwrap(func)
source_code = func_to_source_code(func)
# TODO(liym27):
......@@ -669,7 +675,9 @@ class ProgramTranslator(object):
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
# Gets AST from dygraph function
raw_code = inspect.getsource(dygraph_func)
unwrap_func = unwrap(dygraph_func)
raw_code = inspect.getsource(unwrap_func)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
......
......@@ -1052,3 +1052,19 @@ class SplitAssignTransformer(gast.NodeTransformer):
value_node = target
return new_nodes
# NOTE: inspect.unwrap() exits in PY3 but not in PY2.
def unwrap(func):
"""
Returns the object wrapped by decorators.
"""
def _is_wrapped(f):
return hasattr(f, '__wrapped__')
unwrapped_f = func
while (_is_wrapped(unwrapped_f)):
unwrapped_f = unwrapped_f.__wrapped__
return unwrapped_f
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册