diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index a3f9e0b9c236b7092311ad36ab3a6e72e0ec2145..3636d3c475243d6f8e5635a26493c0ae845137bd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -29,7 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.framework import in_dygraph_mode -__all__ = ['ProgramTranslator'] +__all__ = ['ProgramTranslator', 'convert_function_with_cache'] class FunctionCache(object): @@ -66,6 +66,19 @@ class FunctionCache(object): self._get_dedent_code_string(func), None) is not None +_CACHE_LOCK = threading.Lock() +_FUNCTION_CACHE = FunctionCache() + + +def convert_function_with_cache(dygraph_func): + """ + Transform function of dygraph into static function using the cache mechanism. + """ + with _CACHE_LOCK: + static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func) + return static_func + + def synchronized(func): func.__lock__ = threading.Lock() @@ -273,7 +286,7 @@ class ProgramTranslator(object): "The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." " Please use it in static mode.") return dygraph_func - static_func, ast_transformer = convert_to_static(dygraph_func) + static_func = convert_function_with_cache(dygraph_func) return static_func def get_code(self, dygraph_func): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py index b30270945289bfb98440b9c9ccb896db17c798ed..588dd0a5f1f9c8bb5fc8bad21f9e4ad82cf032af 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py @@ -21,7 +21,7 @@ from collections import Counter import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator -from paddle.fluid.dygraph.jit import dygraph_to_static_output +from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache from test_fetch_feed import Pool2D, Linear @@ -111,5 +111,19 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): static_loss)) +def simple_func(x): + inputs = fluid.dygraph.to_variable(x) + mean = fluid.layers.mean(inputs) + return mean + + +class TestConvertWithCache(unittest.TestCase): + def test_cache(self): + static_func = convert_function_with_cache(simple_func) + # Get transformed function from cache. + cached_func = convert_function_with_cache(simple_func) + self.assertTrue(id(static_func), id(cached_func)) + + if __name__ == '__main__': unittest.main()