提交 a8fd6d58 编写于 作者: Q qiaolongfei

add use_program_cache to executor.run

上级 0876fc14
...@@ -226,7 +226,19 @@ class Executor(object): ...@@ -226,7 +226,19 @@ class Executor(object):
feed_var_name='feed', feed_var_name='feed',
fetch_var_name='fetch', fetch_var_name='fetch',
scope=None, scope=None,
return_numpy=True): return_numpy=True,
use_program_cache=False):
"""
:param program: the program that need to run
:param feed: feed variable list
:param fetch_list: fetch variable list
:param feed_var_name: feed_var_name default to 'feed'
:param fetch_var_name: fetch_var_name default to 'fetch'
:param scope: the scope used to run this program, you can switch it to different scope.
:param return_numpy: convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
:return:
"""
if feed is None: if feed is None:
feed = {} feed = {}
if fetch_list is None: if fetch_list is None:
...@@ -244,7 +256,7 @@ class Executor(object): ...@@ -244,7 +256,7 @@ class Executor(object):
program_cache_key = str(feed.keys() + fetch_list) program_cache_key = str(feed.keys() + fetch_list)
program_cache = self.program_caches.get(program_cache_key, None) program_cache = self.program_caches.get(program_cache_key, None)
if program_cache is None: if program_cache is None or not use_program_cache:
program_cache = program.clone() program_cache = program.clone()
self.program_caches[program_cache_key] = program_cache self.program_caches[program_cache_key] = program_cache
...@@ -266,6 +278,7 @@ class Executor(object): ...@@ -266,6 +278,7 @@ class Executor(object):
type=core.VarDesc.VarType.FETCH_LIST, type=core.VarDesc.VarType.FETCH_LIST,
persistable=True) persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name): if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed): for i, name in enumerate(feed):
out = global_block.var(name) out = global_block.var(name)
...@@ -275,6 +288,7 @@ class Executor(object): ...@@ -275,6 +288,7 @@ class Executor(object):
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'col': i}) attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, if not has_fetch_operators(global_block, fetch_list,
fetch_var_name): fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册