未验证 提交 e8134e87 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Refine caches of converting function (#25085)

* Refine caches of converting func test=develop

* fix if statement test=develop

* refine cache code test=develop

* rm unuse import statement test=develop

* Polish code comment test=develop
上级 1eb9ee24
...@@ -19,8 +19,6 @@ from __future__ import print_function ...@@ -19,8 +19,6 @@ from __future__ import print_function
# as produced by ast.parse from the standard ast module. # as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/ # See details in https://github.com/serge-sans-paille/gast/
import gast import gast
import inspect
import textwrap
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
...@@ -34,10 +32,9 @@ from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransf ...@@ -34,10 +32,9 @@ from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransf
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
__all__ = ['DygraphToStaticAst', 'convert_to_static'] __all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func'] DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
...@@ -54,7 +51,6 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -54,7 +51,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
) )
self.decorate_func_name = None self.decorate_func_name = None
self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root return self.static_analysis_root
...@@ -65,7 +61,6 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -65,7 +61,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform basic api of dygraph to static graph and get feed_name_to_arg_name # Transform basic api of dygraph to static graph and get feed_name_to_arg_name
basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans.transform() basic_api_trans.transform()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform Tensor.shape into fluid.layers.shape(Tensor) # Transform Tensor.shape into fluid.layers.shape(Tensor)
TensorShapeTransformer(node_wrapper).transform() TensorShapeTransformer(node_wrapper).transform()
...@@ -97,8 +92,6 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -97,8 +92,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if self.decorate_func_name is None: if self.decorate_func_name is None:
self.decorate_func_name = node.name self.decorate_func_name = node.name
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx[arg.id] = idx
self.generic_visit(node) self.generic_visit(node)
# Remove the decorated name of dygraph_to_static # Remove the decorated name of dygraph_to_static
...@@ -132,30 +125,3 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -132,30 +125,3 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Should consider BaseAPITransformer which add new module name in Yamei's PR. # Should consider BaseAPITransformer which add new module name in Yamei's PR.
assert self.decorate_func_name, "decorate_func_name shall not be None." assert self.decorate_func_name, "decorate_func_name shall not be None."
return self.decorate_func_name return self.decorate_func_name
def get_feed_name_to_idx(self):
feed_name_to_idx = {}
for feed_name, arg_name in self.feed_name_to_arg_name.items():
feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name)
return feed_name_to_idx
def convert_to_static(dyfunc):
"""
Converts dygraph function into static function.
"""
# Get AST from dygraph function
# Note: In Python2, it will raise OSError when inspect function
# with decorator directly and dyfunc.__wrapped__ holds the actual function.
dyfunc = getattr(dyfunc, '__wrapped__', dyfunc)
raw_code = inspect.getsource(dyfunc)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
# Get static_func from AST
static_func, file_name = ast_to_func(root_wrapper.node, dyfunc)
return static_func, dygraph_to_static
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import astor import astor
import gast import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
...@@ -36,8 +35,6 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -36,8 +35,6 @@ class BasicApiTransformer(gast.NodeTransformer):
self.root = wrapper_root.node self.root = wrapper_root.node
self.class_node_dict = {} self.class_node_dict = {}
# Used for transformation of data feed
self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {} self.name_to_tensor_shape = {}
def transform(self): def transform(self):
...@@ -69,7 +66,6 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -69,7 +66,6 @@ class BasicApiTransformer(gast.NodeTransformer):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign` # Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node): if is_to_variable(node):
self._update_feed_dict(node)
node = to_assign_node(node) node = to_assign_node(node)
return node return node
...@@ -106,24 +102,3 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -106,24 +102,3 @@ class BasicApiTransformer(gast.NodeTransformer):
return True return True
# TODO: node.value is not dygraph class # TODO: node.value is not dygraph class
return False return False
def _update_feed_dict(self, node):
assert isinstance(node, gast.Call)
value_node = None
for kw in node.keywords:
if kw.arg == 'value':
value_node = kw.value # eg: `a` for "value=a "
if not value_node:
value_node = node.args[0]
if not isinstance(value_node, gast.Name):
return
else:
var_name = value_node.id
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
self.feed_name_to_arg_id[
feed_var_name] = var_name # eg: "a_0" : "a"
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
...@@ -28,15 +28,16 @@ from paddle.fluid.dygraph import layers ...@@ -28,15 +28,16 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
__all__ = ['ProgramTranslator', 'convert_function_with_cache'] __all__ = ['ProgramTranslator', 'convert_to_static']
logger = logging.getLogger("fluid") logger = logging.getLogger("fluid")
...@@ -47,43 +48,76 @@ class FunctionCache(object): ...@@ -47,43 +48,76 @@ class FunctionCache(object):
""" """
def __init__(self): def __init__(self):
self._dycode_to_static_func = dict() # Caches the converted static functions. {dygraph_func: static_func}
self._static_func_to_transformer = dict() self._converted_static_func_caches = dict()
# Caches the converted ast node for same source code. {source_code: ast_root}
self._code_to_ast_caches = dict()
self._dygraph_to_static = DygraphToStaticAst()
def get_or_cache_func(self, func): def convert_with_cache(self, func):
# code = self._get_dedent_code_string(func) """
static_func = self._dycode_to_static_func.get(func, None) Returns the cached static function or converts it when first encounters the function.
"""
# If hit cache, return it directly.
static_func = self._converted_static_func_caches.get(func, None)
if static_func is None: if static_func is None:
static_func, dygraph_to_static_transformer = convert_to_static(func) static_func = self._convert(func)
self._dycode_to_static_func[func] = static_func self._converted_static_func_caches[func] = static_func
self._static_func_to_transformer[
func] = dygraph_to_static_transformer
return static_func return static_func
def get_transformer(self, func): def _convert(self, func):
return self._static_func_to_transformer.get(func, None) """
Converts dygraph function into static function. For two functions with same dedent code,
the second function will reuse the transformed ast node of previous one.
For example:
# A.py
def foo(x, y):
z = x + y
return z
# B.py
def foo(x, y):
z = x + y
return z
If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo
to speed up the conversion.
"""
# Note: In Python2, it will raise OSError when inspect function
# with decorator directly and function.__wrapped__ holds the actual function.
func = getattr(func, '__wrapped__', func)
source_code = func_to_source_code(func)
if source_code in self._code_to_ast_caches:
root_wrapper = self._code_to_ast_caches[source_code]
else:
root = gast.parse(source_code)
root_wrapper = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root_wrapper
def _get_dedent_code_string(self, func): # Get static function from AST
raw_code = inspect.getsource(func) static_func, file_name = ast_to_func(root_wrapper.node, func)
dedent_code = textwrap.dedent(raw_code) return static_func
return dedent_code
def exist(self, func): def exist(self, func):
return self._dycode_to_static_func.get(func, None) is not None return func in self._converted_static_func_caches
_CACHE_LOCK = threading.Lock() _CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache() _FUNCTION_CACHE = FunctionCache()
def convert_function_with_cache(dygraph_func): def convert_to_static(function):
""" """
Transforms function of dygraph into static function using the cache mechanism. Transforms function of dygraph into static function using the cache mechanism.
Args:
function(callable): The function with dygraph layers that will be converted into static layers.
""" """
with _CACHE_LOCK: with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func) static_func = _FUNCTION_CACHE.convert_with_cache(function)
return static_func return static_func
...@@ -202,7 +236,7 @@ class ConcreteProgram(object): ...@@ -202,7 +236,7 @@ class ConcreteProgram(object):
""" """
# Transforms dygraph function into static function and caches it. # Transforms dygraph function into static function and caches it.
dygraph_function = func_spec.dyfunc dygraph_function = func_spec.dyfunc
static_func = convert_function_with_cache(dygraph_function) static_func = convert_to_static(dygraph_function)
main_program, startup_program = framework.Program(), framework.Program() main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program # Note: The random seed should be synchronized into cached program
...@@ -461,7 +495,7 @@ class ProgramTranslator(object): ...@@ -461,7 +495,7 @@ class ProgramTranslator(object):
"just return dygraph output.") "just return dygraph output.")
return dygraph_func return dygraph_func
static_func = convert_function_with_cache(dygraph_func) static_func = convert_to_static(dygraph_func)
return static_func return static_func
def get_program(self, dygraph_func, *args, **kwargs): def get_program(self, dygraph_func, *args, **kwargs):
......
...@@ -24,6 +24,7 @@ import inspect ...@@ -24,6 +24,7 @@ import inspect
import os import os
import six import six
import tempfile import tempfile
import textwrap
from paddle.fluid import unique_name from paddle.fluid import unique_name
...@@ -407,9 +408,24 @@ def recover_globals_attribute(src_obj, dst_obj): ...@@ -407,9 +408,24 @@ def recover_globals_attribute(src_obj, dst_obj):
dst_globals[k] = v dst_globals[k] = v
def func_to_source_code(function, dedent=True):
"""
Transforms function into raw string of source code.
"""
if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError(
"The type of 'function' should be a function or method, but received {}.".
format(type(function).__name__))
source_code = inspect.getsource(function)
if dedent:
source_code = textwrap.dedent(source_code)
return source_code
def ast_to_source_code(ast_node): def ast_to_source_code(ast_node):
""" """
Transformers ast node into source code. Transforms ast node into source code.
""" """
if not isinstance(ast_node, (gast.AST, ast.AST)): if not isinstance(ast_node, (gast.AST, ast.AST)):
raise TypeError( raise TypeError(
......
...@@ -22,7 +22,7 @@ import paddle.fluid as fluid ...@@ -22,7 +22,7 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
from test_fetch_feed import Pool2D, Linear from test_fetch_feed import Pool2D, Linear
...@@ -116,9 +116,9 @@ def simple_func(x): ...@@ -116,9 +116,9 @@ def simple_func(x):
class TestConvertWithCache(unittest.TestCase): class TestConvertWithCache(unittest.TestCase):
def test_cache(self): def test_cache(self):
static_func = convert_function_with_cache(simple_func) static_func = convert_to_static(simple_func)
# Get transformed function from cache. # Get transformed function from cache.
cached_func = convert_function_with_cache(simple_func) cached_func = convert_to_static(simple_func)
self.assertTrue(id(static_func), id(cached_func)) self.assertTrue(id(static_func), id(cached_func))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册