提交 72c13278 编写于 作者: D dangqingqing

follow comments

上级 bb7db754
......@@ -62,8 +62,8 @@ class DataFeeder(DataProviderConverter):
self.reader_dict = reader_dict
for each in data_types:
self.input_names.append(each[0])
self.input_types.append(each[1])
assert isinstance(each[1], data_type.InputType)
self.input_types.append(each[1])
DataProviderConverter.__init__(self, self.input_types)
def convert(self, dat, argument=None):
......@@ -88,24 +88,16 @@ class DataFeeder(DataProviderConverter):
:type argument: swig_paddle.Arguments
"""
if argument is None:
argument = swig_paddle.Arguments.createArguments(0)
assert isinstance(argument, swig_paddle.Arguments)
argument.resize(len(self.input_types))
scanners = [
DataProviderConverter.create_scanner(i, each_type)
for i, each_type in enumerate(self.input_types)
]
for each_sample in dat:
for name, scanner in zip(self.input_names, scanners):
scanner.scan(each_sample[self.reader_dict[name]])
for scanner in scanners:
scanner.finish_scan(argument)
def reorder_data(data):
retv = []
for each in data:
reorder = []
for name in self.input_names:
reorder.append(each[self.reader_dict[name]])
retv.append(reorder)
return retv
return argument
return DataProviderConverter.convert(self, reorder_data(dat), argument)
def __call__(self, dat, argument=None):
return self.convert(dat, argument)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册