提交 72c13278 编写于 作者: D dangqingqing

follow comments

上级 bb7db754
...@@ -62,8 +62,8 @@ class DataFeeder(DataProviderConverter): ...@@ -62,8 +62,8 @@ class DataFeeder(DataProviderConverter):
self.reader_dict = reader_dict self.reader_dict = reader_dict
for each in data_types: for each in data_types:
self.input_names.append(each[0]) self.input_names.append(each[0])
self.input_types.append(each[1])
assert isinstance(each[1], data_type.InputType) assert isinstance(each[1], data_type.InputType)
self.input_types.append(each[1])
DataProviderConverter.__init__(self, self.input_types) DataProviderConverter.__init__(self, self.input_types)
def convert(self, dat, argument=None): def convert(self, dat, argument=None):
...@@ -88,24 +88,16 @@ class DataFeeder(DataProviderConverter): ...@@ -88,24 +88,16 @@ class DataFeeder(DataProviderConverter):
:type argument: swig_paddle.Arguments :type argument: swig_paddle.Arguments
""" """
if argument is None: def reorder_data(data):
argument = swig_paddle.Arguments.createArguments(0) retv = []
assert isinstance(argument, swig_paddle.Arguments) for each in data:
argument.resize(len(self.input_types)) reorder = []
for name in self.input_names:
scanners = [ reorder.append(each[self.reader_dict[name]])
DataProviderConverter.create_scanner(i, each_type) retv.append(reorder)
for i, each_type in enumerate(self.input_types) return retv
]
for each_sample in dat:
for name, scanner in zip(self.input_names, scanners):
scanner.scan(each_sample[self.reader_dict[name]])
for scanner in scanners:
scanner.finish_scan(argument)
return argument return DataProviderConverter.convert(self, reorder_data(dat), argument)
def __call__(self, dat, argument=None): def __call__(self, dat, argument=None):
return self.convert(dat, argument) return self.convert(dat, argument)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册