diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index e5fb0e5d628b2df14355aec2718cf46aa641b6cf..4490f2bf153f672464ec8bca2a44109c9fe0dd04 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True): return tensor +def get_program_cache_key(feed, fetch_list): + feed_var_names = feed.keys() + + def to_name_str(var): + if isinstance(var, Variable): + return var.desc.name() + elif isinstance(var, str): + return var + else: + raise TypeError(str(var) + " should be Variable or str") + + fetch_var_names = map(to_name_str, fetch_list) + + return str(feed_var_names + fetch_var_names) + + class Executor(object): def __init__(self, places): if not isinstance(places, list) and not isinstance(places, tuple): @@ -177,6 +193,7 @@ class Executor(object): # TODO(dzhwinter) : only use the first place self.executor = core.Executor(act_places[0]) self.places = places + self.program_caches = dict() def aslodtensor(self, data): def accumulate(data): @@ -225,9 +242,30 @@ class Executor(object): feed_var_name='feed', fetch_var_name='fetch', scope=None, - return_numpy=True): + return_numpy=True, + use_program_cache=False): + """ Run program by this Executor. Feed data by feed map, fetch result by fetch_list. + + Python executor takes a program, add feed operators and fetch operators to this program according + to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides + the variables(or names) that user want to get after program run. Note: the executor will run all + operators in the program but not only the operators dependent by the fetch_list + + :param program: the program that need to run, if not provied, then default_main_program will be used. + :param feed: feed variable map, e.g. {"image": ImageData, "label": LableData} + :param fetch_list: a list of variable or variable names that user want to get, run will return them according + to this list. + :param feed_var_name: the name for the input variable of feed Operator. + :param fetch_var_name: the name for the output variable of feed Operator. + :param scope: the scope used to run this program, you can switch it to different scope. default is global_scope + :param return_numpy: if convert the fetched tensor to numpy + :param use_program_cache: set use_program_cache to true if program not changed compare to the last step. + :return: result according to fetch_list. + """ if feed is None: feed = {} + if not isinstance(feed, dict): + raise TypeError("feed should be a map") if fetch_list is None: fetch_list = [] @@ -240,35 +278,64 @@ class Executor(object): if scope is None: scope = global_scope() - program = program.clone() - global_block = program.global_block() + program_cache = None + program_cache_key = get_program_cache_key(feed, fetch_list) - if feed_var_name in global_block.vars: - feed_var = global_block.var(feed_var_name) + if use_program_cache: + # find program cache by cache_key + program_cache = self.program_caches.get(program_cache_key, None) + # TODO(qiao): Should check program_cache and program are exactly the same. else: - feed_var = global_block.create_var( - name=feed_var_name, - type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) + self.program_caches.pop(program_cache_key, None) - if fetch_var_name in global_block.vars: - fetch_var = global_block.var(fetch_var_name) - else: - fetch_var = global_block.create_var( - name=fetch_var_name, - type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) - - if not has_feed_operators(global_block, feed, feed_var_name): - for i, name in enumerate(feed): - out = global_block.var(name) - global_block.prepend_op( - type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) - - for op in global_block.ops: + if program_cache is None: + program_cache = program.clone() + + if use_program_cache: + self.program_caches[program_cache_key] = program_cache + + global_block = program_cache.global_block() + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) + else: + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + # prepend feed operators + if not has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + # append fetch_operators + if not has_fetch_operators(global_block, fetch_list, + fetch_var_name): + for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance(var, str), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + # feed var to framework + for op in program_cache.global_block().ops: if op.desc.type() == 'feed': feed_target_name = op.desc.output('Out')[0] cur_feed = feed[feed_target_name] @@ -279,17 +346,7 @@ class Executor(object): else: break - if not has_fetch_operators(global_block, fetch_list, fetch_var_name): - for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, str), ( - "Wrong type for fetch_list[%s]: %s" % (i, type(var))) - global_block.append_op( - type='fetch', - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) - - self.executor.run(program.desc, scope, 0, True, True) + self.executor.run(program_cache.desc, scope, 0, True, True) outs = [ core.get_fetch_variable(scope, fetch_var_name, i) for i in xrange(len(fetch_list)) diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index ab25bfffaa45020cc854e44b593776e90638cf72..e75a6529e9fa265121ba187f3ed6bc0273c058d7 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase): exe.run(fluid.default_startup_program()) for step in range(10): lr_val, = exe.run(fluid.default_main_program(), - feed=[], + feed={}, fetch_list=[decayed_lr]) python_decayed_lr = python_decay_fn( global_step=float(step), **kwargs)