未验证 提交 17eb43bc 编写于 作者: 张春乔 提交者: GitHub

【Hackathon No.89】 Remove circle import Part3 (#51433)

* fix the circle import of NodeVarType

* rollback sth.

* rename the ast

* add utils_helper.py
上级 0e9a48c7
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import astor
from paddle.utils import gast
def ast_to_source_code(ast_node):
"""
Transforms ast node into source code.
"""
if not isinstance(ast_node, (gast.AST, ast.AST)):
raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s."
% type(ast_node)
)
if isinstance(ast_node, gast.AST):
ast_node = gast.gast_to_ast(ast_node)
# Do not wrap lines even if they are too long
def pretty_source(source):
return ''.join(source)
source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code
......@@ -17,7 +17,7 @@ import threading
from paddle.fluid import log_helper
from .utils import ast_to_source_code
from .ast_utils import ast_to_source_code
__all__ = []
......
......@@ -14,9 +14,8 @@
from paddle.utils import gast
from .logging_utils import warn
from .utils import (
ast_to_source_code,
from .utils_helper import (
NodeVarType,
index_in_list,
is_dygraph_api,
is_numpy_api,
......@@ -26,93 +25,6 @@ from .utils import (
__all__ = []
class NodeVarType:
"""
Enum class of python variable types. We have to know some variable types
during compile time to transfer AST. For example, a string variable and a
tensor variable in if clause may lead to different conversion from dygraph
to static graph.
"""
ERROR = -1 # Returns when static analysis gets error
UNKNOWN = 0 # Reserve for AST nodes have not known the type
STATEMENT = 1 # For nodes representing statement (non-variable type)
CALLABLE = 2
# python data types
NONE = 100
BOOLEAN = 101
INT = 102
FLOAT = 103
STRING = 104
TENSOR = 105
NUMPY_NDARRAY = 106
# python collections
LIST = 200
SET = 201
DICT = 202
PADDLE_DYGRAPH_API = 300
PADDLE_CONTROL_IF = 301
PADDLE_CONTROL_WHILE = 302
PADDLE_CONTROL_FOR = 303
# Paddle API may not be visible to get source code.
# We use this enum value to denote the type return by a Paddle API
PADDLE_RETURN_TYPES = 304
# If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}
Annotation_map = {
"Tensor": TENSOR,
"paddle.Tensor": TENSOR,
"int": INT,
"float": FLOAT,
"bool": BOOLEAN,
"str": STRING,
}
@staticmethod
def binary_op_output_type(in_type1, in_type2):
if in_type1 == in_type2:
return in_type1
if in_type1 == NodeVarType.UNKNOWN:
return in_type2
if in_type2 == NodeVarType.UNKNOWN:
return in_type1
supported_types = [
NodeVarType.BOOLEAN,
NodeVarType.INT,
NodeVarType.FLOAT,
NodeVarType.NUMPY_NDARRAY,
NodeVarType.TENSOR,
NodeVarType.PADDLE_RETURN_TYPES,
]
if in_type1 not in supported_types:
return NodeVarType.UNKNOWN
if in_type2 not in supported_types:
return NodeVarType.UNKNOWN
forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR]
if in_type1 in forbidden_types and in_type2 in forbidden_types:
return NodeVarType.UNKNOWN
return max(in_type1, in_type2)
@staticmethod
def type_from_annotation(annotation):
annotation_str = ast_to_source_code(annotation).strip()
if annotation_str in NodeVarType.Annotation_map:
return NodeVarType.Annotation_map[annotation_str]
# raise warning if not found
warn("Currently we don't support annotation: %s" % annotation_str)
return NodeVarType.UNKNOWN
class AstNodeWrapper:
"""
Wrapper for python gast.node. We need a node wrapper because gast.node
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import atexit
import builtins
import copy
......@@ -37,13 +36,23 @@ from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.utils import gast
from .ast_utils import ast_to_source_code
from .static_analysis import StaticAnalysisVisitor
from .utils_helper import DYGRAPH_MODULE_PREFIX # noqa: F401
from .utils_helper import DYGRAPH_TO_STATIC_MODULE_PREFIX # noqa: F401
from .utils_helper import PADDLE_MODULE_PREFIX # noqa: F401
from .utils_helper import NodeVarType # noqa: F401
from .utils_helper import _is_api_in_module_helper # noqa: F401
from .utils_helper import index_in_list # noqa: F401
from .utils_helper import is_api_in_module # noqa: F401
from .utils_helper import is_dygraph_api # noqa: F401
from .utils_helper import is_numpy_api # noqa: F401;
from .utils_helper import is_paddle_api # noqa: F401
__all__ = []
# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.jit.dy2static'
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ALREADY_D2S = '__already_d2s'
......@@ -250,52 +259,6 @@ def make_hashable(x, error_msg=None):
return x
def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith(module_prefix)
def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
# Python can have gast.Call as function, for example: covert_call(func)(x)
# We only check the most outside function
func_node = node.func
while isinstance(func_node, gast.Call):
func_node = func_node.func
func_str = astor.to_source(gast.gast_to_ast(func_node)).strip()
try:
import paddle # noqa: F401
import paddle.fluid as fluid # noqa: F401
import paddle.fluid.dygraph as dygraph # noqa: F401
import paddle.fluid.layers as layers # noqa: F401
import paddle.jit.dy2static as _jst # noqa: F401
from paddle import to_tensor # noqa: F401
from paddle.fluid.dygraph import to_variable # noqa: F401
return eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)
)
except Exception:
return False
def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX):
return False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return is_api_in_module(node, DYGRAPH_MODULE_PREFIX)
def is_paddle_api(node):
return is_api_in_module(node, PADDLE_MODULE_PREFIX)
# NOTE(Aurelius84): Consider the following paddle inner API as common case to
# apply @to_static code transformation as usual. Because they contains
# user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer.
......@@ -341,25 +304,6 @@ def is_paddle_func(func, ignore_white_list=True):
return False
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import numpy as np # noqa: F401
module_result = eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, "numpy")
)
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
return module_result or (
func_str.startswith("numpy.") or func_str.startswith("np.")
)
except Exception:
return False
def _delete_keywords_from(node):
assert isinstance(node, gast.Call)
func_src = astor.to_source(gast.gast_to_ast(node.func))
......@@ -558,14 +502,6 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
return func_def_node
def index_in_list(array_list, item):
try:
return array_list.index(item)
except ValueError:
# Item not in array_list
return -1
def create_assign_node(name, node):
"""
Creates a `gast.Assign` node by given name_id as target and node as value.
......@@ -708,26 +644,6 @@ def func_to_source_code(function, dedent=True):
return source_code
def ast_to_source_code(ast_node):
"""
Transforms ast node into source code.
"""
if not isinstance(ast_node, (gast.AST, ast.AST)):
raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s."
% type(ast_node)
)
if isinstance(ast_node, gast.AST):
ast_node = gast.gast_to_ast(ast_node)
# Do not wrap lines even if they are too long
def pretty_source(source):
return ''.join(source)
source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code
def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
......@@ -805,8 +721,6 @@ class IsControlFlowVisitor(gast.NodeVisitor):
)
self.ast_root = ast_node
if static_analysis_visitor is None:
from .static_analysis import StaticAnalysisVisitor
static_analysis_visitor = StaticAnalysisVisitor(ast_node)
self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = (
......@@ -941,8 +855,6 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return node
def _is_node_with_tensor(self, node, name_id):
from paddle.jit.dy2static.static_analysis import NodeVarType
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
if name_id and isinstance(name_id, str):
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import astor
from paddle.utils import gast
from .ast_utils import ast_to_source_code
from .logging_utils import warn
def index_in_list(array_list, item):
try:
return array_list.index(item)
except ValueError:
# Item not in array_list
return -1
# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.jit.dy2static'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX):
return False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return is_api_in_module(node, DYGRAPH_MODULE_PREFIX)
def is_api_in_module(node, module_prefix):
assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"
# Python can have gast.Call as function, for example: covert_call(func)(x)
# We only check the most outside function
func_node = node.func
while isinstance(func_node, gast.Call):
func_node = func_node.func
func_str = astor.to_source(gast.gast_to_ast(func_node)).strip()
try:
import paddle # noqa: F401
import paddle.fluid as fluid # noqa: F401
import paddle.fluid.dygraph as dygraph # noqa: F401
import paddle.fluid.layers as layers # noqa: F401
import paddle.jit.dy2static as _jst # noqa: F401
from paddle import to_tensor # noqa: F401
from paddle.fluid.dygraph import to_variable # noqa: F401
return eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)
)
except Exception:
return False
def _is_api_in_module_helper(obj, module_prefix):
m = inspect.getmodule(obj)
return m is not None and m.__name__.startswith(module_prefix)
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import numpy as np # noqa: F401
module_result = eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, "numpy")
)
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
return module_result or (
func_str.startswith("numpy.") or func_str.startswith("np.")
)
except Exception:
return False
def is_paddle_api(node):
return is_api_in_module(node, PADDLE_MODULE_PREFIX)
class NodeVarType:
"""
Enum class of python variable types. We have to know some variable types
during compile time to transfer AST. For example, a string variable and a
tensor variable in if clause may lead to different conversion from dygraph
to static graph.
"""
ERROR = -1 # Returns when static analysis gets error
UNKNOWN = 0 # Reserve for AST nodes have not known the type
STATEMENT = 1 # For nodes representing statement (non-variable type)
CALLABLE = 2
# python data types
NONE = 100
BOOLEAN = 101
INT = 102
FLOAT = 103
STRING = 104
TENSOR = 105
NUMPY_NDARRAY = 106
# python collections
LIST = 200
SET = 201
DICT = 202
PADDLE_DYGRAPH_API = 300
PADDLE_CONTROL_IF = 301
PADDLE_CONTROL_WHILE = 302
PADDLE_CONTROL_FOR = 303
# Paddle API may not be visible to get source code.
# We use this enum value to denote the type return by a Paddle API
PADDLE_RETURN_TYPES = 304
# If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}
Annotation_map = {
"Tensor": TENSOR,
"paddle.Tensor": TENSOR,
"int": INT,
"float": FLOAT,
"bool": BOOLEAN,
"str": STRING,
}
@staticmethod
def binary_op_output_type(in_type1, in_type2):
if in_type1 == in_type2:
return in_type1
if in_type1 == NodeVarType.UNKNOWN:
return in_type2
if in_type2 == NodeVarType.UNKNOWN:
return in_type1
supported_types = [
NodeVarType.BOOLEAN,
NodeVarType.INT,
NodeVarType.FLOAT,
NodeVarType.NUMPY_NDARRAY,
NodeVarType.TENSOR,
NodeVarType.PADDLE_RETURN_TYPES,
]
if in_type1 not in supported_types:
return NodeVarType.UNKNOWN
if in_type2 not in supported_types:
return NodeVarType.UNKNOWN
forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR]
if in_type1 in forbidden_types and in_type2 in forbidden_types:
return NodeVarType.UNKNOWN
return max(in_type1, in_type2)
@staticmethod
def type_from_annotation(annotation):
annotation_str = ast_to_source_code(annotation).strip()
if annotation_str in NodeVarType.Annotation_map:
return NodeVarType.Annotation_map[annotation_str]
# raise warning if not found
warn("Currently we don't support annotation: %s" % annotation_str)
return NodeVarType.UNKNOWN
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册