未验证 提交 d926b30b 编写于 作者: A Aurelius84 提交者: GitHub

[Fluid Clean]Migrate if/while/return/break transformer into paddle.jit (#48449)

* [Fluid Clean]Migrate if/while/return/break transformer into paddle.jit

* migrate call_transformer

* migrate call_transformer
上级 f41ccbd5
......@@ -697,7 +697,7 @@ class IpuDynamicPatcher:
MAX_TRACED_PROGRAM_COUNT,
)
from ..fluid.dygraph.dygraph_to_static import logging_utils
from ..fluid.dygraph.dygraph_to_static.partial_program import (
from paddle.jit.dy2static.partial_program import (
partial_program_from,
)
......
......@@ -15,23 +15,14 @@
from . import static_analysis
from .static_analysis import *
from . import loop_transformer
from .loop_transformer import *
from . import variable_trans_func
from .variable_trans_func import *
from . import convert_call_func
from .convert_call_func import *
from . import convert_operators
from . import logging_utils
from .logging_utils import *
__all__ = []
__all__ += loop_transformer.__all__
__all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += convert_call_func.__all__
__all__ += logging_utils.__all__
......@@ -51,7 +51,7 @@ class CallTransformer(BaseTransformer):
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import (
from paddle.jit.dy2static.convert_call_func import (
is_builtin,
)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import functools
import logging
import inspect
import pdb
import re
import types
import numpy
import builtins
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import (
convert_len,
convert_zip,
)
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import (
convert_range,
convert_enumerate,
)
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import (
TranslatorLogger,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"]
# The api(s) should be considered as plain function and convert
# them into static layer code.
PADDLE_NEED_CONVERT_APIS = [Sequential]
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
class ConversionOptions:
"""
A container for conversion flags of a function in dynamic-to-static.
Attributes:
not_convert(bool): An attribute indicates that the function won't be converted in dynamic-to-static.
NOTE(liym27): More attributes and methods can be added in this class.
"""
def __init__(self, not_convert=False):
self.not_convert = not_convert
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
def builtin_modules():
"""
Return builtin modules.
"""
modules = [
collections,
pdb,
copy,
inspect,
re,
numpy,
logging,
]
try:
import six
modules.append(six)
except ImportError:
pass # do nothing
return modules
BUILTIN_LIKELY_MODULES = builtin_modules()
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
"""
for m in BUILTIN_LIKELY_MODULES:
for v in m.__dict__.values():
func_in_dict = func == v
if isinstance(func_in_dict, (list, numpy.ndarray)):
func_in_dict = numpy.array(func_in_dict).any()
if func_in_dict:
translator_logger.log(
2,
"Whitelist: {} is part of built-in module and does not have to be transformed.".format(
func
),
)
return True
# NOTE: should be placed before `is_paddle_func`
if type(func) in PADDLE_NEED_CONVERT_APIS:
return False
if is_paddle_func(func):
translator_logger.log(
2,
"Whitelist: {} is part of Paddle module and does not have to be transformed.".format(
func
),
)
return True
def convert_call(func):
"""
Converts a function call which needs to be transformed to static function.
Args:
func (callable): A callable function or method to convert.
Returns:
Callable: A converted function.
Examples:
.. code-block:: python
import paddle
from paddle.jit.dy2static import convert_call
paddle.enable_static()
def dyfunc(x):
if paddle.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
new_func = convert_call(dyfunc)
x = paddle.tensor.manipulation.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = new_func(x)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
"""
# NOTE(Aurelius84): Fix it after all files migrating into jit.
from paddle.jit.dy2static.program_translator import (
convert_to_static,
unwrap_decorators,
StaticFunction,
)
translator_logger.log(
1, "Convert callable object: convert {}.".format(func)
)
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function.
_, func = unwrap_decorators(func)
options = getattr(func, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
translator_logger.log(
2,
"{} is not converted when it is decorated by 'paddle.jit.not_to_static'.".format(
func
),
)
return func
if is_builtin(func, "len"):
return convert_len
if is_builtin(func, "zip"):
return convert_zip
if is_builtin(func, "range"):
return convert_range
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func) or is_unsupported(func):
return func
if inspect.isgeneratorfunction(func):
# NOTE(xiongkun03): inspect.isfunction() will return True even though func is a generator function.
# If we don't deal generatorfunction here, we will regard it as normal function and get errors in some
# occasion.
number_of_stars = 30
translator_logger.warn(
"\n\n"
+ "*" * number_of_stars
+ "\nYour function:`{}` doesn't support to transform to static function because it is a generator function, it will be run as-is.".format(
func.__name__
)
+ "\n"
+ "*" * number_of_stars
+ "\n\n"
)
return func
if inspect.isfunction(func):
# TODO(liym27): If func is a lambda function, special conversion is needed.
if func.__name__ == '<lambda>':
return func
try:
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.
# For example:
#
# @declarative
# def foo(x):
# return x
#
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
_origfunc = unwrap(func)
global_functions = set()
for fn in _origfunc.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(
fn.__dict__.get(func.__name__, None), staticmethod
):
global_functions.add(
func
) # Add func to ensure that we will convert
if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
else:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
translator_logger.warn(
"{} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is.".format(
func
)
)
converted_call = func
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
converted_call = None
except (IOError, OSError):
# NOTE:
# If func has been decorated, its source code can not be get
# so that it can not be transformed to static function.
converted_call = None
elif inspect.ismethod(func):
try:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have been decorated.
converted_call = None
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
_, forward_func = unwrap_decorators(func.forward)
func._original_funcs['forward'] = forward_func.__func__
forward_func = convert_to_static(forward_func)
# Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
converted_call = func
else:
try:
call_func = func.__class__.__call__
converted_call = convert_to_static(call_func)
func_self = func
except (IOError, OSError, TypeError):
# NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
func_self = None if func_self else func_self
else:
raise NotImplementedError(
"Callable {} can not be transformed at present.".format(func)
)
if converted_call is None:
translator_logger.warn(
"{} doesn't have to be transformed to static function, and it will be run as-is.".format(
func
)
)
return func
if func_self:
converted_call = functools.partial(converted_call, func_self)
return converted_call
......@@ -145,7 +145,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
from paddle.jit.dy2static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
......@@ -1212,13 +1212,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
"""NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
"""
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import (
from paddle.jit.dy2static.loop_transformer import (
WHILE_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX,
)
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import (
from paddle.jit.dy2static.ifelse_transformer import (
TRUE_FUNC_PREFIX,
FALSE_FUNC_PREFIX,
)
......
......@@ -30,7 +30,7 @@ from paddle.fluid.executor import (
_is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
)
from paddle.fluid.dygraph.dygraph_to_static.partial_program import (
from paddle.jit.dy2static.partial_program import (
add_build_strategy_for,
LazyInitialized,
)
......
......@@ -2589,7 +2589,7 @@ def expand_undefined_var(nest1, nest2, names):
In this case, we should not expand recursively.
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
from paddle.jit.dy2static.return_transformer import (
RETURN_VALUE_PREFIX,
)
......
......@@ -275,7 +275,7 @@ def monkey_patch_variable():
Returns:
Variable: self[index]
"""
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import (
from paddle.jit.dy2static.convert_operators import (
_run_paddle_pop,
)
......
......@@ -20,7 +20,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.jit import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import (
from paddle.jit.dy2static.convert_call_func import (
CONVERSION_OPTIONS,
)
from test_program_translator import get_source_code
......
......@@ -18,7 +18,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.jit.api import declarative
from paddle.fluid.dygraph.dygraph_to_static import convert_call
from paddle.jit.dy2static import Call
SEED = 2020
np.random.seed(SEED)
......@@ -90,11 +90,11 @@ def len_with_selected_rows(place):
)
# y is Variable(SelectedRows)
y = fluid.layers.merge_selected_rows(var)
y_len = convert_call(len)(y)
y_len = Call(len)(y)
# z is inner tensor with shape [4, 2]
z = fluid.layers.get_tensor_from_selected_rows(y)
z_len = convert_call(len)(z)
z_len = Call(len)(z)
# set data for selected_rows
x_rows = [0, 2, 2, 4, 19]
......
......@@ -19,7 +19,7 @@ import paddle
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor
from paddle.jit.dy2static.loop_transformer import NameVisitor
from paddle.jit.api import declarative
SEED = 2020
......
......@@ -21,7 +21,7 @@ import paddle
import paddle.fluid as fluid
from paddle.jit import ProgramTranslator
from paddle.jit.api import declarative
from paddle.fluid.dygraph.dygraph_to_static.partial_program import (
from paddle.jit.dy2static.partial_program import (
partial_program_from,
)
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
......
......@@ -35,7 +35,7 @@ from paddle.fluid.dygraph.base import (
switch_to_static_graph,
)
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import (
from paddle.jit.dy2static.convert_call_func import (
ConversionOptions,
CONVERSION_OPTIONS,
)
......
......@@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer,
)
__all__ = ['AssertTransformer']
class AssertTransformer(BaseTransformer):
"""
......
......@@ -21,7 +21,7 @@ import os
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import (
from .early_return_transformer import (
EarlyReturnTransformer,
)
from .assert_transformer import (
......@@ -30,10 +30,8 @@ from .assert_transformer import (
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import (
BasicApiTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import (
from .break_continue_transformer import (
BreakContinueTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import (
BreakTransformOptimizer,
)
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import (
......@@ -45,19 +43,19 @@ from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import (
from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import (
TypeHintTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import (
from .ifelse_transformer import (
IfElseTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import (
LogicalTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import (
from .loop_transformer import (
LoopTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import (
PrintTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
from .return_transformer import (
ReturnTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import (
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,8 +12,330 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.dygraph.dygraph_to_static.convert_call_func import ( # noqa: F401
convert_call,
import collections
import copy
import functools
import logging
import inspect
import pdb
import re
import types
import numpy
import builtins
from paddle.fluid.dygraph.container import Sequential
from .convert_operators import (
convert_len,
convert_zip,
convert_range,
convert_enumerate,
)
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import (
TranslatorLogger,
)
__all__ = []
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"]
# The api(s) should be considered as plain function and convert
# them into static layer code.
PADDLE_NEED_CONVERT_APIS = [Sequential]
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
class ConversionOptions:
"""
A container for conversion flags of a function in dynamic-to-static.
Attributes:
not_convert(bool): An attribute indicates that the function won't be converted in dynamic-to-static.
NOTE(liym27): More attributes and methods can be added in this class.
"""
def __init__(self, not_convert=False):
self.not_convert = not_convert
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
def builtin_modules():
"""
Return builtin modules.
"""
modules = [
collections,
pdb,
copy,
inspect,
re,
numpy,
logging,
]
try:
import six
modules.append(six)
except ImportError:
pass # do nothing
return modules
BUILTIN_LIKELY_MODULES = builtin_modules()
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
"""
for m in BUILTIN_LIKELY_MODULES:
for v in m.__dict__.values():
func_in_dict = func == v
if isinstance(func_in_dict, (list, numpy.ndarray)):
func_in_dict = numpy.array(func_in_dict).any()
if func_in_dict:
translator_logger.log(
2,
"Whitelist: {} is part of built-in module and does not have to be transformed.".format(
func
),
)
return True
# NOTE: should be placed before `is_paddle_func`
if type(func) in PADDLE_NEED_CONVERT_APIS:
return False
if is_paddle_func(func):
translator_logger.log(
2,
"Whitelist: {} is part of Paddle module and does not have to be transformed.".format(
func
),
)
return True
def convert_call(func):
"""
Converts a function call which needs to be transformed to static function.
Args:
func (callable): A callable function or method to convert.
Returns:
Callable: A converted function.
Examples:
.. code-block:: python
import paddle
from paddle.jit.dy2static import Call
paddle.enable_static()
def dyfunc(x):
if paddle.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
new_func = Call(dyfunc)
x = paddle.tensor.manipulation.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = new_func(x)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
"""
# NOTE(Aurelius84): Fix it after all files migrating into jit.
from paddle.jit.dy2static.program_translator import (
convert_to_static,
unwrap_decorators,
StaticFunction,
)
translator_logger.log(
1, "Convert callable object: convert {}.".format(func)
)
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function.
_, func = unwrap_decorators(func)
options = getattr(func, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
translator_logger.log(
2,
"{} is not converted when it is decorated by 'paddle.jit.not_to_static'.".format(
func
),
)
return func
if is_builtin(func, "len"):
return convert_len
if is_builtin(func, "zip"):
return convert_zip
if is_builtin(func, "range"):
return convert_range
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func) or is_unsupported(func):
return func
if inspect.isgeneratorfunction(func):
# NOTE(xiongkun03): inspect.isfunction() will return True even though func is a generator function.
# If we don't deal generatorfunction here, we will regard it as normal function and get errors in some
# occasion.
number_of_stars = 30
translator_logger.warn(
"\n\n"
+ "*" * number_of_stars
+ "\nYour function:`{}` doesn't support to transform to static function because it is a generator function, it will be run as-is.".format(
func.__name__
)
+ "\n"
+ "*" * number_of_stars
+ "\n\n"
)
return func
if inspect.isfunction(func):
# TODO(liym27): If func is a lambda function, special conversion is needed.
if func.__name__ == '<lambda>':
return func
try:
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.
# For example:
#
# @declarative
# def foo(x):
# return x
#
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
_origfunc = unwrap(func)
global_functions = set()
for fn in _origfunc.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(
fn.__dict__.get(func.__name__, None), staticmethod
):
global_functions.add(
func
) # Add func to ensure that we will convert
if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
else:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
translator_logger.warn(
"{} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is.".format(
func
)
)
converted_call = func
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
converted_call = None
except (IOError, OSError):
# NOTE:
# If func has been decorated, its source code can not be get
# so that it can not be transformed to static function.
converted_call = None
elif inspect.ismethod(func):
try:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have been decorated.
converted_call = None
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
_, forward_func = unwrap_decorators(func.forward)
func._original_funcs['forward'] = forward_func.__func__
forward_func = convert_to_static(forward_func)
# Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
converted_call = func
else:
try:
call_func = func.__class__.__call__
converted_call = convert_to_static(call_func)
func_self = func
except (IOError, OSError, TypeError):
# NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
func_self = None if func_self else func_self
else:
raise NotImplementedError(
"Callable {} can not be transformed at present.".format(func)
)
if converted_call is None:
translator_logger.warn(
"{} doesn't have to be transformed to static function, and it will be run as-is.".format(
func
)
)
return func
if func_self:
converted_call = functools.partial(converted_call, func_self)
return converted_call
......@@ -20,6 +20,8 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer,
)
__all__ = ['EarlyReturnTransformer']
class EarlyReturnTransformer(BaseTransformer):
"""
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import copy
import textwrap
from collections import defaultdict
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
......@@ -28,18 +27,11 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
ast_to_source_code,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_assign_node,
FunctionNameLivenessAnalysis,
)
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
StaticAnalysisVisitor,
)
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
AstNodeWrapper,
)
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
create_undefined_var,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_nonlocal_stmt_nodes,
)
......@@ -65,6 +57,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_name_str,
)
__all__ = ['IfElseTransformer']
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
GET_ARGS_FUNC_PREFIX = 'get_args'
......
......@@ -25,11 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
StaticAnalysisVisitor,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
create_undefined_var,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_nonlocal_stmt_nodes,
create_get_args_node,
......@@ -38,13 +34,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
from paddle.fluid.dygraph.dygraph_to_static.utils import (
FunctionNameLivenessAnalysis,
)
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
from .ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
RenameTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
ForLoopTuplePreTransformer,
)
......@@ -217,7 +210,7 @@ class NameVisitor(gast.NodeVisitor):
# If this var is a basic variable and read-only and not
# condition var, it may not be loop_var else it should
# be in loop_var as input
if (not name in condition_names) and (not name in write_names):
if (name not in condition_names) and (name not in write_names):
continue
loop_var_names.add(name)
......
......@@ -23,7 +23,7 @@ from paddle.fluid.executor import (
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
from .return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
from paddle.fluid.layers.utils import flatten
......
......@@ -37,7 +37,7 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
update_op_callstack_with_origin_info,
)
from paddle.fluid.dygraph.dygraph_to_static.partial_program import (
from .partial_program import (
partial_program_from,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
......
......@@ -16,12 +16,9 @@ from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import (
from .break_continue_transformer import (
ForToWhileTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
create_fill_constant_node,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册