提交 1ee77841 编写于 作者: Q qiaolongfei

add get_program_cache_key function

上级 b63901f5
......@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
return tensor
def get_program_cache_key(feed, fetch_list):
feed_var_names = feed.keys()
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
else:
raise TypeError(str(var) + " should be Variable or str")
fetch_var_names = map(to_name_str, fetch_list)
return str(feed_var_names + fetch_var_names)
class Executor(object):
def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple):
......@@ -232,12 +248,13 @@ class Executor(object):
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables that user want to get after program run. Note: the executor will run all
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used.
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
:param fetch_list: a list of variable that user want to get, run will return them according to this list.
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
to this list.
:param feed_var_name: the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
......@@ -247,6 +264,8 @@ class Executor(object):
"""
if feed is None:
feed = {}
if not isinstance(feed, dict):
raise TypeError("feed should be a map")
if fetch_list is None:
fetch_list = []
......@@ -260,10 +279,7 @@ class Executor(object):
scope = global_scope()
program_cache = None
feed_var_names = feed.keys()
fetch_var_names = [var.desc.name() for var in fetch_list]
program_cache_key = str(feed_var_names + fetch_var_names)
program_cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache:
# find program cache by cache_key
......
......@@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase):
exe.run(fluid.default_startup_program())
for step in range(10):
lr_val, = exe.run(fluid.default_main_program(),
feed=[],
feed={},
fetch_list=[decayed_lr])
python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册