diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 28ea7d1e287d6855cb877c55eabb19aafcb390a5..0815f61432f189bb8153ce692e68c3194206b55e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -19,8 +19,6 @@ from __future__ import print_function # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ import gast -import inspect -import textwrap from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer @@ -34,10 +32,9 @@ from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransf from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name -__all__ = ['DygraphToStaticAst', 'convert_to_static'] +__all__ = ['DygraphToStaticAst'] DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func'] @@ -54,7 +51,6 @@ class DygraphToStaticAst(gast.NodeTransformer): self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root( ) self.decorate_func_name = None - self.arg_name_to_idx = {} self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root @@ -65,7 +61,6 @@ class DygraphToStaticAst(gast.NodeTransformer): # Transform basic api of dygraph to static graph and get feed_name_to_arg_name basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans.transform() - self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform Tensor.shape into fluid.layers.shape(Tensor) TensorShapeTransformer(node_wrapper).transform() @@ -97,8 +92,6 @@ class DygraphToStaticAst(gast.NodeTransformer): def visit_FunctionDef(self, node): if self.decorate_func_name is None: self.decorate_func_name = node.name - for idx, arg in enumerate(node.args.args): - self.arg_name_to_idx[arg.id] = idx self.generic_visit(node) # Remove the decorated name of dygraph_to_static @@ -132,30 +125,3 @@ class DygraphToStaticAst(gast.NodeTransformer): # Should consider BaseAPITransformer which add new module name in Yamei's PR. assert self.decorate_func_name, "decorate_func_name shall not be None." return self.decorate_func_name - - def get_feed_name_to_idx(self): - feed_name_to_idx = {} - for feed_name, arg_name in self.feed_name_to_arg_name.items(): - feed_name_to_idx[feed_name] = self.arg_name_to_idx.get(arg_name) - return feed_name_to_idx - - -def convert_to_static(dyfunc): - """ - Converts dygraph function into static function. - """ - # Get AST from dygraph function - # Note: In Python2, it will raise OSError when inspect function - # with decorator directly and dyfunc.__wrapped__ holds the actual function. - dyfunc = getattr(dyfunc, '__wrapped__', dyfunc) - raw_code = inspect.getsource(dyfunc) - code = textwrap.dedent(raw_code) - root = gast.parse(code) - - # Transform AST - dygraph_to_static = DygraphToStaticAst() - root_wrapper = dygraph_to_static.get_static_ast(root) - - # Get static_func from AST - static_func, file_name = ast_to_func(root_wrapper.node, dyfunc) - return static_func, dygraph_to_static diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 681c744e6de51a8950b0a0b3d810b04f61150732..82f39ffd080ec803beca4e60695204b707f48210 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -15,7 +15,6 @@ import astor import gast -from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func @@ -36,8 +35,6 @@ class BasicApiTransformer(gast.NodeTransformer): self.root = wrapper_root.node self.class_node_dict = {} - # Used for transformation of data feed - self.feed_name_to_arg_id = {} self.name_to_tensor_shape = {} def transform(self): @@ -69,7 +66,6 @@ class BasicApiTransformer(gast.NodeTransformer): assert isinstance(node, gast.Call) # Replace API `to_variable` with `fluid.layers.assign` if is_to_variable(node): - self._update_feed_dict(node) node = to_assign_node(node) return node @@ -106,24 +102,3 @@ class BasicApiTransformer(gast.NodeTransformer): return True # TODO: node.value is not dygraph class return False - - def _update_feed_dict(self, node): - assert isinstance(node, gast.Call) - - value_node = None - for kw in node.keywords: - if kw.arg == 'value': - value_node = kw.value # eg: `a` for "value=a " - if not value_node: - value_node = node.args[0] - - if not isinstance(value_node, gast.Name): - return - else: - var_name = value_node.id - feed_var_name = unique_name.generate(var_name) # eg: "a_0" - self.feed_name_to_arg_id[ - feed_var_name] = var_name # eg: "a_0" : "a" - - def get_feed_name_to_arg_id(self): - return self.feed_name_to_arg_id diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index a74a9ff331053e3304b616d61559fcfee5d64ac3..463a968e56afa08c5d8159fbb35f6222c862e1ff 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -28,15 +28,16 @@ from paddle.fluid.dygraph import layers from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.dygraph.base import switch_to_static_graph -from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.dygraph.base import param_guard from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from -__all__ = ['ProgramTranslator', 'convert_function_with_cache'] +__all__ = ['ProgramTranslator', 'convert_to_static'] logger = logging.getLogger("fluid") @@ -47,43 +48,76 @@ class FunctionCache(object): """ def __init__(self): - self._dycode_to_static_func = dict() - self._static_func_to_transformer = dict() + # Caches the converted static functions. {dygraph_func: static_func} + self._converted_static_func_caches = dict() + # Caches the converted ast node for same source code. {source_code: ast_root} + self._code_to_ast_caches = dict() + self._dygraph_to_static = DygraphToStaticAst() - def get_or_cache_func(self, func): - # code = self._get_dedent_code_string(func) - static_func = self._dycode_to_static_func.get(func, None) + def convert_with_cache(self, func): + """ + Returns the cached static function or converts it when first encounters the function. + """ + # If hit cache, return it directly. + static_func = self._converted_static_func_caches.get(func, None) if static_func is None: - static_func, dygraph_to_static_transformer = convert_to_static(func) - self._dycode_to_static_func[func] = static_func - self._static_func_to_transformer[ - func] = dygraph_to_static_transformer + static_func = self._convert(func) + self._converted_static_func_caches[func] = static_func return static_func - def get_transformer(self, func): - return self._static_func_to_transformer.get(func, None) + def _convert(self, func): + """ + Converts dygraph function into static function. For two functions with same dedent code, + the second function will reuse the transformed ast node of previous one. + + For example: + # A.py + def foo(x, y): + z = x + y + return z + + # B.py + def foo(x, y): + z = x + y + return z + + If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo + to speed up the conversion. + """ + # Note: In Python2, it will raise OSError when inspect function + # with decorator directly and function.__wrapped__ holds the actual function. + func = getattr(func, '__wrapped__', func) + source_code = func_to_source_code(func) + if source_code in self._code_to_ast_caches: + root_wrapper = self._code_to_ast_caches[source_code] + else: + root = gast.parse(source_code) + root_wrapper = self._dygraph_to_static.get_static_ast(root) + self._code_to_ast_caches[source_code] = root_wrapper - def _get_dedent_code_string(self, func): - raw_code = inspect.getsource(func) - dedent_code = textwrap.dedent(raw_code) - return dedent_code + # Get static function from AST + static_func, file_name = ast_to_func(root_wrapper.node, func) + return static_func def exist(self, func): - return self._dycode_to_static_func.get(func, None) is not None + return func in self._converted_static_func_caches _CACHE_LOCK = threading.Lock() _FUNCTION_CACHE = FunctionCache() -def convert_function_with_cache(dygraph_func): +def convert_to_static(function): """ Transforms function of dygraph into static function using the cache mechanism. + + Args: + function(callable): The function with dygraph layers that will be converted into static layers. """ with _CACHE_LOCK: - static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func) + static_func = _FUNCTION_CACHE.convert_with_cache(function) return static_func @@ -202,7 +236,7 @@ class ConcreteProgram(object): """ # Transforms dygraph function into static function and caches it. dygraph_function = func_spec.dyfunc - static_func = convert_function_with_cache(dygraph_function) + static_func = convert_to_static(dygraph_function) main_program, startup_program = framework.Program(), framework.Program() # Note: The random seed should be synchronized into cached program @@ -461,7 +495,7 @@ class ProgramTranslator(object): "just return dygraph output.") return dygraph_func - static_func = convert_function_with_cache(dygraph_func) + static_func = convert_to_static(dygraph_func) return static_func def get_program(self, dygraph_func, *args, **kwargs): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 0e49fb2b2e78cec07ff88daee17e006be13d6675..4b489c7d2847dc49df86645f69c61573a4adfbcb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -24,6 +24,7 @@ import inspect import os import six import tempfile +import textwrap from paddle.fluid import unique_name @@ -407,9 +408,24 @@ def recover_globals_attribute(src_obj, dst_obj): dst_globals[k] = v +def func_to_source_code(function, dedent=True): + """ + Transforms function into raw string of source code. + """ + if not (inspect.isfunction(function) or inspect.ismethod(function)): + raise TypeError( + "The type of 'function' should be a function or method, but received {}.". + format(type(function).__name__)) + source_code = inspect.getsource(function) + if dedent: + source_code = textwrap.dedent(source_code) + + return source_code + + def ast_to_source_code(ast_node): """ - Transformers ast node into source code. + Transforms ast node into source code. """ if not isinstance(ast_node, (gast.AST, ast.AST)): raise TypeError( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py index 192e1a6b114b09ec0f4ef645f56e853c0ecb5d9e..8e35dd78457bb59bb4882bc1deeb23539f47012a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py @@ -22,7 +22,7 @@ import paddle.fluid as fluid from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator -from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache +from paddle.fluid.dygraph.dygraph_to_static import convert_to_static from test_fetch_feed import Pool2D, Linear @@ -116,9 +116,9 @@ def simple_func(x): class TestConvertWithCache(unittest.TestCase): def test_cache(self): - static_func = convert_function_with_cache(simple_func) + static_func = convert_to_static(simple_func) # Get transformed function from cache. - cached_func = convert_function_with_cache(simple_func) + cached_func = convert_to_static(simple_func) self.assertTrue(id(static_func), id(cached_func))