From 4977d99b055a5ba50baa620e9b24da5c71201a51 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 5 Mar 2018 10:50:46 +0800 Subject: [PATCH] add program cache for executor --- python/paddle/fluid/executor.py | 102 +++++++++++++++++--------------- 1 file changed, 55 insertions(+), 47 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index e5fb0e5d6..984c4c73a 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -177,6 +177,7 @@ class Executor(object): # TODO(dzhwinter) : only use the first place self.executor = core.Executor(act_places[0]) self.places = places + self.program_caches = dict() def aslodtensor(self, data): def accumulate(data): @@ -240,56 +241,63 @@ class Executor(object): if scope is None: scope = global_scope() - program = program.clone() - global_block = program.global_block() + program_cache_key = str(feed.keys() + fetch_list) + program_cache = self.program_caches.get(program_cache_key, None) - if feed_var_name in global_block.vars: - feed_var = global_block.var(feed_var_name) - else: - feed_var = global_block.create_var( - name=feed_var_name, - type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) + if program_cache is None: + program_cache = program.clone() + self.program_caches[program_cache_key] = program_cache - if fetch_var_name in global_block.vars: - fetch_var = global_block.var(fetch_var_name) - else: - fetch_var = global_block.create_var( - name=fetch_var_name, - type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) - - if not has_feed_operators(global_block, feed, feed_var_name): - for i, name in enumerate(feed): - out = global_block.var(name) - global_block.prepend_op( - type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) - - for op in global_block.ops: - if op.desc.type() == 'feed': - feed_target_name = op.desc.output('Out')[0] - cur_feed = feed[feed_target_name] - if not isinstance(cur_feed, core.LoDTensor): - cur_feed = self.aslodtensor(cur_feed) - idx = op.desc.attr('col') - core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + global_block = program_cache.global_block() + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) else: - break - - if not has_fetch_operators(global_block, fetch_list, fetch_var_name): - for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, str), ( - "Wrong type for fetch_list[%s]: %s" % (i, type(var))) - global_block.append_op( - type='fetch', - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) - - self.executor.run(program.desc, scope, 0, True, True) + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + if not has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + for op in global_block.ops: + if op.desc.type() == 'feed': + feed_target_name = op.desc.output('Out')[0] + cur_feed = feed[feed_target_name] + if not isinstance(cur_feed, core.LoDTensor): + cur_feed = self.aslodtensor(cur_feed) + idx = op.desc.attr('col') + core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + else: + break + + if not has_fetch_operators(global_block, fetch_list, + fetch_var_name): + for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance(var, str), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + self.executor.run(program_cache.desc, scope, 0, True, True) outs = [ core.get_fetch_variable(scope, fetch_var_name, i) for i in xrange(len(fetch_list)) -- GitLab