未验证 提交 e8134e87 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Refine caches of converting function (#25085)

* Refine caches of converting func test=develop

* fix if statement test=develop

* refine cache code test=develop

* rm unuse import statement test=develop

* Polish code comment test=develop
上级 1eb9ee24
......@@ -19,8 +19,6 @@ from __future__ import print_function
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import gast
import inspect
import textwrap
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
......@@ -34,10 +32,9 @@ from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransf
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
__all__ = ['DygraphToStaticAst', 'convert_to_static']
__all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
......@@ -54,7 +51,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
)
self.decorate_func_name = None
self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
......@@ -65,7 +61,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform basic api of dygraph to static graph and get feed_name_to_arg_name
basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans.transform()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform Tensor.shape into fluid.layers.shape(Tensor)
TensorShapeTransformer(node_wrapper).transform()
......@@ -97,8 +92,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
self.decorate_func_name = node.name
for idx, arg in enumerate(node.args.args):
self.arg_name_to_idx[arg.id] = idx
self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
......@@ -132,30 +125,3 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Should consider BaseAPITransformer which add new module name in Yamei's PR.
assert self.decorate_func_name, "decorate_func_name shall not be None."
return self.decorate_func_name
def get_feed_name_to_idx(self):
feed_name_to_idx = {}
for feed_name, arg_name in self.feed_name_to_arg_name.items():
feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name)
return feed_name_to_idx
def convert_to_static(dyfunc):
"""
Converts dygraph function into static function.
"""
# Get AST from dygraph function
# Note: In Python2, it will raise OSError when inspect function
# with decorator directly and dyfunc.__wrapped__ holds the actual function.
dyfunc = getattr(dyfunc, '__wrapped__', dyfunc)
raw_code = inspect.getsource(dyfunc)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
# Get static_func from AST
static_func, file_name = ast_to_func(root_wrapper.node, dyfunc)
return static_func, dygraph_to_static
......@@ -15,7 +15,6 @@
import astor
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
......@@ -36,8 +35,6 @@ class BasicApiTransformer(gast.NodeTransformer):
self.root = wrapper_root.node
self.class_node_dict = {}
# Used for transformation of data feed
self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {}
def transform(self):
......@@ -69,7 +66,6 @@ class BasicApiTransformer(gast.NodeTransformer):
assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
self._update_feed_dict(node)
node = to_assign_node(node)
return node
......@@ -106,24 +102,3 @@ class BasicApiTransformer(gast.NodeTransformer):
return True
# TODO: node.value is not dygraph class
return False
def _update_feed_dict(self, node):
assert isinstance(node, gast.Call)
value_node = None
for kw in node.keywords:
if kw.arg == 'value':
value_node = kw.value # eg: `a` for "value=a "
if not value_node:
value_node = node.args[0]
if not isinstance(value_node, gast.Name):
return
else:
var_name = value_node.id
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
self.feed_name_to_arg_id[
feed_var_name] = var_name # eg: "a_0" : "a"
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
......@@ -28,15 +28,16 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
__all__ = ['ProgramTranslator', 'convert_to_static']
logger = logging.getLogger("fluid")
......@@ -47,43 +48,76 @@ class FunctionCache(object):
"""
def __init__(self):
self._dycode_to_static_func = dict()
self._static_func_to_transformer = dict()
# Caches the converted static functions. {dygraph_func: static_func}
self._converted_static_func_caches = dict()
# Caches the converted ast node for same source code. {source_code: ast_root}
self._code_to_ast_caches = dict()
self._dygraph_to_static = DygraphToStaticAst()
def get_or_cache_func(self, func):
# code = self._get_dedent_code_string(func)
static_func = self._dycode_to_static_func.get(func, None)
def convert_with_cache(self, func):
"""
Returns the cached static function or converts it when first encounters the function.
"""
# If hit cache, return it directly.
static_func = self._converted_static_func_caches.get(func, None)
if static_func is None:
static_func, dygraph_to_static_transformer = convert_to_static(func)
self._dycode_to_static_func[func] = static_func
self._static_func_to_transformer[
func] = dygraph_to_static_transformer
static_func = self._convert(func)
self._converted_static_func_caches[func] = static_func
return static_func
def get_transformer(self, func):
return self._static_func_to_transformer.get(func, None)
def _convert(self, func):
"""
Converts dygraph function into static function. For two functions with same dedent code,
the second function will reuse the transformed ast node of previous one.
For example:
# A.py
def foo(x, y):
z = x + y
return z
# B.py
def foo(x, y):
z = x + y
return z
If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo
to speed up the conversion.
"""
# Note: In Python2, it will raise OSError when inspect function
# with decorator directly and function.__wrapped__ holds the actual function.
func = getattr(func, '__wrapped__', func)
source_code = func_to_source_code(func)
if source_code in self._code_to_ast_caches:
root_wrapper = self._code_to_ast_caches[source_code]
else:
root = gast.parse(source_code)
root_wrapper = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root_wrapper
def _get_dedent_code_string(self, func):
raw_code = inspect.getsource(func)
dedent_code = textwrap.dedent(raw_code)
return dedent_code
# Get static function from AST
static_func, file_name = ast_to_func(root_wrapper.node, func)
return static_func
def exist(self, func):
return self._dycode_to_static_func.get(func, None) is not None
return func in self._converted_static_func_caches
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()
def convert_function_with_cache(dygraph_func):
def convert_to_static(function):
"""
Transforms function of dygraph into static function using the cache mechanism.
Args:
function(callable): The function with dygraph layers that will be converted into static layers.
"""
with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
static_func = _FUNCTION_CACHE.convert_with_cache(function)
return static_func
......@@ -202,7 +236,7 @@ class ConcreteProgram(object):
"""
# Transforms dygraph function into static function and caches it.
dygraph_function = func_spec.dyfunc
static_func = convert_function_with_cache(dygraph_function)
static_func = convert_to_static(dygraph_function)
main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program
......@@ -461,7 +495,7 @@ class ProgramTranslator(object):
"just return dygraph output.")
return dygraph_func
static_func = convert_function_with_cache(dygraph_func)
static_func = convert_to_static(dygraph_func)
return static_func
def get_program(self, dygraph_func, *args, **kwargs):
......
......@@ -24,6 +24,7 @@ import inspect
import os
import six
import tempfile
import textwrap
from paddle.fluid import unique_name
......@@ -407,9 +408,24 @@ def recover_globals_attribute(src_obj, dst_obj):
dst_globals[k] = v
def func_to_source_code(function, dedent=True):
"""
Transforms function into raw string of source code.
"""
if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError(
"The type of 'function' should be a function or method, but received {}.".
format(type(function).__name__))
source_code = inspect.getsource(function)
if dedent:
source_code = textwrap.dedent(source_code)
return source_code
def ast_to_source_code(ast_node):
"""
Transformers ast node into source code.
Transforms ast node into source code.
"""
if not isinstance(ast_node, (gast.AST, ast.AST)):
raise TypeError(
......
......@@ -22,7 +22,7 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache
from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
from test_fetch_feed import Pool2D, Linear
......@@ -116,9 +116,9 @@ def simple_func(x):
class TestConvertWithCache(unittest.TestCase):
def test_cache(self):
static_func = convert_function_with_cache(simple_func)
static_func = convert_to_static(simple_func)
# Get transformed function from cache.
cached_func = convert_function_with_cache(simple_func)
cached_func = convert_to_static(simple_func)
self.assertTrue(id(static_func), id(cached_func))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册