提交 2f76932d 编写于 作者: F fengjiayi

enhance DataFeeder

上级 94096ae5
......@@ -3,7 +3,7 @@ import core
import numpy
import six.moves as six
from framework import Variable
from framework import Variable, default_main_program
__all__ = ['DataFeeder']
......@@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object):
class DataFeeder(object):
def __init__(self, feed_list, place):
def __init__(self, feed_list, place, program=None):
self.feed_dtypes = []
self.feed_names = []
self.feed_shapes = []
self.feed_lod_level = []
if program is None:
program = default_main_program()
for each_var in feed_list:
if isinstance(each_var, basestring):
each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable")
self.feed_dtypes.append(each_var.dtype)
......
......@@ -188,7 +188,7 @@ def save_inference_model(dirname,
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
feeded_var_names = [feeded_var_names]
target_vars = [target_vars]
else:
if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册