diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index ceff04286c6ac3f51f4eb6bad0b840770686c98a..4490f2bf153f672464ec8bca2a44109c9fe0dd04 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True): return tensor +def get_program_cache_key(feed, fetch_list): + feed_var_names = feed.keys() + + def to_name_str(var): + if isinstance(var, Variable): + return var.desc.name() + elif isinstance(var, str): + return var + else: + raise TypeError(str(var) + " should be Variable or str") + + fetch_var_names = map(to_name_str, fetch_list) + + return str(feed_var_names + fetch_var_names) + + class Executor(object): def __init__(self, places): if not isinstance(places, list) and not isinstance(places, tuple): @@ -232,12 +248,13 @@ class Executor(object): Python executor takes a program, add feed operators and fetch operators to this program according to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides - the variables that user want to get after program run. Note: the executor will run all + the variables(or names) that user want to get after program run. Note: the executor will run all operators in the program but not only the operators dependent by the fetch_list :param program: the program that need to run, if not provied, then default_main_program will be used. :param feed: feed variable map, e.g. {"image": ImageData, "label": LableData} - :param fetch_list: a list of variable that user want to get, run will return them according to this list. + :param fetch_list: a list of variable or variable names that user want to get, run will return them according + to this list. :param feed_var_name: the name for the input variable of feed Operator. :param fetch_var_name: the name for the output variable of feed Operator. :param scope: the scope used to run this program, you can switch it to different scope. default is global_scope @@ -247,6 +264,8 @@ class Executor(object): """ if feed is None: feed = {} + if not isinstance(feed, dict): + raise TypeError("feed should be a map") if fetch_list is None: fetch_list = [] @@ -260,10 +279,7 @@ class Executor(object): scope = global_scope() program_cache = None - - feed_var_names = feed.keys() - fetch_var_names = [var.desc.name() for var in fetch_list] - program_cache_key = str(feed_var_names + fetch_var_names) + program_cache_key = get_program_cache_key(feed, fetch_list) if use_program_cache: # find program cache by cache_key diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_decay.py b/python/paddle/fluid/tests/unittests/test_learning_rate_decay.py index 5c221a0325b6cdc27ec22e5a8b02ae8eec9f6d80..8954a8619578aa1e05eb856c2d1de43152f9c9e5 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_decay.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_decay.py @@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase): exe.run(fluid.default_startup_program()) for step in range(10): lr_val, = exe.run(fluid.default_main_program(), - feed=[], + feed={}, fetch_list=[decayed_lr]) python_decayed_lr = python_decay_fn( global_step=float(step), **kwargs)