From 2f76932d7cee1aa0d32ff2ad6b195b9de678635a Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 27 Dec 2017 17:22:49 +0800 Subject: [PATCH] enhance DataFeeder --- python/paddle/v2/fluid/data_feeder.py | 8 ++++++-- python/paddle/v2/fluid/io.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/data_feeder.py b/python/paddle/v2/fluid/data_feeder.py index 30a542af21..24036c3e75 100644 --- a/python/paddle/v2/fluid/data_feeder.py +++ b/python/paddle/v2/fluid/data_feeder.py @@ -3,7 +3,7 @@ import core import numpy import six.moves as six -from framework import Variable +from framework import Variable, default_main_program __all__ = ['DataFeeder'] @@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object): class DataFeeder(object): - def __init__(self, feed_list, place): + def __init__(self, feed_list, place, program=None): self.feed_dtypes = [] self.feed_names = [] self.feed_shapes = [] self.feed_lod_level = [] + if program is None: + program = default_main_program() for each_var in feed_list: + if isinstance(each_var, basestring): + each_var = program.block(0).var(each_var) if not isinstance(each_var, Variable): raise TypeError("Feed list should contain a list of variable") self.feed_dtypes.append(each_var.dtype) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 69a732fc45..c47ce82aba 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -188,7 +188,7 @@ def save_inference_model(dirname, raise ValueError("'feed_var_names' should be a list of str.") if isinstance(target_vars, Variable): - feeded_var_names = [feeded_var_names] + target_vars = [target_vars] else: if not (bool(target_vars) and all( isinstance(var, Variable) for var in target_vars)): -- GitLab