未验证 提交 a647bcd3 编写于 作者: A Aurelius84 提交者: GitHub

Add convert_function_with_cache in dygraph_to_static_func (#23190)

* add unittest test=develop

* add function cache test=develop
上级 bd809033
......@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.framework import in_dygraph_mode
__all__ = ['ProgramTranslator']
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
class FunctionCache(object):
......@@ -66,6 +66,19 @@ class FunctionCache(object):
self._get_dedent_code_string(func), None) is not None
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()
def convert_function_with_cache(dygraph_func):
"""
Transform function of dygraph into static function using the cache mechanism.
"""
with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
return static_func
def synchronized(func):
func.__lock__ = threading.Lock()
......@@ -273,7 +286,7 @@ class ProgramTranslator(object):
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode.")
return dygraph_func
static_func, ast_transformer = convert_to_static(dygraph_func)
static_func = convert_function_with_cache(dygraph_func)
return static_func
def get_code(self, dygraph_func):
......
......@@ -21,7 +21,7 @@ from collections import Counter
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output
from paddle.fluid.dygraph.dygraph_to_static import convert_function_with_cache
from test_fetch_feed import Pool2D, Linear
......@@ -111,5 +111,19 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
static_loss))
def simple_func(x):
inputs = fluid.dygraph.to_variable(x)
mean = fluid.layers.mean(inputs)
return mean
class TestConvertWithCache(unittest.TestCase):
def test_cache(self):
static_func = convert_function_with_cache(simple_func)
# Get transformed function from cache.
cached_func = convert_function_with_cache(simple_func)
self.assertTrue(id(static_func), id(cached_func))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册