From f3ea615673f13323eec89ccb6a50c759c9909201 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 19 Aug 2020 11:31:29 +0800 Subject: [PATCH] [Dy2Stat] Fix bug: unwrap func in dy2Stat. (#26279) --- .../dygraph/dygraph_to_static/origin_info.py | 14 +------ .../dygraph_to_static/program_translator.py | 40 +++++++++++-------- .../fluid/dygraph/dygraph_to_static/utils.py | 16 ++++++++ 3 files changed, 41 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py index aeece9513b5..13f38b0726c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py @@ -18,8 +18,8 @@ import collections import inspect import gast - from paddle.fluid import core +from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap from paddle.fluid.framework import Program # NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. @@ -197,18 +197,6 @@ def attach_origin_info(ast_node, func): return ast_node -# NOTE: inspect.unwrap() exits in PY3 but not in PY2. -def unwrap(func): - def _is_wrapped(f): - return hasattr(f, '__wrapped__') - - unwrapped_f = func - while (_is_wrapped(unwrapped_f)): - unwrapped_f = unwrapped_f.__wrapped__ - - return unwrapped_f - - def ast_walk(transformed_node, static_node): """ Recursively yield all descendant nodes in the trees starting at transformed_node and static_node (including itself) in parallel. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 88562dd40a6..ceacba25375 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -13,32 +13,38 @@ # limitations under the License. from __future__ import print_function -import gast + +import collections import inspect -import warnings import textwrap import threading -import collections +import warnings + +import gast import numpy as np -from paddle.fluid import core, scope_guard -from paddle.fluid import framework +from paddle.fluid import core from paddle.fluid import executor +from paddle.fluid import framework +from paddle.fluid import scope_guard from paddle.fluid import unique_name +from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers -from paddle.fluid.layers.utils import flatten -from paddle.fluid.layers.utils import pack_sequence_as +from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst +from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA +from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data +from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info +from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map +from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info +from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func +from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap +from paddle.fluid.layers.utils import flatten +from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.wrapped_decorator import signature_safe_contextmanager -from paddle.fluid.dygraph.base import param_guard -from paddle.fluid.data_feeder import check_type -from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from -from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map -from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info -from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA __all__ = ['ProgramTranslator', 'convert_to_static'] @@ -89,7 +95,7 @@ class FunctionCache(object): """ # Note: In Python2, it will raise OSError when inspect function # with decorator directly and function.__wrapped__ holds the actual function. - func = getattr(func, '__wrapped__', func) + func = unwrap(func) source_code = func_to_source_code(func) # TODO(liym27): @@ -669,7 +675,9 @@ class ProgramTranslator(object): dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_code" # Gets AST from dygraph function - raw_code = inspect.getsource(dygraph_func) + + unwrap_func = unwrap(dygraph_func) + raw_code = inspect.getsource(unwrap_func) code = textwrap.dedent(raw_code) root = gast.parse(code) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index def201cedc2..58cad1cfa42 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -1052,3 +1052,19 @@ class SplitAssignTransformer(gast.NodeTransformer): value_node = target return new_nodes + + +# NOTE: inspect.unwrap() exits in PY3 but not in PY2. +def unwrap(func): + """ + Returns the object wrapped by decorators. + """ + + def _is_wrapped(f): + return hasattr(f, '__wrapped__') + + unwrapped_f = func + while (_is_wrapped(unwrapped_f)): + unwrapped_f = unwrapped_f.__wrapped__ + + return unwrapped_f -- GitLab