From a647bcd355353a35ba247266e88452fb9faec8fe Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 26 Mar 2020 13:11:22 +0800 Subject: [PATCH] Add convert_function_with_cache in dygraph_to_static_func (#23190) * add unittest test=develop * add function cache test=develop --- .../dygraph_to_static/program_translator.py | 17 +++++++++++++++-- .../dygraph_to_static/test_cache_program.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index a3f9e0b9c23..3636d3c4752 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -29,7 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.framework import in_dygraph_mode -__all__ = ['ProgramTranslator'] +__all__ = ['ProgramTranslator', 'convert_function_with_cache'] class FunctionCache(object): @@ -66,6 +66,19 @@ class FunctionCache(object): self._get_dedent_code_string(func), None) is not None +_CACHE_LOCK = threading.Lock() +_FUNCTION_CACHE = FunctionCache() + + +def convert_function_with_cache(dygraph_func): + """ + Transform function of dygraph into static function using the cache mechanism. + """ + with _CACHE_LOCK: + static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func) + return static_func + + def synchronized(func): func.__lock__ = threading.Lock() @@ -273,7 +286,7 @@ class ProgramTranslator(object): "The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." " Please use it in static mode.") return dygraph_func - static_func, ast_transformer = convert_to_static(dygraph_func) + static_func = convert_function_with_cache(dygraph_func) return static_func def get_code(self, dygraph_func): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py index b3027094528..588dd0a5f1f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py @@ -21,7 +21,7 @@ from collections import Counter import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator -from paddle.fluid.dygraph.jit import dygraph_to_static_output +from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache from test_fetch_feed import Pool2D, Linear @@ -111,5 +111,19 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): static_loss)) +def simple_func(x): + inputs = fluid.dygraph.to_variable(x) + mean = fluid.layers.mean(inputs) + return mean + + +class TestConvertWithCache(unittest.TestCase): + def test_cache(self): + static_func = convert_function_with_cache(simple_func) + # Get transformed function from cache. + cached_func = convert_function_with_cache(simple_func) + self.assertTrue(id(static_func), id(cached_func)) + + if __name__ == '__main__': unittest.main() -- GitLab