未验证 提交 34358de5 编写于 作者: 张春乔 提交者: GitHub

【Hackathon No.89】 Remove circle import Part2 (#51199)

* fix the only one circle import in call_transformer.py

* move define of CONVERSION_OPTIONS from convert_call_func.py to program_translator.py

* delete the self import of program_translator.py

* fix import failed problem

* define variable in utils.py

* move is_builtin to utils.py

* move is_builtin to utils.py

* fix import errors

* fix import errors

* fix something

* Update python/paddle/jit/dy2static/call_transformer.py
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>

* Update python/paddle/jit/dy2static/call_transformer.py

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 ab17f988
...@@ -17,6 +17,7 @@ from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api ...@@ -17,6 +17,7 @@ from paddle.jit.dy2static.utils import ast_to_source_code, is_paddle_api
from paddle.utils import gast from paddle.utils import gast
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .utils import is_builtin # noqa: F401
PDB_SET = "pdb.set_trace" PDB_SET = "pdb.set_trace"
...@@ -48,8 +49,6 @@ class CallTransformer(BaseTransformer): ...@@ -48,8 +49,6 @@ class CallTransformer(BaseTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
try: try:
from paddle.jit.dy2static.convert_call_func import is_builtin
need_convert_builtin_func_list = { need_convert_builtin_func_list = {
'len', 'len',
'zip', 'zip',
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import builtins
import collections import collections
import copy import copy
import functools import functools
...@@ -20,14 +19,11 @@ import inspect ...@@ -20,14 +19,11 @@ import inspect
import logging import logging
import pdb import pdb
import re import re
import types
from typing import Any, List from typing import Any, List
import numpy import numpy
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.jit.dy2static.logging_utils import TranslatorLogger
from paddle.jit.dy2static.utils import is_paddle_func, unwrap
from .convert_operators import ( from .convert_operators import (
convert_enumerate, convert_enumerate,
...@@ -36,14 +32,20 @@ from .convert_operators import ( ...@@ -36,14 +32,20 @@ from .convert_operators import (
convert_range, convert_range,
convert_zip, convert_zip,
) )
from .logging_utils import TranslatorLogger
from .program_translator import (
CONVERSION_OPTIONS,
StaticFunction,
convert_to_static,
unwrap_decorators,
)
from .utils import is_builtin, is_paddle_func, unwrap
__all__ = [] __all__ = []
translator_logger = TranslatorLogger() translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "__jst_not_to_static"
class ConversionOptions: class ConversionOptions:
""" """
...@@ -72,22 +74,6 @@ class ConversionOptions: ...@@ -72,22 +74,6 @@ class ConversionOptions:
) )
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
def builtin_modules(): def builtin_modules():
""" """
Return builtin modules. Return builtin modules.
...@@ -198,13 +184,6 @@ def convert_call(func): ...@@ -198,13 +184,6 @@ def convert_call(func):
# [1. 1. 1.]] # [1. 1. 1.]]
""" """
# NOTE(Aurelius84): Fix it after all files migrating into jit.
from paddle.jit.dy2static.program_translator import (
StaticFunction,
convert_to_static,
unwrap_decorators,
)
translator_logger.log( translator_logger.log(
1, "Convert callable object: convert {}.".format(func) 1, "Convert callable object: convert {}.".format(func)
) )
......
...@@ -20,13 +20,13 @@ from paddle.fluid.dygraph.base import _convert_into_variable ...@@ -20,13 +20,13 @@ from paddle.fluid.dygraph.base import _convert_into_variable
from paddle.fluid.framework import Variable, core from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import Print, control_flow, fill_constant from paddle.fluid.layers import Print, control_flow, fill_constant
from paddle.fluid.layers.control_flow import while_loop from paddle.fluid.layers.control_flow import while_loop
from paddle.jit.dy2static.utils import (
from .utils import (
RETURN_NO_VALUE_VAR_NAME,
Dygraph2StaticException, Dygraph2StaticException,
GetterSetterHelper, GetterSetterHelper,
UndefinedVar, UndefinedVar,
) )
from .return_transformer import RETURN_NO_VALUE_VAR_NAME
from .variable_trans_func import to_static_variable from .variable_trans_func import to_static_variable
__all__ = [] __all__ = []
......
...@@ -43,11 +43,10 @@ from paddle.jit.dy2static.utils import ( ...@@ -43,11 +43,10 @@ from paddle.jit.dy2static.utils import (
from paddle.utils import gast from paddle.utils import gast
from .base_transformer import BaseTransformer from .base_transformer import BaseTransformer
from .utils import FALSE_FUNC_PREFIX, TRUE_FUNC_PREFIX
__all__ = [] __all__ = []
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
GET_ARGS_FUNC_PREFIX = 'get_args' GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args' SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args' ARGS_NAME = '__args'
......
...@@ -26,6 +26,10 @@ from .base_transformer import ( ...@@ -26,6 +26,10 @@ from .base_transformer import (
from .ifelse_transformer import ARGS_NAME from .ifelse_transformer import ARGS_NAME
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .utils import ( from .utils import (
FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
WHILE_CONDITION_PREFIX,
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
GetterSetterHelper, GetterSetterHelper,
ast_to_source_code, ast_to_source_code,
...@@ -38,12 +42,6 @@ from .utils import ( ...@@ -38,12 +42,6 @@ from .utils import (
__all__ = [] __all__ = []
WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
def create_while_nodes( def create_while_nodes(
condition_name, condition_name,
......
...@@ -25,8 +25,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph ...@@ -25,8 +25,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass from paddle.fluid.framework import _apply_pass
from . import logging_utils from . import logging_utils
from .return_transformer import RETURN_NO_VALUE_MAGIC_NUM from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
from .utils import _out_grad_names, _param_grad_names
__all__ = [] __all__ = []
......
...@@ -28,7 +28,6 @@ from paddle.utils import flatten, gast ...@@ -28,7 +28,6 @@ from paddle.utils import flatten, gast
from . import error, logging_utils from . import error, logging_utils
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
from .convert_call_func import CONVERSION_OPTIONS
from .function_spec import ( from .function_spec import (
FunctionSpec, FunctionSpec,
_hash_spec_names, _hash_spec_names,
...@@ -60,6 +59,8 @@ __all__ = [] ...@@ -60,6 +59,8 @@ __all__ = []
# Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected. # Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected.
MAX_TRACED_PROGRAM_COUNT = 10 MAX_TRACED_PROGRAM_COUNT = 10
CONVERSION_OPTIONS = "__jst_not_to_static"
def synchronized(func): def synchronized(func):
func.__lock__ = threading.Lock() func.__lock__ = threading.Lock()
...@@ -1031,10 +1032,6 @@ class ConcreteProgram: ...@@ -1031,10 +1032,6 @@ class ConcreteProgram:
error_data.raise_new_exception() error_data.raise_new_exception()
raise raise
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
# 3. Gets all ParamBases and buffered VarBases in the function # 3. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = ( all_parameters_and_buffers = (
ProgramTranslator.get_instance()._params_recorder.pop( ProgramTranslator.get_instance()._params_recorder.pop(
......
...@@ -43,10 +43,6 @@ RETURN_VALUE_INIT_NAME = '__return_value_init' ...@@ -43,10 +43,6 @@ RETURN_VALUE_INIT_NAME = '__return_value_init'
# graph as a place holder to indicate the returning placeholder means no value # graph as a place holder to indicate the returning placeholder means no value
# should return. # should return.
# Assign not support float64, use float32 value as magic number.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e27
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
def get_return_size(return_node): def get_return_size(return_node):
assert isinstance(return_node, gast.Return), "Input is not gast.Return node" assert isinstance(return_node, gast.Return), "Input is not gast.Return node"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import ast import ast
import atexit import atexit
import builtins
import copy import copy
import functools import functools
import importlib.util import importlib.util
...@@ -23,6 +24,7 @@ import shutil ...@@ -23,6 +24,7 @@ import shutil
import sys import sys
import tempfile import tempfile
import textwrap import textwrap
import types
import warnings import warnings
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
...@@ -49,6 +51,31 @@ ARGS_NAME = '__args' ...@@ -49,6 +51,31 @@ ARGS_NAME = '__args'
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. # NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node." ORIGI_INFO = "Original information of source code for ast node."
DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
RE_PYNAME = '[a-zA-Z0-9_]+'
RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
# Assign not support float64, use float32 value as magic number.
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e27
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
class BaseNodeVisitor(gast.NodeVisitor): class BaseNodeVisitor(gast.NodeVisitor):
""" """
...@@ -81,19 +108,6 @@ dygraph_class_to_static_api = { ...@@ -81,19 +108,6 @@ dygraph_class_to_static_api = {
"PolynomialDecay": "polynomial_decay", "PolynomialDecay": "polynomial_decay",
} }
DEL_TEMP_DIR = True # A flag to avoid atexit.register more than once
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TARGET_PREFIX = '__for_loop_iter_target'
FOR_ITER_ITERATOR_PREFIX = '__for_loop_iter_iterator'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'
RE_PYNAME = '[a-zA-Z0-9_]+'
RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
def data_layer_not_check(name, shape, dtype='float32', lod_level=0): def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
""" """
...@@ -143,10 +157,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -143,10 +157,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def create_undefined_variable(): def create_undefined_variable():
from paddle.jit.dy2static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
var = data_layer_not_check( var = data_layer_not_check(
unique_name.generate("undefined_var"), [1], "float64" unique_name.generate("undefined_var"), [1], "float64"
) )
...@@ -1221,16 +1231,6 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1221,16 +1231,6 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
"""NOTE: why we need merge w_vars and push_pop_vars here ? """NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
""" """
from paddle.jit.dy2static.ifelse_transformer import (
FALSE_FUNC_PREFIX,
TRUE_FUNC_PREFIX,
)
from paddle.jit.dy2static.loop_transformer import (
FOR_BODY_PREFIX,
FOR_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
)
control_flow_function_def = [ control_flow_function_def = [
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX,
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX,
...@@ -1568,3 +1568,19 @@ def prim_or_cinn_is_enabled(build_strategy): ...@@ -1568,3 +1568,19 @@ def prim_or_cinn_is_enabled(build_strategy):
elif value.lower() in ['true', '1']: elif value.lower() in ['true', '1']:
return True return True
return False return False
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册