未验证 提交 767acc6c 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #8744 from jacquesqiao/add-program-cache-for-executor

Add program cache for executor.py
...@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True): ...@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
return tensor return tensor
def get_program_cache_key(feed, fetch_list):
feed_var_names = feed.keys()
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
else:
raise TypeError(str(var) + " should be Variable or str")
fetch_var_names = map(to_name_str, fetch_list)
return str(feed_var_names + fetch_var_names)
class Executor(object): class Executor(object):
def __init__(self, places): def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple): if not isinstance(places, list) and not isinstance(places, tuple):
...@@ -177,6 +193,7 @@ class Executor(object): ...@@ -177,6 +193,7 @@ class Executor(object):
# TODO(dzhwinter) : only use the first place # TODO(dzhwinter) : only use the first place
self.executor = core.Executor(act_places[0]) self.executor = core.Executor(act_places[0])
self.places = places self.places = places
self.program_caches = dict()
def aslodtensor(self, data): def aslodtensor(self, data):
def accumulate(data): def accumulate(data):
...@@ -225,9 +242,30 @@ class Executor(object): ...@@ -225,9 +242,30 @@ class Executor(object):
feed_var_name='feed', feed_var_name='feed',
fetch_var_name='fetch', fetch_var_name='fetch',
scope=None, scope=None,
return_numpy=True): return_numpy=True,
use_program_cache=False):
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used.
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
to this list.
:param feed_var_name: the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
:param return_numpy: if convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
:return: result according to fetch_list.
"""
if feed is None: if feed is None:
feed = {} feed = {}
if not isinstance(feed, dict):
raise TypeError("feed should be a map")
if fetch_list is None: if fetch_list is None:
fetch_list = [] fetch_list = []
...@@ -240,8 +278,23 @@ class Executor(object): ...@@ -240,8 +278,23 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
program = program.clone() program_cache = None
global_block = program.global_block() program_cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache:
# find program cache by cache_key
program_cache = self.program_caches.get(program_cache_key, None)
# TODO(qiao): Should check program_cache and program are exactly the same.
else:
self.program_caches.pop(program_cache_key, None)
if program_cache is None:
program_cache = program.clone()
if use_program_cache:
self.program_caches[program_cache_key] = program_cache
global_block = program_cache.global_block()
if feed_var_name in global_block.vars: if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name) feed_var = global_block.var(feed_var_name)
...@@ -259,6 +312,7 @@ class Executor(object): ...@@ -259,6 +312,7 @@ class Executor(object):
type=core.VarDesc.VarType.FETCH_LIST, type=core.VarDesc.VarType.FETCH_LIST,
persistable=True) persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name): if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed): for i, name in enumerate(feed):
out = global_block.var(name) out = global_block.var(name)
...@@ -268,7 +322,20 @@ class Executor(object): ...@@ -268,7 +322,20 @@ class Executor(object):
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'col': i}) attrs={'col': i})
for op in global_block.ops: # append fetch_operators
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed': if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0] feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name] cur_feed = feed[feed_target_name]
...@@ -279,17 +346,7 @@ class Executor(object): ...@@ -279,17 +346,7 @@ class Executor(object):
else: else:
break break
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): self.executor.run(program_cache.desc, scope, 0, True, True)
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program.desc, scope, 0, True, True)
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list)) for i in xrange(len(fetch_list))
......
...@@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for step in range(10): for step in range(10):
lr_val, = exe.run(fluid.default_main_program(), lr_val, = exe.run(fluid.default_main_program(),
feed=[], feed={},
fetch_list=[decayed_lr]) fetch_list=[decayed_lr])
python_decayed_lr = python_decay_fn( python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs) global_step=float(step), **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册