提交 4977d99b 编写于 作者: Q qiaolongfei

add program cache for executor

上级 01654212
...@@ -177,6 +177,7 @@ class Executor(object): ...@@ -177,6 +177,7 @@ class Executor(object):
# TODO(dzhwinter) : only use the first place # TODO(dzhwinter) : only use the first place
self.executor = core.Executor(act_places[0]) self.executor = core.Executor(act_places[0])
self.places = places self.places = places
self.program_caches = dict()
def aslodtensor(self, data): def aslodtensor(self, data):
def accumulate(data): def accumulate(data):
...@@ -240,56 +241,63 @@ class Executor(object): ...@@ -240,56 +241,63 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
program = program.clone() program_cache_key = str(feed.keys() + fetch_list)
global_block = program.global_block() program_cache = self.program_caches.get(program_cache_key, None)
if feed_var_name in global_block.vars: if program_cache is None:
feed_var = global_block.var(feed_var_name) program_cache = program.clone()
else: self.program_caches[program_cache_key] = program_cache
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars: global_block = program_cache.global_block()
fetch_var = global_block.var(fetch_var_name)
else: if feed_var_name in global_block.vars:
fetch_var = global_block.create_var( feed_var = global_block.var(feed_var_name)
name=fetch_var_name, else:
type=core.VarDesc.VarType.FETCH_LIST, feed_var = global_block.create_var(
persistable=True) name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
if not has_feed_operators(global_block, feed, feed_var_name): persistable=True)
for i, name in enumerate(feed):
out = global_block.var(name) if fetch_var_name in global_block.vars:
global_block.prepend_op( fetch_var = global_block.var(fetch_var_name)
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else: else:
break fetch_var = global_block.create_var(
name=fetch_var_name,
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): type=core.VarDesc.VarType.FETCH_LIST,
for i, var in enumerate(fetch_list): persistable=True)
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) if not has_feed_operators(global_block, feed, feed_var_name):
global_block.append_op( for i, name in enumerate(feed):
type='fetch', out = global_block.var(name)
inputs={'X': [var]}, global_block.prepend_op(
outputs={'Out': [fetch_var]}, type='feed',
attrs={'col': i}) inputs={'X': [feed_var]},
outputs={'Out': [out]},
self.executor.run(program.desc, scope, 0, True, True) attrs={'col': i})
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program_cache.desc, scope, 0, True, True)
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list)) for i in xrange(len(fetch_list))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册