未验证 提交 ab85f87a 编写于 作者: A Aurelius84 提交者: GitHub

[Fluid Clean]Migrate utils files and delete dygraph_to_static dir (#48566)

* [Fluid Clean]Migrate utils files and delete dygraph_to_static dir

* fix setup.py.in

* fix import

* fix unittest

* fix code style

* fix unittest
上级 3d35aa80
...@@ -696,7 +696,7 @@ class IpuDynamicPatcher: ...@@ -696,7 +696,7 @@ class IpuDynamicPatcher:
ProgramCache, ProgramCache,
MAX_TRACED_PROGRAM_COUNT, MAX_TRACED_PROGRAM_COUNT,
) )
from ..fluid.dygraph.dygraph_to_static import logging_utils from paddle.jit.dy2static import logging_utils
from paddle.jit.dy2static.partial_program import ( from paddle.jit.dy2static.partial_program import (
partial_program_from, partial_program_from,
) )
......
...@@ -43,8 +43,6 @@ from .io import * ...@@ -43,8 +43,6 @@ from .io import *
from . import static_runner from . import static_runner
from .static_runner import StaticModelRunner from .static_runner import StaticModelRunner
from . import dygraph_to_static
from . import rnn from . import rnn
from .rnn import * from .rnn import *
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import static_analysis
from .static_analysis import *
from . import variable_trans_func
from .variable_trans_func import *
from . import logging_utils
from .logging_utils import *
__all__ = []
__all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += logging_utils.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import textwrap
from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.dygraph.dygraph_to_static.utils import (
UndefinedVar,
create_undefined_variable,
)
from paddle.fluid.layers.utils import map_structure, is_sequence
__all__ = [
'create_bool_as_type',
'create_fill_constant_node',
'to_static_variable',
'create_undefined_var',
]
def create_undefined_var(name):
func_code = "{} = _jst.UndefinedVar('{}')".format(name, name)
return gast.parse(func_code).body[0]
def create_fill_constant_node(name, value=0):
func_code = "{} = paddle.full(shape=[1], ".format(name)
if isinstance(value, bool):
func_code += "dtype='bool', fill_value={}, name='{}')".format(
value, name
)
return gast.parse(func_code).body[0]
if isinstance(value, float):
func_code += "dtype='float64', fill_value={}, name='{}')".format(
value, name
)
return gast.parse(func_code).body[0]
if isinstance(value, int):
func_code += "dtype='int64', fill_value={}, name='{}')".format(
value, name
)
return gast.parse(func_code).body[0]
def to_static_variable(x):
'''
Translate a Python Tensor to PaddlePaddle static graph Tensor
'''
if isinstance(x, bool):
return paddle.full(shape=[1], dtype='bool', fill_value=x)
if isinstance(x, float):
return paddle.full(shape=[1], dtype='float64', fill_value=x)
if isinstance(x, int):
return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar) or x is None:
"""
for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
return create_undefined_variable()
if is_sequence(x):
return map_structure(to_static_variable, x)
return x
def create_bool_as_type(x, value=True):
'''
Create a bool variable, which type is the same as x.
'''
if isinstance(x, Variable):
return paddle.full(shape=[1], fill_value=value, dtype="bool")
else:
return value
def create_bool_node(name, value):
'''
Create a assign stmt for name = value .
'''
assert isinstance(value, bool)
node = "{} = {}".format(name, value)
return gast.parse(node).body[0]
...@@ -159,10 +159,10 @@ def select_input(inputs, mask): ...@@ -159,10 +159,10 @@ def select_input(inputs, mask):
def select_input_with_buildin_type(inputs, mask, name): def select_input_with_buildin_type(inputs, mask, name):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import ( from paddle.jit.dy2static.variable_trans_func import (
to_static_variable, to_static_variable,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.jit.dy2static.utils import UndefinedVar
false_var, true_var = inputs false_var, true_var = inputs
...@@ -1484,7 +1484,7 @@ def _deal_with_undefined_var(output_vars, loop_vars): ...@@ -1484,7 +1484,7 @@ def _deal_with_undefined_var(output_vars, loop_vars):
3. UndefinedVar = List(int) # create a list of variable 3. UndefinedVar = List(int) # create a list of variable
4. UndefinedVar = value # create a variable 4. UndefinedVar = value # create a variable
""" """
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
UndefinedVar, UndefinedVar,
create_undefined_variable, create_undefined_variable,
) )
...@@ -2552,7 +2552,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ...@@ -2552,7 +2552,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
def change_none_to_undefinedvar(nest1, nest2): def change_none_to_undefinedvar(nest1, nest2):
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.jit.dy2static.utils import UndefinedVar
def map_fn(x): def map_fn(x):
if x is None: if x is None:
...@@ -2588,7 +2588,7 @@ def expand_undefined_var(nest1, nest2, names): ...@@ -2588,7 +2588,7 @@ def expand_undefined_var(nest1, nest2, names):
nest2: Var2, ([1,2,3,4], UndefinedVar) nest2: Var2, ([1,2,3,4], UndefinedVar)
In this case, we should not expand recursively. In this case, we should not expand recursively.
""" """
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.jit.dy2static.utils import UndefinedVar
from paddle.jit.dy2static.return_transformer import ( from paddle.jit.dy2static.return_transformer import (
RETURN_VALUE_PREFIX, RETURN_VALUE_PREFIX,
) )
......
...@@ -25,7 +25,7 @@ from ifelse_simple_func import ( ...@@ -25,7 +25,7 @@ from ifelse_simple_func import (
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.jit.dy2static.utils import ast_to_func
from paddle.utils import gast from paddle.utils import gast
......
...@@ -22,8 +22,8 @@ import paddle.fluid as fluid ...@@ -22,8 +22,8 @@ import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph import paddle.fluid.dygraph as dygraph
from paddle import to_tensor from paddle import to_tensor
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api
from paddle.jit.api import dygraph_to_static_func from paddle.jit.api import dygraph_to_static_func
from paddle.jit.dy2static.utils import is_dygraph_api
from paddle.utils import gast from paddle.utils import gast
SEED = 2020 SEED = 2020
......
...@@ -18,9 +18,9 @@ import numpy as np ...@@ -18,9 +18,9 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
from paddle.jit.api import declarative from paddle.jit.api import declarative
from paddle.jit.dy2static.program_translator import ProgramTranslator from paddle.jit.dy2static.program_translator import ProgramTranslator
from paddle.jit.dy2static.utils import Dygraph2StaticException
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
......
...@@ -18,9 +18,7 @@ import unittest ...@@ -18,9 +18,7 @@ import unittest
from numpy import append from numpy import append
import paddle import paddle
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import FunctionNameLivenessAnalysis
FunctionNameLivenessAnalysis,
)
from paddle.utils import gast from paddle.utils import gast
global_a = [] global_a = []
......
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import error from paddle.jit.dy2static import error
from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap from paddle.jit.dy2static.origin_info import unwrap
def inner_func(): def inner_func():
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
from test_declarative import foo_func from test_declarative import foo_func
import paddle import paddle
from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec from paddle.jit.dy2static.function_spec import FunctionSpec
from paddle.static import InputSpec from paddle.static import InputSpec
paddle.enable_static() paddle.enable_static()
......
...@@ -43,9 +43,9 @@ from ifelse_simple_func import ( ...@@ -43,9 +43,9 @@ from ifelse_simple_func import (
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
from paddle.jit.api import declarative from paddle.jit.api import declarative
from paddle.jit.dy2static.program_translator import ProgramTranslator from paddle.jit.dy2static.program_translator import ProgramTranslator
from paddle.jit.dy2static.utils import Dygraph2StaticException
np.random.seed(1) np.random.seed(1)
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from unittest import mock from unittest import mock
import paddle import paddle
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.jit.dy2static import logging_utils
from paddle.utils import gast from paddle.utils import gast
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
import sys import sys
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static.origin_info import ( from paddle.jit.api import declarative
from paddle.jit.dy2static import DygraphToStaticAst
from paddle.jit.dy2static.origin_info import (
ORIGI_INFO, ORIGI_INFO,
Location, Location,
OriginInfo, OriginInfo,
...@@ -25,9 +27,7 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import ( ...@@ -25,9 +27,7 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
inspect, inspect,
unwrap, unwrap,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.jit.dy2static.utils import ast_to_func
from paddle.jit.api import declarative
from paddle.jit.dy2static import DygraphToStaticAst
def simple_func(x): def simple_func(x):
......
...@@ -27,9 +27,9 @@ from ifelse_simple_func import ( ...@@ -27,9 +27,9 @@ from ifelse_simple_func import (
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.jit.dy2static as _jst import paddle.jit.dy2static as _jst
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.jit import ProgramTranslator from paddle.jit import ProgramTranslator
from paddle.jit.api import declarative from paddle.jit.api import declarative
from paddle.jit.dy2static.utils import func_to_source_code
from paddle.utils import gast from paddle.utils import gast
np.random.seed(0) np.random.seed(0)
......
...@@ -20,8 +20,8 @@ from ifelse_simple_func import dyfunc_with_if_else ...@@ -20,8 +20,8 @@ from ifelse_simple_func import dyfunc_with_if_else
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException
from paddle.jit import ProgramTranslator, to_static from paddle.jit import ProgramTranslator, to_static
from paddle.jit.dy2static.utils import Dygraph2StaticException
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
......
...@@ -17,8 +17,8 @@ import unittest ...@@ -17,8 +17,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.jit.dy2static.program_translator import StaticFunction from paddle.jit.dy2static.program_translator import StaticFunction
from paddle.jit.dy2static.utils import func_to_source_code
class Net(paddle.nn.Layer): class Net(paddle.nn.Layer):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper from paddle.jit.dy2static.utils import GetterSetterHelper
vars = [1, 2, 3, 4, 5] vars = [1, 2, 3, 4, 5]
......
...@@ -19,10 +19,7 @@ import numpy as np ...@@ -19,10 +19,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ( from paddle.jit.dy2static import NodeVarType, StaticAnalysisVisitor
NodeVarType,
StaticAnalysisVisitor,
)
from paddle.utils import gast from paddle.utils import gast
......
...@@ -15,10 +15,7 @@ ...@@ -15,10 +15,7 @@
import types import types
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import index_in_list, is_paddle_func
index_in_list,
is_paddle_func,
)
class TestIndexInList(unittest.TestCase): class TestIndexInList(unittest.TestCase):
......
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import ( from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node
create_fill_constant_node,
)
class TestVariableTransFunc(unittest.TestCase): class TestVariableTransFunc(unittest.TestCase):
......
...@@ -29,7 +29,6 @@ import paddle.fluid as fluid ...@@ -29,7 +29,6 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.framework import ( from paddle.fluid.framework import (
OpProtoHolder, OpProtoHolder,
...@@ -43,6 +42,7 @@ from paddle.fluid.framework import ( ...@@ -43,6 +42,7 @@ from paddle.fluid.framework import (
_test_eager_guard, _test_eager_guard,
) )
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.jit.dy2static.utils import parse_arg_and_kwargs
sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from testsuite import append_input_output, append_loss_ops, create_op, set_input from testsuite import append_input_output, append_loss_ops, create_op, set_input
......
...@@ -20,10 +20,8 @@ import numpy as np ...@@ -20,10 +20,8 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import (
_compatible_non_tensor_spec,
)
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.jit.dy2static.utils import _compatible_non_tensor_spec
from paddle.static import InputSpec from paddle.static import InputSpec
......
...@@ -622,7 +622,7 @@ def _setitem_for_tensor_array(var, item, value): ...@@ -622,7 +622,7 @@ def _setitem_for_tensor_array(var, item, value):
not _non_static_mode() not _non_static_mode()
), "setitem for tensor_array must be called in static graph mode." ), "setitem for tensor_array must be called in static graph mode."
if isinstance(item, (Variable, int)): if isinstance(item, (Variable, int)):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import ( from paddle.jit.dy2static.variable_trans_func import (
to_static_variable, to_static_variable,
) )
from paddle import cast from paddle import cast
......
...@@ -34,16 +34,16 @@ from paddle.fluid.dygraph.base import ( ...@@ -34,16 +34,16 @@ from paddle.fluid.dygraph.base import (
program_desc_tracing_guard, program_desc_tracing_guard,
switch_to_static_graph, switch_to_static_graph,
) )
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from .dy2static import logging_utils
from paddle.jit.dy2static.convert_call_func import ( from .dy2static.convert_call_func import (
ConversionOptions, ConversionOptions,
CONVERSION_OPTIONS, CONVERSION_OPTIONS,
) )
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import ( from .dy2static.logging_utils import (
set_code_level, set_code_level,
set_verbosity, set_verbosity,
) )
from paddle.jit.dy2static.program_translator import ( from .dy2static.program_translator import (
ProgramTranslator, ProgramTranslator,
StaticFunction, StaticFunction,
unwrap_decorators, unwrap_decorators,
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base import saw from .utils import (
from .base import UndefinedVar saw,
UndefinedVar,
)
from .convert_operators import convert_logical_and as And # noqa: F401 from .convert_operators import convert_logical_and as And # noqa: F401
from .convert_operators import convert_var_dtype as AsDtype # noqa: F401 from .convert_operators import convert_var_dtype as AsDtype # noqa: F401
from .convert_operators import convert_assert as Assert # noqa: F401 from .convert_operators import convert_assert as Assert # noqa: F401
...@@ -35,5 +37,6 @@ from .convert_operators import convert_shape_compare # noqa: F401 ...@@ -35,5 +37,6 @@ from .convert_operators import convert_shape_compare # noqa: F401
from .assert_transformer import AssertTransformer from .assert_transformer import AssertTransformer
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
from .program_translator import convert_to_static from .program_translator import convert_to_static
from .static_analysis import * # noqa: F403
__all__ = [] __all__ = []
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
......
...@@ -61,7 +61,7 @@ from .return_transformer import ( ...@@ -61,7 +61,7 @@ from .return_transformer import (
from .create_variable_transformer import ( from .create_variable_transformer import (
CreateVariableTransformer, CreateVariableTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from .static_analysis import (
StaticAnalysisVisitor, StaticAnalysisVisitor,
) )
from .tensor_shape_transformer import ( from .tensor_shape_transformer import (
...@@ -71,8 +71,8 @@ from .decorator_transformer import ( ...@@ -71,8 +71,8 @@ from .decorator_transformer import (
DecoratorTransformer, DecoratorTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from . import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from .utils import ast_to_source_code
__all__ = ['DygraphToStaticAst'] __all__ = ['DygraphToStaticAst']
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.dygraph.dygraph_to_static.utils import saw # noqa: F401
from ...fluid.dygraph.dygraph_to_static.utils import UndefinedVar # noqa: F401
__all__ = []
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
ORIGI_INFO, ORIGI_INFO,
FOR_ITER_INDEX_PREFIX, FOR_ITER_INDEX_PREFIX,
FOR_ITER_VAR_LEN_PREFIX, FOR_ITER_VAR_LEN_PREFIX,
......
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
import astor import astor
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from .static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static import utils from . import utils
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class BasicApiTransformer(BaseTransformer): class BasicApiTransformer(BaseTransformer):
...@@ -166,7 +165,7 @@ class AttributeJstTransformer(BaseTransformer): ...@@ -166,7 +165,7 @@ class AttributeJstTransformer(BaseTransformer):
node = ( node = (
gast.parse( gast.parse(
"_jst.Attr({}, \"{}\")".format( "_jst.Attr({}, \"{}\")".format(
ast_to_source_code(value).strip(), attr utils.ast_to_source_code(value).strip(), attr
) )
) )
.body[0] .body[0]
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.jit.dy2static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor from paddle.jit.dy2static.utils import BaseNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import ( from paddle.jit.dy2static.variable_trans_func import (
create_bool_node, create_bool_node,
) )
from .base_transformer import ( from .base_transformer import (
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.jit.dy2static.utils import is_paddle_api
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
......
...@@ -32,11 +32,11 @@ from .convert_operators import ( ...@@ -32,11 +32,11 @@ from .convert_operators import (
convert_enumerate, convert_enumerate,
) )
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import ( from paddle.jit.dy2static.logging_utils import (
TranslatorLogger, TranslatorLogger,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap from paddle.jit.dy2static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"] __all__ = ["convert_call"]
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import re import re
import paddle import paddle
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import ( from paddle.jit.dy2static.variable_trans_func import (
to_static_variable, to_static_variable,
) )
from paddle.fluid.framework import core, Variable from paddle.fluid.framework import core, Variable
...@@ -46,11 +46,11 @@ from paddle.fluid.layers.control_flow import ( ...@@ -46,11 +46,11 @@ from paddle.fluid.layers.control_flow import (
from .return_transformer import ( from .return_transformer import (
RETURN_NO_VALUE_VAR_NAME, RETURN_NO_VALUE_VAR_NAME,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
UndefinedVar, UndefinedVar,
Dygraph2StaticException, Dygraph2StaticException,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import GetterSetterHelper from paddle.jit.dy2static.utils import GetterSetterHelper
from paddle.fluid.layers.utils import copy_mutable_vars from paddle.fluid.layers.utils import copy_mutable_vars
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
) )
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import ( from paddle.jit.dy2static.variable_trans_func import (
create_undefined_var, create_undefined_var,
) )
from .base_transformer import ( from .base_transformer import (
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
RE_PYNAME, RE_PYNAME,
RE_PYMODULE, RE_PYMODULE,
ast_to_source_code, ast_to_source_code,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
......
...@@ -17,17 +17,18 @@ import sys ...@@ -17,17 +17,18 @@ import sys
import traceback import traceback
import linecache import linecache
import re import re
import numpy as np import numpy as np # noqa: F401
from paddle.fluid.dygraph.dygraph_to_static.origin_info import ( from .origin_info import (
Location, Location,
OriginInfo, OriginInfo,
global_origin_info_map, global_origin_info_map,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from .utils import _is_api_in_module_helper # noqa: F401
_is_api_in_module_helper, from .utils import RE_PYMODULE
RE_PYMODULE,
)
__all__ = []
ERROR_DATA = "Error data about original source code information and traceback." ERROR_DATA = "Error data about original source code information and traceback."
......
...@@ -22,13 +22,16 @@ from paddle.fluid.dygraph import layers ...@@ -22,13 +22,16 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_varargs_name
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.io import TranslatedLayer from paddle.fluid.dygraph.io import TranslatedLayer
from . import logging_utils
from .utils import (
parse_arg_and_kwargs,
parse_varargs_name,
type_name,
func_to_source_code,
)
class FunctionSpec: class FunctionSpec:
""" """
......
...@@ -22,27 +22,27 @@ from collections import defaultdict ...@@ -22,27 +22,27 @@ from collections import defaultdict
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
create_funcDef_node, create_funcDef_node,
ast_to_source_code, ast_to_source_code,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
) )
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
create_nonlocal_stmt_nodes, create_nonlocal_stmt_nodes,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
create_get_args_node, create_get_args_node,
create_set_args_node, create_set_args_node,
) )
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
FOR_ITER_INDEX_PREFIX, FOR_ITER_INDEX_PREFIX,
FOR_ITER_TUPLE_PREFIX, FOR_ITER_TUPLE_PREFIX,
FOR_ITER_TUPLE_INDEX_PREFIX, FOR_ITER_TUPLE_INDEX_PREFIX,
...@@ -52,7 +52,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ( ...@@ -52,7 +52,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
FOR_ITER_TARGET_PREFIX, FOR_ITER_TARGET_PREFIX,
FOR_ITER_ITERATOR_PREFIX, FOR_ITER_ITERATOR_PREFIX,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
GetterSetterHelper, GetterSetterHelper,
create_name_str, create_name_str,
) )
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import threading import threading
from paddle.fluid import log_helper from paddle.fluid import log_helper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from .utils import ast_to_source_code
__all__ = ["TranslatorLogger", "set_verbosity", "set_code_level"] __all__ = ["TranslatorLogger", "set_verbosity", "set_code_level"]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
......
...@@ -17,21 +17,21 @@ from paddle.utils import gast ...@@ -17,21 +17,21 @@ from paddle.utils import gast
from collections import defaultdict from collections import defaultdict
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.jit.dy2static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
StaticAnalysisVisitor, StaticAnalysisVisitor,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.jit.dy2static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
create_nonlocal_stmt_nodes, create_nonlocal_stmt_nodes,
create_get_args_node, create_get_args_node,
create_set_args_node, create_set_args_node,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
) )
from .ifelse_transformer import ARGS_NAME from .ifelse_transformer import ARGS_NAME
...@@ -41,7 +41,7 @@ from .base_transformer import ( ...@@ -41,7 +41,7 @@ from .base_transformer import (
ForNodeVisitor, ForNodeVisitor,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.jit.dy2static.utils import (
GetterSetterHelper, GetterSetterHelper,
create_name_str, create_name_str,
) )
......
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import inspect import inspect
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap from .utils import (
from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO unwrap,
ORIGI_INFO,
)
from paddle.fluid.framework import Program from paddle.fluid.framework import Program
from collections.abc import Sequence from collections.abc import Sequence
......
...@@ -22,7 +22,7 @@ from paddle.fluid.executor import ( ...@@ -22,7 +22,7 @@ from paddle.fluid.executor import (
) )
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from . import logging_utils
from .return_transformer import ( from .return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM, RETURN_NO_VALUE_MAGIC_NUM,
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
StaticAnalysisVisitor, StaticAnalysisVisitor,
) )
......
...@@ -26,39 +26,35 @@ from paddle.fluid.data_feeder import check_type ...@@ -26,39 +26,35 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import error from . import error
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from . import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.origin_info import ( from .origin_info import (
attach_origin_info, attach_origin_info,
)
from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
create_and_update_origin_info_map, create_and_update_origin_info_map,
)
from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
update_op_callstack_with_origin_info, update_op_callstack_with_origin_info,
) )
from .partial_program import ( from .partial_program import (
partial_program_from, partial_program_from,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from .utils import (
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code ast_to_func,
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code ast_to_source_code,
from paddle.fluid.dygraph.dygraph_to_static.utils import input_specs_compatible func_to_source_code,
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name input_specs_compatible,
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap type_name,
from paddle.fluid.dygraph.dygraph_to_static.utils import ( unwrap,
make_hashable, make_hashable,
ALREADY_D2S, ALREADY_D2S,
) )
from paddle.fluid.dygraph.dygraph_to_static.function_spec import ( from .function_spec import (
FunctionSpec, FunctionSpec,
_hash_spec_names, _hash_spec_names,
)
from paddle.fluid.dygraph.dygraph_to_static.function_spec import (
get_buffers, get_buffers,
get_parameters, get_parameters,
) )
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
__all__ = ['ProgramTranslator', 'convert_to_static'] __all__ = ['ProgramTranslator', 'convert_to_static']
......
...@@ -15,16 +15,16 @@ ...@@ -15,16 +15,16 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.jit.dy2static.utils import index_in_list
from .break_continue_transformer import ( from .break_continue_transformer import (
ForToWhileTransformer, ForToWhileTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException from paddle.jit.dy2static.utils import Dygraph2StaticException
from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO from paddle.jit.dy2static.utils import ORIGI_INFO
__all__ = [ __all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_MAGIC_NUM',
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
......
...@@ -16,7 +16,6 @@ import ast ...@@ -16,7 +16,6 @@ import ast
import astor import astor
import atexit import atexit
import copy import copy
import collections
from paddle.utils import gast from paddle.utils import gast
import inspect import inspect
import os import os
...@@ -32,15 +31,17 @@ from paddle.fluid.data_feeder import convert_dtype ...@@ -32,15 +31,17 @@ from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import assign from paddle.fluid.layers import assign
import collections
from functools import reduce from functools import reduce
import warnings import warnings
__all__ = []
# Note(Aurelius): Do not forget the dot `.` to distinguish other # Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp. # module such as paddlenlp.
PADDLE_MODULE_PREFIX = 'paddle.' PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph' DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static' DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.jit.dy2static'
GET_ARGS_FUNC_PREFIX = 'get_args' GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args' SET_ARGS_FUNC_PREFIX = 'set_args'
ALREADY_D2S = '__already_d2s' ALREADY_D2S = '__already_d2s'
...@@ -258,19 +259,13 @@ def is_api_in_module(node, module_prefix): ...@@ -258,19 +259,13 @@ def is_api_in_module(node, module_prefix):
func_str = astor.to_source(gast.gast_to_ast(func_node)).strip() func_str = astor.to_source(gast.gast_to_ast(func_node)).strip()
try: try:
# TODO(liym27): import paddle # noqa: F401
# Consider a better to import modules like: import paddle.fluid as fluid # noqa: F401
# source_file = inspect.getfile(dyfunc) import paddle.fluid.dygraph as dygraph # noqa: F401
# import_statements = ImportVisitor(source_file).transform() import paddle.fluid.layers as layers # noqa: F401
# import_str = "".join(import_statements) import paddle.jit.dy2static as _jst # noqa: F401
import paddle from paddle.fluid.dygraph import to_variable # noqa: F401
import paddle.fluid as fluid from paddle import to_tensor # noqa: F401
import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers
import paddle.jit.dy2static as _jst
from paddle.fluid.dygraph import to_variable
from paddle import to_tensor
return eval( return eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix) "_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)
...@@ -304,7 +299,7 @@ def is_numpy_api(node): ...@@ -304,7 +299,7 @@ def is_numpy_api(node):
assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
func_str = astor.to_source(gast.gast_to_ast(node.func)) func_str = astor.to_source(gast.gast_to_ast(node.func))
try: try:
import numpy as np import numpy as np # noqa: F401
module_result = eval( module_result = eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, "numpy") "_is_api_in_module_helper({}, '{}')".format(func_str, "numpy")
...@@ -321,7 +316,7 @@ def is_numpy_api(node): ...@@ -321,7 +316,7 @@ def is_numpy_api(node):
def _delete_keywords_from(node): def _delete_keywords_from(node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
func_src = astor.to_source(gast.gast_to_ast(node.func)) func_src = astor.to_source(gast.gast_to_ast(node.func))
import paddle.fluid as fluid import paddle.fluid as fluid # noqa: F401
full_args = eval(f"inspect.getfullargspec({func_src})") full_args = eval(f"inspect.getfullargspec({func_src})")
full_args_name = full_args[0] full_args_name = full_args[0]
...@@ -402,7 +397,7 @@ def update_args_of_func(node, dygraph_node, method_name): ...@@ -402,7 +397,7 @@ def update_args_of_func(node, dygraph_node, method_name):
) )
class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func)) class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func))
import paddle.fluid as fluid import paddle.fluid as fluid # noqa: F401
if method_name == "__init__" or eval( if method_name == "__init__" or eval(
"issubclass({}, fluid.dygraph.Layer)".format(class_src) "issubclass({}, fluid.dygraph.Layer)".format(class_src)
...@@ -894,7 +889,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -894,7 +889,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return node return node
def _is_node_with_tensor(self, node, name_id): def _is_node_with_tensor(self, node, name_id):
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.jit.dy2static.static_analysis import (
NodeVarType, NodeVarType,
) )
...@@ -1213,7 +1208,6 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1213,7 +1208,6 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
""" """
from paddle.jit.dy2static.loop_transformer import ( from paddle.jit.dy2static.loop_transformer import (
WHILE_CONDITION_PREFIX,
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX, FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX, FOR_BODY_PREFIX,
......
...@@ -12,9 +12,82 @@ ...@@ -12,9 +12,82 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import ( # noqa: F401 import paddle
create_bool_as_type, from paddle.utils import gast
to_static_variable, from paddle.fluid.framework import Variable
from paddle.jit.dy2static.utils import (
UndefinedVar,
create_undefined_variable,
) )
from paddle.fluid.layers.utils import map_structure, is_sequence
__all__ = [] __all__ = [
'create_bool_as_type',
'create_fill_constant_node',
'to_static_variable',
'create_undefined_var',
]
def create_undefined_var(name):
func_code = "{} = _jst.UndefinedVar('{}')".format(name, name)
return gast.parse(func_code).body[0]
def create_fill_constant_node(name, value=0):
func_code = "{} = paddle.full(shape=[1], ".format(name)
if isinstance(value, bool):
func_code += "dtype='bool', fill_value={}, name='{}')".format(
value, name
)
return gast.parse(func_code).body[0]
if isinstance(value, float):
func_code += "dtype='float64', fill_value={}, name='{}')".format(
value, name
)
return gast.parse(func_code).body[0]
if isinstance(value, int):
func_code += "dtype='int64', fill_value={}, name='{}')".format(
value, name
)
return gast.parse(func_code).body[0]
def to_static_variable(x):
'''
Translate a Python Tensor to PaddlePaddle static graph Tensor
'''
if isinstance(x, bool):
return paddle.full(shape=[1], dtype='bool', fill_value=x)
if isinstance(x, float):
return paddle.full(shape=[1], dtype='float64', fill_value=x)
if isinstance(x, int):
return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar) or x is None:
"""
for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
return create_undefined_variable()
if is_sequence(x):
return map_structure(to_static_variable, x)
return x
def create_bool_as_type(x, value=True):
'''
Create a bool variable, which type is the same as x.
'''
if isinstance(x, Variable):
return paddle.full(shape=[1], fill_value=value, dtype="bool")
else:
return value
def create_bool_node(name, value):
'''
Create a assign stmt for name = value .
'''
assert isinstance(value, bool)
node = "{} = {}".format(name, value)
return gast.parse(node).body[0]
...@@ -331,7 +331,6 @@ packages=['paddle', ...@@ -331,7 +331,6 @@ packages=['paddle',
'paddle.inference.contrib.utils', 'paddle.inference.contrib.utils',
'paddle.fluid', 'paddle.fluid',
'paddle.fluid.dygraph', 'paddle.fluid.dygraph',
'paddle.fluid.dygraph.dygraph_to_static',
'paddle.fluid.dygraph.amp', 'paddle.fluid.dygraph.amp',
'paddle.fluid.proto', 'paddle.fluid.proto',
'paddle.fluid.proto.profiler', 'paddle.fluid.proto.profiler',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册