未验证 提交 ea9e4085 编写于 作者: A Aurelius84 提交者: GitHub

[API Clean]Clean __all__ to avoid exposing usless API (#48713)

* [API Clean]Clean __all__ to avoid exposing usless API

* fix import

* fix typo

* remove tracedLayer unittest
上级 b91bbd32
...@@ -48,7 +48,6 @@ class TestDirectory(unittest.TestCase): ...@@ -48,7 +48,6 @@ class TestDirectory(unittest.TestCase):
'paddle.distributed.ParallelEnv', 'paddle.distributed.ParallelEnv',
'paddle.DataParallel', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit',
'paddle.jit.TracedLayer',
'paddle.jit.to_static', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.ProgramTranslator',
'paddle.jit.TranslatedLayer', 'paddle.jit.TranslatedLayer',
......
...@@ -21,7 +21,7 @@ from utils import DyGraphProgramDescTracerTestHelper ...@@ -21,7 +21,7 @@ from utils import DyGraphProgramDescTracerTestHelper
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import _in_legacy_dygraph, _test_eager_guard from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.optimizer import SGDOptimizer
from paddle.nn import Linear from paddle.nn import Linear
...@@ -153,19 +153,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -153,19 +153,7 @@ class TestImperativeMnist(unittest.TestCase):
dy_x_data = img.numpy() dy_x_data = img.numpy()
label = data[1] label = data[1]
label.stop_gradient = True label.stop_gradient = True
cost = mnist(img)
if batch_id % 10 == 0 and _in_legacy_dygraph():
cost, traced_layer = paddle.jit.TracedLayer.trace(
mnist, inputs=img
)
if program is not None:
self.assertTrue(program, traced_layer.program)
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_mnist'
)
else:
cost = mnist(img)
if traced_layer is not None: if traced_layer is not None:
cost_static = traced_layer([img]) cost_static = traced_layer([img])
......
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
import numpy as np import numpy as np
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from utils import DyGraphProgramDescTracerTestHelper
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -24,9 +24,8 @@ import paddle.fluid.core as core ...@@ -24,9 +24,8 @@ import paddle.fluid.core as core
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import Embedding from paddle.fluid.dygraph.nn import Embedding
from paddle.fluid.framework import _in_legacy_dygraph, _test_eager_guard from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.optimizer import SGDOptimizer
from paddle.jit import TracedLayer
class SimpleLSTMRNN(fluid.Layer): class SimpleLSTMRNN(fluid.Layer):
...@@ -298,25 +297,8 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -298,25 +297,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
y = to_variable(y_data) y = to_variable(y_data)
init_hidden = to_variable(init_hidden_data) init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data) init_cell = to_variable(init_cell_data)
if i % 5 == 0 and _in_legacy_dygraph():
outs, traced_layer = TracedLayer.trace(
ptb_model, [x, y, init_hidden, init_cell]
)
outs_static = traced_layer([x, y, init_hidden, init_cell])
helper.assertEachVar(outs, outs_static)
if program is not None:
self.assertTrue(
is_equal_program(traced_layer.program, program)
)
program = traced_layer.program outs = ptb_model(x, y, init_hidden, init_cell)
traced_layer.save_inference_model(
'./infe_imperative_ptb_rnn', feed=list(range(4))
)
else:
outs = ptb_model(x, y, init_hidden, init_cell)
dy_loss, last_hidden, last_cell = outs dy_loss, last_hidden, last_cell = outs
......
...@@ -16,15 +16,14 @@ import unittest ...@@ -16,15 +16,14 @@ import unittest
import numpy as np import numpy as np
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from utils import DyGraphProgramDescTracerTestHelper
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import BatchNorm, core from paddle.fluid import BatchNorm, core
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.framework import _in_legacy_dygraph, _test_eager_guard from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.jit import TracedLayer
# NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1 # NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1
...@@ -301,20 +300,7 @@ class TestDygraphResnet(unittest.TestCase): ...@@ -301,20 +300,7 @@ class TestDygraphResnet(unittest.TestCase):
label.stop_gradient = True label.stop_gradient = True
out = None out = None
if batch_id % 5 == 0 and _in_legacy_dygraph(): out = resnet(img)
out, traced_layer = TracedLayer.trace(resnet, img)
if program is not None:
self.assertTrue(
is_equal_program(program, traced_layer.program)
)
traced_layer.save_inference_model(
'./infer_imperative_resnet'
)
program = traced_layer.program
else:
out = resnet(img)
if traced_layer is not None: if traced_layer is not None:
resnet.eval() resnet.eval()
......
...@@ -23,12 +23,11 @@ import paddle.nn.functional as F ...@@ -23,12 +23,11 @@ import paddle.nn.functional as F
from paddle.fluid import Embedding, Layer, core from paddle.fluid import Embedding, Layer, core
from paddle.fluid.dygraph import guard, to_variable from paddle.fluid.dygraph import guard, to_variable
from paddle.fluid.framework import _in_legacy_dygraph, _test_eager_guard from paddle.fluid.framework import _in_legacy_dygraph, _test_eager_guard
from paddle.jit import TracedLayer
from paddle.nn import Linear from paddle.nn import Linear
np.set_printoptions(suppress=True) np.set_printoptions(suppress=True)
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from utils import DyGraphProgramDescTracerTestHelper
# Copy from models # Copy from models
...@@ -1171,27 +1170,7 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ...@@ -1171,27 +1170,7 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
for i in range(batch_num): for i in range(batch_num):
enc_inputs, dec_inputs, label, weights = create_data() enc_inputs, dec_inputs, label, weights = create_data()
if False: outs = transformer(enc_inputs, dec_inputs, label, weights)
outs, traced_layer = TracedLayer.trace(
transformer, [enc_inputs, dec_inputs, label, weights]
)
ins_static = enc_inputs + dec_inputs + [label, weights]
outs_static = traced_layer(ins_static)
helper.assertEachVar(outs, outs_static)
if program is not None:
self.assertTrue(
is_equal_program(program, traced_layer.program)
)
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_transformer',
feed=list(range(len(ins_static))),
fetch=list(range(len(outs_static))),
)
else:
outs = transformer(enc_inputs, dec_inputs, label, weights)
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = outs dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = outs
......
...@@ -20,8 +20,6 @@ import paddle ...@@ -20,8 +20,6 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
from paddle.fluid.framework import in_dygraph_mode
from paddle.jit.api import TracedLayer
class TestTracedLayer(fluid.dygraph.Layer): class TestTracedLayer(fluid.dygraph.Layer):
...@@ -93,20 +91,6 @@ class TestVariable(unittest.TestCase): ...@@ -93,20 +91,6 @@ class TestVariable(unittest.TestCase):
np.testing.assert_array_equal(y_grad, loss.gradient() * a) np.testing.assert_array_equal(y_grad, loss.gradient() * a)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_traced_layer(self):
if in_dygraph_mode():
return
with fluid.dygraph.guard():
layer = TestTracedLayer("test_traced_layer")
a = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
x = fluid.dygraph.to_variable(a)
res_dygraph, static_layer = TracedLayer.trace(
layer, inputs=x
) # dygraph out
res_static_graph = static_layer([x])[0]
np.testing.assert_array_equal(res_dygraph.numpy(), res_static_graph)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from .api import save from .api import save
from .api import load from .api import load
from .api import TracedLayer
from .api import set_code_level from .api import set_code_level
from .api import set_verbosity from .api import set_verbosity
from .api import declarative as to_static from .api import declarative as to_static
...@@ -34,5 +33,4 @@ __all__ = [ # noqa ...@@ -34,5 +33,4 @@ __all__ = [ # noqa
'set_code_level', 'set_code_level',
'set_verbosity', 'set_verbosity',
'not_to_static', 'not_to_static',
'TracedLayer',
] ]
...@@ -74,9 +74,7 @@ from paddle.fluid.framework import dygraph_only, _non_static_mode ...@@ -74,9 +74,7 @@ from paddle.fluid.framework import dygraph_only, _non_static_mode
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
__all__ = [ __all__ = [
'TracedLayer',
'declarative', 'declarative',
'dygraph_to_static_func',
'set_code_level', 'set_code_level',
'set_verbosity', 'set_verbosity',
'save', 'save',
......
...@@ -36,6 +36,6 @@ from .convert_operators import convert_shape_compare # noqa: F401 ...@@ -36,6 +36,6 @@ from .convert_operators import convert_shape_compare # noqa: F401
from .assert_transformer import AssertTransformer from .assert_transformer import AssertTransformer
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
from .program_translator import convert_to_static from .program_translator import convert_to_static
from .static_analysis import * # noqa: F403 from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
__all__ = [] __all__ = []
...@@ -22,7 +22,7 @@ from .base_transformer import ( ...@@ -22,7 +22,7 @@ from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = ['AssertTransformer'] __all__ = []
class AssertTransformer(BaseTransformer): class AssertTransformer(BaseTransformer):
......
...@@ -71,7 +71,7 @@ from .decorator_transformer import ( ...@@ -71,7 +71,7 @@ from .decorator_transformer import (
from . import logging_utils from . import logging_utils
from .utils import ast_to_source_code from .utils import ast_to_source_code
__all__ = ['DygraphToStaticAst'] __all__ = []
def apply_optimization(transformers): def apply_optimization(transformers):
......
...@@ -27,6 +27,8 @@ from paddle.jit.dy2static.utils import ( ...@@ -27,6 +27,8 @@ from paddle.jit.dy2static.utils import (
get_attribute_full_name, get_attribute_full_name,
) )
__all__ = []
class BaseTransformer(gast.NodeTransformer): class BaseTransformer(gast.NodeTransformer):
def visit(self, node): def visit(self, node):
......
...@@ -23,6 +23,8 @@ from .base_transformer import ( ...@@ -23,6 +23,8 @@ from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = []
class BasicApiTransformer(BaseTransformer): class BasicApiTransformer(BaseTransformer):
""" """
......
...@@ -27,7 +27,7 @@ from .base_transformer import ( ...@@ -27,7 +27,7 @@ from .base_transformer import (
ForNodeVisitor, ForNodeVisitor,
) )
__all__ = ['BreakContinueTransformer'] __all__ = []
BREAK_NAME_PREFIX = '__break' BREAK_NAME_PREFIX = '__break'
CONTINUE_NAME_PREFIX = '__continue' CONTINUE_NAME_PREFIX = '__continue'
......
...@@ -25,6 +25,8 @@ from .base_transformer import ( ...@@ -25,6 +25,8 @@ from .base_transformer import (
PDB_SET = "pdb.set_trace" PDB_SET = "pdb.set_trace"
__all__ = []
class CallTransformer(BaseTransformer): class CallTransformer(BaseTransformer):
""" """
......
...@@ -22,6 +22,8 @@ from .base_transformer import ( ...@@ -22,6 +22,8 @@ from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = []
class CastTransformer(BaseTransformer): class CastTransformer(BaseTransformer):
""" """
......
...@@ -40,7 +40,7 @@ from paddle.jit.dy2static.logging_utils import ( ...@@ -40,7 +40,7 @@ from paddle.jit.dy2static.logging_utils import (
from paddle.jit.dy2static.utils import is_paddle_func, unwrap from paddle.jit.dy2static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"] __all__ = []
# The api(s) should be considered as plain function and convert # The api(s) should be considered as plain function and convert
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import re import re
import paddle import paddle
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.jit.dy2static.variable_trans_func import ( from .variable_trans_func import (
to_static_variable, to_static_variable,
) )
from paddle.fluid.framework import core, Variable from paddle.fluid.framework import core, Variable
...@@ -43,10 +43,13 @@ from .return_transformer import ( ...@@ -43,10 +43,13 @@ from .return_transformer import (
from paddle.jit.dy2static.utils import ( from paddle.jit.dy2static.utils import (
UndefinedVar, UndefinedVar,
Dygraph2StaticException, Dygraph2StaticException,
GetterSetterHelper,
) )
from paddle.jit.dy2static.utils import GetterSetterHelper
from paddle.fluid.layers.utils import copy_mutable_vars from paddle.fluid.layers.utils import copy_mutable_vars
__all__ = []
def convert_attr(x, attr): def convert_attr(x, attr):
if isinstance(x, Variable) and attr == "size": if isinstance(x, Variable) and attr == "size":
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.jit.dy2static.static_analysis import ( from .static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.jit.dy2static.utils import ( from .utils import (
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
) )
from paddle.jit.dy2static.variable_trans_func import ( from .variable_trans_func import (
create_undefined_var, create_undefined_var,
) )
from .base_transformer import ( from .base_transformer import (
...@@ -26,6 +26,9 @@ from .base_transformer import ( ...@@ -26,6 +26,9 @@ from .base_transformer import (
) )
__all__ = []
class CreateVariableTransformer(BaseTransformer): class CreateVariableTransformer(BaseTransformer):
""" """ """ """
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from paddle.utils import gast from paddle.utils import gast
from paddle.jit.dy2static.static_analysis import ( from .static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.jit.dy2static.utils import ( from .utils import (
RE_PYNAME, RE_PYNAME,
RE_PYMODULE, RE_PYMODULE,
ast_to_source_code, ast_to_source_code,
...@@ -29,6 +29,8 @@ import warnings ...@@ -29,6 +29,8 @@ import warnings
import re import re
__all__ = []
IGNORE_NAMES = [ IGNORE_NAMES = [
'declarative', 'declarative',
'to_static', 'to_static',
......
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
from paddle.utils import gast from paddle.utils import gast
from paddle.jit.dy2static.static_analysis import ( from .static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = ['EarlyReturnTransformer'] __all__ = []
class EarlyReturnTransformer(BaseTransformer): class EarlyReturnTransformer(BaseTransformer):
......
...@@ -32,6 +32,8 @@ from .utils import ( ...@@ -32,6 +32,8 @@ from .utils import (
func_to_source_code, func_to_source_code,
) )
__all__ = []
class FunctionSpec: class FunctionSpec:
""" """
......
...@@ -18,7 +18,7 @@ import threading ...@@ -18,7 +18,7 @@ import threading
from paddle.fluid import log_helper from paddle.fluid import log_helper
from .utils import ast_to_source_code from .utils import ast_to_source_code
__all__ = ["TranslatorLogger", "set_verbosity", "set_code_level"] __all__ = []
VERBOSITY_ENV_NAME = 'TRANSLATOR_VERBOSITY' VERBOSITY_ENV_NAME = 'TRANSLATOR_VERBOSITY'
CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL' CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL'
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
from paddle.utils import gast from paddle.utils import gast
from paddle.jit.dy2static.utils import ast_to_source_code from .utils import ast_to_source_code
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = []
cmpop_type_to_str = { cmpop_type_to_str = {
gast.Eq: "==", gast.Eq: "==",
gast.NotEq: "!=", gast.NotEq: "!=",
......
...@@ -17,22 +17,16 @@ from paddle.utils import gast ...@@ -17,22 +17,16 @@ from paddle.utils import gast
from collections import defaultdict from collections import defaultdict
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.jit.dy2static.static_analysis import ( from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
AstNodeWrapper, from .utils import (
) ast_to_source_code,
from paddle.jit.dy2static.static_analysis import NodeVarType get_attribute_full_name,
from paddle.jit.dy2static.static_analysis import (
StaticAnalysisVisitor,
)
from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.jit.dy2static.utils import get_attribute_full_name
from paddle.jit.dy2static.utils import (
create_nonlocal_stmt_nodes, create_nonlocal_stmt_nodes,
create_get_args_node, create_get_args_node,
create_set_args_node, create_set_args_node,
)
from paddle.jit.dy2static.utils import (
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
GetterSetterHelper,
create_name_str,
) )
from .ifelse_transformer import ARGS_NAME from .ifelse_transformer import ARGS_NAME
from .base_transformer import ( from .base_transformer import (
...@@ -41,12 +35,8 @@ from .base_transformer import ( ...@@ -41,12 +35,8 @@ from .base_transformer import (
ForNodeVisitor, ForNodeVisitor,
) )
from paddle.jit.dy2static.utils import (
GetterSetterHelper,
create_name_str,
)
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = []
WHILE_CONDITION_PREFIX = 'while_condition' WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body' WHILE_BODY_PREFIX = 'while_body'
......
...@@ -24,6 +24,8 @@ from paddle.fluid.framework import Program ...@@ -24,6 +24,8 @@ from paddle.fluid.framework import Program
from collections.abc import Sequence from collections.abc import Sequence
__all__ = []
class Location: class Location:
""" """
......
...@@ -44,6 +44,8 @@ from paddle.fluid.dygraph.amp.auto_cast import ( ...@@ -44,6 +44,8 @@ from paddle.fluid.dygraph.amp.auto_cast import (
) )
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
__all__ = []
class NestSequence: class NestSequence:
""" """
......
...@@ -57,7 +57,7 @@ from .function_spec import ( ...@@ -57,7 +57,7 @@ from .function_spec import (
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
__all__ = ['ProgramTranslator', 'convert_to_static'] __all__ = []
# For each traced function, we set `max_traced_program_count` = 10 to consider caching performance. # For each traced function, we set `max_traced_program_count` = 10 to consider caching performance.
# Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected. # Once exceeding the threshold, we will raise warning to users to make sure the conversion is as expected.
......
...@@ -15,22 +15,18 @@ ...@@ -15,22 +15,18 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.jit.dy2static.utils import index_in_list from .utils import (
from .break_continue_transformer import ( index_in_list,
ForToWhileTransformer, ast_to_source_code,
Dygraph2StaticException,
ORIGI_INFO,
) )
from paddle.jit.dy2static.utils import ast_to_source_code from .break_continue_transformer import ForToWhileTransformer
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.jit.dy2static.utils import Dygraph2StaticException
from paddle.jit.dy2static.utils import ORIGI_INFO __all__ = []
__all__ = [
'RETURN_NO_VALUE_MAGIC_NUM',
'RETURN_NO_VALUE_VAR_NAME',
'ReturnTransformer',
]
# Constant for the name of the variable which stores the boolean state that we # Constant for the name of the variable which stores the boolean state that we
# should return # should return
......
...@@ -22,7 +22,7 @@ from .utils import ( ...@@ -22,7 +22,7 @@ from .utils import (
ast_to_source_code, ast_to_source_code,
) )
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor'] __all__ = []
class NodeVarType: class NodeVarType:
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
from paddle.utils import gast from paddle.utils import gast
from paddle.jit.dy2static.utils import ast_to_source_code from .utils import ast_to_source_code
from paddle.jit.dy2static.static_analysis import ( from .static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from .base_transformer import ( from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = []
class TensorShapeTransformer(BaseTransformer): class TensorShapeTransformer(BaseTransformer):
""" """
......
...@@ -20,6 +20,8 @@ from .base_transformer import ( ...@@ -20,6 +20,8 @@ from .base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = []
class TypeHintTransformer(BaseTransformer): class TypeHintTransformer(BaseTransformer):
""" """
......
...@@ -15,18 +15,13 @@ ...@@ -15,18 +15,13 @@
import paddle import paddle
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.jit.dy2static.utils import ( from .utils import (
UndefinedVar, UndefinedVar,
create_undefined_variable, create_undefined_variable,
) )
from paddle.fluid.layers.utils import map_structure, is_sequence from paddle.fluid.layers.utils import map_structure, is_sequence
__all__ = [ __all__ = []
'create_bool_as_type',
'create_fill_constant_node',
'to_static_variable',
'create_undefined_var',
]
def create_undefined_var(name): def create_undefined_var(name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册