提交 2f76932d 编写于 作者: F fengjiayi

enhance DataFeeder

上级 94096ae5
...@@ -3,7 +3,7 @@ import core ...@@ -3,7 +3,7 @@ import core
import numpy import numpy
import six.moves as six import six.moves as six
from framework import Variable from framework import Variable, default_main_program
__all__ = ['DataFeeder'] __all__ = ['DataFeeder']
...@@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object): ...@@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object):
class DataFeeder(object): class DataFeeder(object):
def __init__(self, feed_list, place): def __init__(self, feed_list, place, program=None):
self.feed_dtypes = [] self.feed_dtypes = []
self.feed_names = [] self.feed_names = []
self.feed_shapes = [] self.feed_shapes = []
self.feed_lod_level = [] self.feed_lod_level = []
if program is None:
program = default_main_program()
for each_var in feed_list: for each_var in feed_list:
if isinstance(each_var, basestring):
each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable): if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable") raise TypeError("Feed list should contain a list of variable")
self.feed_dtypes.append(each_var.dtype) self.feed_dtypes.append(each_var.dtype)
......
...@@ -188,7 +188,7 @@ def save_inference_model(dirname, ...@@ -188,7 +188,7 @@ def save_inference_model(dirname,
raise ValueError("'feed_var_names' should be a list of str.") raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable): if isinstance(target_vars, Variable):
feeded_var_names = [feeded_var_names] target_vars = [target_vars]
else: else:
if not (bool(target_vars) and all( if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)): isinstance(var, Variable) for var in target_vars)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册