diff --git a/python/paddle/v2/fluid/data_feeder.py b/python/paddle/v2/fluid/data_feeder.py index 30a542af212926c93381aade426e25f2117e4662..24036c3e75b9594ba58cccb02825ab8020d1e107 100644 --- a/python/paddle/v2/fluid/data_feeder.py +++ b/python/paddle/v2/fluid/data_feeder.py @@ -3,7 +3,7 @@ import core import numpy import six.moves as six -from framework import Variable +from framework import Variable, default_main_program __all__ = ['DataFeeder'] @@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object): class DataFeeder(object): - def __init__(self, feed_list, place): + def __init__(self, feed_list, place, program=None): self.feed_dtypes = [] self.feed_names = [] self.feed_shapes = [] self.feed_lod_level = [] + if program is None: + program = default_main_program() for each_var in feed_list: + if isinstance(each_var, basestring): + each_var = program.block(0).var(each_var) if not isinstance(each_var, Variable): raise TypeError("Feed list should contain a list of variable") self.feed_dtypes.append(each_var.dtype) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 69a732fc45a1946f260cdd9a9c2da150b87c3ddd..c47ce82aba7fa5ac42ac26cd25fa3ebc93e96cb2 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -188,7 +188,7 @@ def save_inference_model(dirname, raise ValueError("'feed_var_names' should be a list of str.") if isinstance(target_vars, Variable): - feeded_var_names = [feeded_var_names] + target_vars = [target_vars] else: if not (bool(target_vars) and all( isinstance(var, Variable) for var in target_vars)):