未验证 提交 d926b30b 编写于 作者: A Aurelius84 提交者: GitHub

[Fluid Clean]Migrate if/while/return/break transformer into paddle.jit (#48449)

* [Fluid Clean]Migrate if/while/return/break transformer into paddle.jit

* migrate call_transformer

* migrate call_transformer
上级 f41ccbd5
...@@ -697,7 +697,7 @@ class IpuDynamicPatcher: ...@@ -697,7 +697,7 @@ class IpuDynamicPatcher:
MAX_TRACED_PROGRAM_COUNT, MAX_TRACED_PROGRAM_COUNT,
) )
from ..fluid.dygraph.dygraph_to_static import logging_utils from ..fluid.dygraph.dygraph_to_static import logging_utils
from ..fluid.dygraph.dygraph_to_static.partial_program import ( from paddle.jit.dy2static.partial_program import (
partial_program_from, partial_program_from,
) )
......
...@@ -15,23 +15,14 @@ ...@@ -15,23 +15,14 @@
from . import static_analysis from . import static_analysis
from .static_analysis import * from .static_analysis import *
from . import loop_transformer
from .loop_transformer import *
from . import variable_trans_func from . import variable_trans_func
from .variable_trans_func import * from .variable_trans_func import *
from . import convert_call_func
from .convert_call_func import *
from . import convert_operators
from . import logging_utils from . import logging_utils
from .logging_utils import * from .logging_utils import *
__all__ = [] __all__ = []
__all__ += loop_transformer.__all__
__all__ += static_analysis.__all__ __all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__ __all__ += variable_trans_func.__all__
__all__ += convert_call_func.__all__
__all__ += logging_utils.__all__ __all__ += logging_utils.__all__
...@@ -51,7 +51,7 @@ class CallTransformer(BaseTransformer): ...@@ -51,7 +51,7 @@ class CallTransformer(BaseTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
try: try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ( from paddle.jit.dy2static.convert_call_func import (
is_builtin, is_builtin,
) )
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import functools
import logging
import inspect
import pdb
import re
import types
import numpy
import builtins
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import (
convert_len,
convert_zip,
)
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import (
convert_range,
convert_enumerate,
)
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import (
TranslatorLogger,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"]
# The api(s) should be considered as plain function and convert
# them into static layer code.
PADDLE_NEED_CONVERT_APIS = [Sequential]
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
class ConversionOptions:
"""
A container for conversion flags of a function in dynamic-to-static.
Attributes:
not_convert(bool): An attribute indicates that the function won't be converted in dynamic-to-static.
NOTE(liym27): More attributes and methods can be added in this class.
"""
def __init__(self, not_convert=False):
self.not_convert = not_convert
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
def builtin_modules():
"""
Return builtin modules.
"""
modules = [
collections,
pdb,
copy,
inspect,
re,
numpy,
logging,
]
try:
import six
modules.append(six)
except ImportError:
pass # do nothing
return modules
BUILTIN_LIKELY_MODULES = builtin_modules()
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
"""
for m in BUILTIN_LIKELY_MODULES:
for v in m.__dict__.values():
func_in_dict = func == v
if isinstance(func_in_dict, (list, numpy.ndarray)):
func_in_dict = numpy.array(func_in_dict).any()
if func_in_dict:
translator_logger.log(
2,
"Whitelist: {} is part of built-in module and does not have to be transformed.".format(
func
),
)
return True
# NOTE: should be placed before `is_paddle_func`
if type(func) in PADDLE_NEED_CONVERT_APIS:
return False
if is_paddle_func(func):
translator_logger.log(
2,
"Whitelist: {} is part of Paddle module and does not have to be transformed.".format(
func
),
)
return True
def convert_call(func):
"""
Converts a function call which needs to be transformed to static function.
Args:
func (callable): A callable function or method to convert.
Returns:
Callable: A converted function.
Examples:
.. code-block:: python
import paddle
from paddle.jit.dy2static import convert_call
paddle.enable_static()
def dyfunc(x):
if paddle.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
new_func = convert_call(dyfunc)
x = paddle.tensor.manipulation.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = new_func(x)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
"""
# NOTE(Aurelius84): Fix it after all files migrating into jit.
from paddle.jit.dy2static.program_translator import (
convert_to_static,
unwrap_decorators,
StaticFunction,
)
translator_logger.log(
1, "Convert callable object: convert {}.".format(func)
)
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function.
_, func = unwrap_decorators(func)
options = getattr(func, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
translator_logger.log(
2,
"{} is not converted when it is decorated by 'paddle.jit.not_to_static'.".format(
func
),
)
return func
if is_builtin(func, "len"):
return convert_len
if is_builtin(func, "zip"):
return convert_zip
if is_builtin(func, "range"):
return convert_range
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func) or is_unsupported(func):
return func
if inspect.isgeneratorfunction(func):
# NOTE(xiongkun03): inspect.isfunction() will return True even though func is a generator function.
# If we don't deal generatorfunction here, we will regard it as normal function and get errors in some
# occasion.
number_of_stars = 30
translator_logger.warn(
"\n\n"
+ "*" * number_of_stars
+ "\nYour function:`{}` doesn't support to transform to static function because it is a generator function, it will be run as-is.".format(
func.__name__
)
+ "\n"
+ "*" * number_of_stars
+ "\n\n"
)
return func
if inspect.isfunction(func):
# TODO(liym27): If func is a lambda function, special conversion is needed.
if func.__name__ == '<lambda>':
return func
try:
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.
# For example:
#
# @declarative
# def foo(x):
# return x
#
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
_origfunc = unwrap(func)
global_functions = set()
for fn in _origfunc.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(
fn.__dict__.get(func.__name__, None), staticmethod
):
global_functions.add(
func
) # Add func to ensure that we will convert
if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
else:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
translator_logger.warn(
"{} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is.".format(
func
)
)
converted_call = func
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
converted_call = None
except (IOError, OSError):
# NOTE:
# If func has been decorated, its source code can not be get
# so that it can not be transformed to static function.
converted_call = None
elif inspect.ismethod(func):
try:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have been decorated.
converted_call = None
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
_, forward_func = unwrap_decorators(func.forward)
func._original_funcs['forward'] = forward_func.__func__
forward_func = convert_to_static(forward_func)
# Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
converted_call = func
else:
try:
call_func = func.__class__.__call__
converted_call = convert_to_static(call_func)
func_self = func
except (IOError, OSError, TypeError):
# NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
func_self = None if func_self else func_self
else:
raise NotImplementedError(
"Callable {} can not be transformed at present.".format(func)
)
if converted_call is None:
translator_logger.warn(
"{} doesn't have to be transformed to static function, and it will be run as-is.".format(
func
)
)
return func
if func_self:
converted_call = functools.partial(converted_call, func_self)
return converted_call
...@@ -145,7 +145,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -145,7 +145,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def create_undefined_variable(): def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ( from paddle.jit.dy2static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM, RETURN_NO_VALUE_MAGIC_NUM,
) )
...@@ -1212,13 +1212,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1212,13 +1212,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
"""NOTE: why we need merge w_vars and push_pop_vars here ? """NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
""" """
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import ( from paddle.jit.dy2static.loop_transformer import (
WHILE_CONDITION_PREFIX, WHILE_CONDITION_PREFIX,
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX, FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX, FOR_BODY_PREFIX,
) )
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ( from paddle.jit.dy2static.ifelse_transformer import (
TRUE_FUNC_PREFIX, TRUE_FUNC_PREFIX,
FALSE_FUNC_PREFIX, FALSE_FUNC_PREFIX,
) )
......
...@@ -30,7 +30,7 @@ from paddle.fluid.executor import ( ...@@ -30,7 +30,7 @@ from paddle.fluid.executor import (
_is_enable_standalone_executor, _is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor, _is_dy2st_enable_standalone_executor,
) )
from paddle.fluid.dygraph.dygraph_to_static.partial_program import ( from paddle.jit.dy2static.partial_program import (
add_build_strategy_for, add_build_strategy_for,
LazyInitialized, LazyInitialized,
) )
......
...@@ -2589,7 +2589,7 @@ def expand_undefined_var(nest1, nest2, names): ...@@ -2589,7 +2589,7 @@ def expand_undefined_var(nest1, nest2, names):
In this case, we should not expand recursively. In this case, we should not expand recursively.
""" """
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ( from paddle.jit.dy2static.return_transformer import (
RETURN_VALUE_PREFIX, RETURN_VALUE_PREFIX,
) )
......
...@@ -275,7 +275,7 @@ def monkey_patch_variable(): ...@@ -275,7 +275,7 @@ def monkey_patch_variable():
Returns: Returns:
Variable: self[index] Variable: self[index]
""" """
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import ( from paddle.jit.dy2static.convert_operators import (
_run_paddle_pop, _run_paddle_pop,
) )
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.jit import ProgramTranslator from paddle.jit import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ( from paddle.jit.dy2static.convert_call_func import (
CONVERSION_OPTIONS, CONVERSION_OPTIONS,
) )
from test_program_translator import get_source_code from test_program_translator import get_source_code
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.jit.api import declarative from paddle.jit.api import declarative
from paddle.fluid.dygraph.dygraph_to_static import convert_call from paddle.jit.dy2static import Call
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
...@@ -90,11 +90,11 @@ def len_with_selected_rows(place): ...@@ -90,11 +90,11 @@ def len_with_selected_rows(place):
) )
# y is Variable(SelectedRows) # y is Variable(SelectedRows)
y = fluid.layers.merge_selected_rows(var) y = fluid.layers.merge_selected_rows(var)
y_len = convert_call(len)(y) y_len = Call(len)(y)
# z is inner tensor with shape [4, 2] # z is inner tensor with shape [4, 2]
z = fluid.layers.get_tensor_from_selected_rows(y) z = fluid.layers.get_tensor_from_selected_rows(y)
z_len = convert_call(len)(z) z_len = Call(len)(z)
# set data for selected_rows # set data for selected_rows
x_rows = [0, 2, 2, 4, 19] x_rows = [0, 2, 2, 4, 19]
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor from paddle.jit.dy2static.loop_transformer import NameVisitor
from paddle.jit.api import declarative from paddle.jit.api import declarative
SEED = 2020 SEED = 2020
......
...@@ -21,7 +21,7 @@ import paddle ...@@ -21,7 +21,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.jit import ProgramTranslator from paddle.jit import ProgramTranslator
from paddle.jit.api import declarative from paddle.jit.api import declarative
from paddle.fluid.dygraph.dygraph_to_static.partial_program import ( from paddle.jit.dy2static.partial_program import (
partial_program_from, partial_program_from,
) )
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
......
...@@ -35,7 +35,7 @@ from paddle.fluid.dygraph.base import ( ...@@ -35,7 +35,7 @@ from paddle.fluid.dygraph.base import (
switch_to_static_graph, switch_to_static_graph,
) )
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ( from paddle.jit.dy2static.convert_call_func import (
ConversionOptions, ConversionOptions,
CONVERSION_OPTIONS, CONVERSION_OPTIONS,
) )
......
...@@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ( ...@@ -22,6 +22,8 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = ['AssertTransformer']
class AssertTransformer(BaseTransformer): class AssertTransformer(BaseTransformer):
""" """
......
...@@ -21,7 +21,7 @@ import os ...@@ -21,7 +21,7 @@ import os
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import ( from .early_return_transformer import (
EarlyReturnTransformer, EarlyReturnTransformer,
) )
from .assert_transformer import ( from .assert_transformer import (
...@@ -30,10 +30,8 @@ from .assert_transformer import ( ...@@ -30,10 +30,8 @@ from .assert_transformer import (
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import (
BasicApiTransformer, BasicApiTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ( from .break_continue_transformer import (
BreakContinueTransformer, BreakContinueTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import (
BreakTransformOptimizer, BreakTransformOptimizer,
) )
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.call_transformer import (
...@@ -45,19 +43,19 @@ from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import ( ...@@ -45,19 +43,19 @@ from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import (
from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import (
TypeHintTransformer, TypeHintTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ( from .ifelse_transformer import (
IfElseTransformer, IfElseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import (
LogicalTransformer, LogicalTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import ( from .loop_transformer import (
LoopTransformer, LoopTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.print_transformer import (
PrintTransformer, PrintTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ( from .return_transformer import (
ReturnTransformer, ReturnTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import (
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,8 +12,330 @@ ...@@ -12,8 +12,330 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...fluid.dygraph.dygraph_to_static.convert_call_func import ( # noqa: F401 import collections
convert_call, import copy
import functools
import logging
import inspect
import pdb
import re
import types
import numpy
import builtins
from paddle.fluid.dygraph.container import Sequential
from .convert_operators import (
convert_len,
convert_zip,
convert_range,
convert_enumerate,
)
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import (
TranslatorLogger,
) )
__all__ = [] from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func, unwrap
from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"]
# The api(s) should be considered as plain function and convert
# them into static layer code.
PADDLE_NEED_CONVERT_APIS = [Sequential]
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
class ConversionOptions:
"""
A container for conversion flags of a function in dynamic-to-static.
Attributes:
not_convert(bool): An attribute indicates that the function won't be converted in dynamic-to-static.
NOTE(liym27): More attributes and methods can be added in this class.
"""
def __init__(self, not_convert=False):
self.not_convert = not_convert
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""
def name_judge():
return name is None or func.__name__ == name
if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in builtins.__dict__.values() and name_judge():
return True
else:
return False
def builtin_modules():
"""
Return builtin modules.
"""
modules = [
collections,
pdb,
copy,
inspect,
re,
numpy,
logging,
]
try:
import six
modules.append(six)
except ImportError:
pass # do nothing
return modules
BUILTIN_LIKELY_MODULES = builtin_modules()
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
"""
for m in BUILTIN_LIKELY_MODULES:
for v in m.__dict__.values():
func_in_dict = func == v
if isinstance(func_in_dict, (list, numpy.ndarray)):
func_in_dict = numpy.array(func_in_dict).any()
if func_in_dict:
translator_logger.log(
2,
"Whitelist: {} is part of built-in module and does not have to be transformed.".format(
func
),
)
return True
# NOTE: should be placed before `is_paddle_func`
if type(func) in PADDLE_NEED_CONVERT_APIS:
return False
if is_paddle_func(func):
translator_logger.log(
2,
"Whitelist: {} is part of Paddle module and does not have to be transformed.".format(
func
),
)
return True
def convert_call(func):
"""
Converts a function call which needs to be transformed to static function.
Args:
func (callable): A callable function or method to convert.
Returns:
Callable: A converted function.
Examples:
.. code-block:: python
import paddle
from paddle.jit.dy2static import Call
paddle.enable_static()
def dyfunc(x):
if paddle.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
new_func = Call(dyfunc)
x = paddle.tensor.manipulation.fill_constant(shape=[3, 3], value=0, dtype='float64')
x_v = new_func(x)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(fetch_list=[x_v])
print(out[0])
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
"""
# NOTE(Aurelius84): Fix it after all files migrating into jit.
from paddle.jit.dy2static.program_translator import (
convert_to_static,
unwrap_decorators,
StaticFunction,
)
translator_logger.log(
1, "Convert callable object: convert {}.".format(func)
)
func_self = None
converted_call = None
# Function in convert_call may be decorated by another `@to_static`,
# in this case, unwraps it into a raw method or function.
_, func = unwrap_decorators(func)
options = getattr(func, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
translator_logger.log(
2,
"{} is not converted when it is decorated by 'paddle.jit.not_to_static'.".format(
func
),
)
return func
if is_builtin(func, "len"):
return convert_len
if is_builtin(func, "zip"):
return convert_zip
if is_builtin(func, "range"):
return convert_range
if is_builtin(func, "enumerate"):
return convert_enumerate
if is_builtin(func) or is_unsupported(func):
return func
if inspect.isgeneratorfunction(func):
# NOTE(xiongkun03): inspect.isfunction() will return True even though func is a generator function.
# If we don't deal generatorfunction here, we will regard it as normal function and get errors in some
# occasion.
number_of_stars = 30
translator_logger.warn(
"\n\n"
+ "*" * number_of_stars
+ "\nYour function:`{}` doesn't support to transform to static function because it is a generator function, it will be run as-is.".format(
func.__name__
)
+ "\n"
+ "*" * number_of_stars
+ "\n\n"
)
return func
if inspect.isfunction(func):
# TODO(liym27): If func is a lambda function, special conversion is needed.
if func.__name__ == '<lambda>':
return func
try:
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.
# For example:
#
# @declarative
# def foo(x):
# return x
#
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
_origfunc = unwrap(func)
global_functions = set()
for fn in _origfunc.__globals__.values():
if inspect.isfunction(fn):
global_functions.add(fn)
elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(
fn.__dict__.get(func.__name__, None), staticmethod
):
global_functions.add(
func
) # Add func to ensure that we will convert
if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
else:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
translator_logger.warn(
"{} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is.".format(
func
)
)
converted_call = func
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
converted_call = None
except (IOError, OSError):
# NOTE:
# If func has been decorated, its source code can not be get
# so that it can not be transformed to static function.
converted_call = None
elif inspect.ismethod(func):
try:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have been decorated.
converted_call = None
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
_, forward_func = unwrap_decorators(func.forward)
func._original_funcs['forward'] = forward_func.__func__
forward_func = convert_to_static(forward_func)
# Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
converted_call = func
else:
try:
call_func = func.__class__.__call__
converted_call = convert_to_static(call_func)
func_self = func
except (IOError, OSError, TypeError):
# NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
func_self = None if func_self else func_self
else:
raise NotImplementedError(
"Callable {} can not be transformed at present.".format(func)
)
if converted_call is None:
translator_logger.warn(
"{} doesn't have to be transformed to static function, and it will be run as-is.".format(
func
)
)
return func
if func_self:
converted_call = functools.partial(converted_call, func_self)
return converted_call
...@@ -20,6 +20,8 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ( ...@@ -20,6 +20,8 @@ from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer, BaseTransformer,
) )
__all__ = ['EarlyReturnTransformer']
class EarlyReturnTransformer(BaseTransformer): class EarlyReturnTransformer(BaseTransformer):
""" """
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import copy import copy
import textwrap
from collections import defaultdict from collections import defaultdict
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
...@@ -28,18 +27,11 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ( ...@@ -28,18 +27,11 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
ast_to_source_code, ast_to_source_code,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_assign_node,
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
) )
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
StaticAnalysisVisitor,
)
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
AstNodeWrapper, AstNodeWrapper,
) )
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
create_undefined_var,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_nonlocal_stmt_nodes, create_nonlocal_stmt_nodes,
) )
...@@ -65,6 +57,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ( ...@@ -65,6 +57,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_name_str, create_name_str,
) )
__all__ = ['IfElseTransformer']
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
GET_ARGS_FUNC_PREFIX = 'get_args' GET_ARGS_FUNC_PREFIX = 'get_args'
......
...@@ -25,11 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( ...@@ -25,11 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
StaticAnalysisVisitor, StaticAnalysisVisitor,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
create_undefined_var,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.fluid.dygraph.dygraph_to_static.utils import (
create_nonlocal_stmt_nodes, create_nonlocal_stmt_nodes,
create_get_args_node, create_get_args_node,
...@@ -38,13 +34,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ( ...@@ -38,13 +34,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import (
from paddle.fluid.dygraph.dygraph_to_static.utils import ( from paddle.fluid.dygraph.dygraph_to_static.utils import (
FunctionNameLivenessAnalysis, FunctionNameLivenessAnalysis,
) )
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME from .ifelse_transformer import ARGS_NAME
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer, BaseTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
RenameTransformer,
)
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
ForLoopTuplePreTransformer, ForLoopTuplePreTransformer,
) )
...@@ -217,7 +210,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -217,7 +210,7 @@ class NameVisitor(gast.NodeVisitor):
# If this var is a basic variable and read-only and not # If this var is a basic variable and read-only and not
# condition var, it may not be loop_var else it should # condition var, it may not be loop_var else it should
# be in loop_var as input # be in loop_var as input
if (not name in condition_names) and (not name in write_names): if (name not in condition_names) and (name not in write_names):
continue continue
loop_var_names.add(name) loop_var_names.add(name)
......
...@@ -23,7 +23,7 @@ from paddle.fluid.executor import ( ...@@ -23,7 +23,7 @@ from paddle.fluid.executor import (
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ( from .return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM, RETURN_NO_VALUE_MAGIC_NUM,
) )
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
......
...@@ -37,7 +37,7 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import ( ...@@ -37,7 +37,7 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
from paddle.fluid.dygraph.dygraph_to_static.origin_info import ( from paddle.fluid.dygraph.dygraph_to_static.origin_info import (
update_op_callstack_with_origin_info, update_op_callstack_with_origin_info,
) )
from paddle.fluid.dygraph.dygraph_to_static.partial_program import ( from .partial_program import (
partial_program_from, partial_program_from,
) )
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
......
...@@ -16,12 +16,9 @@ from paddle.utils import gast ...@@ -16,12 +16,9 @@ from paddle.utils import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ( from .break_continue_transformer import (
ForToWhileTransformer, ForToWhileTransformer,
) )
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
create_fill_constant_node,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ( from paddle.fluid.dygraph.dygraph_to_static.base_transformer import (
BaseTransformer, BaseTransformer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册