diff --git a/python/paddle/v2/data_feeder.py b/python/paddle/v2/data_feeder.py index 74ee112f46efd78ff0ffcf95bad2580cc582fd29..632f1760ab5ff244f6d8d300d185b388d6643368 100644 --- a/python/paddle/v2/data_feeder.py +++ b/python/paddle/v2/data_feeder.py @@ -62,8 +62,8 @@ class DataFeeder(DataProviderConverter): self.reader_dict = reader_dict for each in data_types: self.input_names.append(each[0]) - self.input_types.append(each[1]) assert isinstance(each[1], data_type.InputType) + self.input_types.append(each[1]) DataProviderConverter.__init__(self, self.input_types) def convert(self, dat, argument=None): @@ -88,24 +88,16 @@ class DataFeeder(DataProviderConverter): :type argument: swig_paddle.Arguments """ - if argument is None: - argument = swig_paddle.Arguments.createArguments(0) - assert isinstance(argument, swig_paddle.Arguments) - argument.resize(len(self.input_types)) - - scanners = [ - DataProviderConverter.create_scanner(i, each_type) - for i, each_type in enumerate(self.input_types) - ] - - for each_sample in dat: - for name, scanner in zip(self.input_names, scanners): - scanner.scan(each_sample[self.reader_dict[name]]) - - for scanner in scanners: - scanner.finish_scan(argument) + def reorder_data(data): + retv = [] + for each in data: + reorder = [] + for name in self.input_names: + reorder.append(each[self.reader_dict[name]]) + retv.append(reorder) + return retv - return argument + return DataProviderConverter.convert(self, reorder_data(dat), argument) def __call__(self, dat, argument=None): return self.convert(dat, argument)