未验证 提交 34358de5 编写于 作者: 张春乔 提交者: GitHub

【Hackathon No.89】 Remove circle import Part2 (#51199)

* fix the only one circle import in call_transformer.py

* move define of CONVERSION_OPTIONS from convert_call_func.py to program_translator.py

* delete the self import of program_translator.py

* fix import failed problem

* define variable in utils.py

* move is_builtin to utils.py

* move is_builtin to utils.py

* fix import errors

* fix import errors

* fix something

* Update python/paddle/jit/dy2static/call_transformer.py
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>

* Update python/paddle/jit/dy2static/call_transformer.py

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 ab17f988
......@@ -17,6 +17,7 @@ from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api
from paddle.utils import gast
from .base_transformer import BaseTransformer
from .utils import is_builtin # noqa: F401
PDB_SET = "pdb.set_trace"
......@@ -48,8 +49,6 @@ class CallTransformer(BaseTransformer):
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.jit.dy2static.convert_call_func import is_builtin
need_convert_builtin_func_list = {
'len',
'zip',
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import collections
import copy
import functools
......@@ -20,14 +19,11 @@ import inspect
import logging
import pdb
import re
import types
from typing import Any, List
import numpy
from paddle.fluid.dygraph.layers import Layer
from paddle.jit.dy2static.logging_utils import TranslatorLogger
from paddle.jit.dy2static.utils import is_paddle_func, unwrap
from .convert_operators import (
convert_enumerate,
......@@ -36,14 +32,20 @@ from .convert_operators import (
convert_range,
convert_zip,
)
from .logging_utils import TranslatorLogger
from .program_translator import (
CONVERSION_OPTIONS,
StaticFunction,
convert_to_static,
unwrap_decorators,
)
from .utils import is_builtin, is_paddle_func, unwrap
__all__ = []
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "__jst_not_to_static"
class ConversionOptions:
"""
......@@ -72,22 +74,6 @@ class ConversionOptions:
)
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
def builtin_modules():
"""
Return builtin modules.
......@@ -198,13 +184,6 @@ def convert_call(func):
# [1. 1. 1.]]
"""
# NOTE(Aurelius84): Fix it after all files migrating into jit.
from paddle.jit.dy2static.program_translator import (
StaticFunction,
convert_to_static,
unwrap_decorators,
)
translator_logger.log(
1, "Convert callable object: convert {}.".format(func)
)
......
......@@ -20,13 +20,13 @@ from paddle.fluid.dygraph.base import _convert_into_variable
from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import Print, control_flow, fill_constant
from paddle.fluid.layers.control_flow import while_loop
from paddle.jit.dy2static.utils import (
from .utils import (
RETURN_NO_VALUE_VAR_NAME,
Dygraph2StaticException,
GetterSetterHelper,
UndefinedVar,
)
from .return_transformer import RETURN_NO_VALUE_VAR_NAME
from .variable_trans_func import to_static_variable
__all__ = []
......
......@@ -43,11 +43,10 @@ from paddle.jit.dy2static.utils import (
from paddle.utils import gast
from .base_transformer import BaseTransformer
from .utils import FALSE_FUNC_PREFIX, TRUE_FUNC_PREFIX
__all__ = []
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
......
......@@ -26,6 +26,10 @@ from .base_transformer import (
from .ifelse_transformer import ARGS_NAME
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .utils import (
FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
WHILE_CONDITION_PREFIX,
FunctionNameLivenessAnalysis,
GetterSetterHelper,
ast_to_source_code,
......@@ -38,12 +42,6 @@ from .utils import (
__all__ = []
WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
def create_while_nodes(
condition_name,
......
......@@ -25,8 +25,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from . import logging_utils
from .return_transformer import RETURN_NO_VALUE_MAGIC_NUM
from .utils import _out_grad_names, _param_grad_names
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
__all__ = []
......
......@@ -28,7 +28,6 @@ from paddle.utils import flatten, gast
from . import error, logging_utils
from .ast_transformer import DygraphToStaticAst
from .convert_call_func import CONVERSION_OPTIONS
from .function_spec import (
FunctionSpec,
_hash_spec_names,
......@@ -60,6 +59,8 @@ __all__ = []
# Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected.
MAX_TRACED_PROGRAM_COUNT = 10
CONVERSION_OPTIONS = "__jst_not_to_static"
def synchronized(func):
func.__lock__ = threading.Lock()
......@@ -1031,10 +1032,6 @@ class ConcreteProgram:
error_data.raise_new_exception()
raise
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
# 3. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = (
ProgramTranslator.get_instance()._params_recorder.pop(
......
......@@ -43,10 +43,6 @@ RETURN_VALUE_INIT_NAME = '__return_value_init'
# graph as a place holder to indicate the returning placeholder means no value
# should return.
# Assign not support float64, use float32 value as magic number.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e27
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
def get_return_size(return_node):
assert isinstance(return_node, gast.Return), "Input is not gast.Return node"
......
......@@ -14,6 +14,7 @@
import ast
import atexit
import builtins
import copy
import functools
import importlib.util
......@@ -23,6 +24,7 @@ import shutil
import sys
import tempfile
import textwrap
import types
import warnings
from importlib.machinery import SourceFileLoader
......@@ -49,6 +51,31 @@ ARGS_NAME = '__args'
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node."
DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
RE_PYNAME = '[a-zA-Z0-9_]+'
RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
# Assign not support float64, use float32 value as magic number.
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e27
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
class BaseNodeVisitor(gast.NodeVisitor):
"""
......@@ -81,19 +108,6 @@ dygraph_class_to_static_api = {
"PolynomialDecay": "polynomial_decay",
}
DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
RE_PYNAME = '[a-zA-Z0-9_]+'
RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
"""
......@@ -143,10 +157,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def create_undefined_variable():
from paddle.jit.dy2static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
var = data_layer_not_check(
unique_name.generate("undefined_var"), [1], "float64"
)
......@@ -1221,16 +1231,6 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
"""NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
"""
from paddle.jit.dy2static.ifelse_transformer import (
FALSE_FUNC_PREFIX,
TRUE_FUNC_PREFIX,
)
from paddle.jit.dy2static.loop_transformer import (
FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
)
control_flow_function_def = [
WHILE_BODY_PREFIX,
WHILE_BODY_PREFIX,
......@@ -1568,3 +1568,19 @@ def prim_or_cinn_is_enabled(build_strategy):
elif value.lower() in ['true', '1']:
return True
return False
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册