diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index d82a925c757e7dadce844c025aaa85953f55a017..4848d23d27877b711a4ae495bebc9171459aab27 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -502,51 +502,3 @@ class DataFeeder(object): "not implemented") return __reader_creator__ - - -class NumpyToLoDTensorConverter(object): - def __init__(self, place): - self.place = place - self.data = [] - self._reset() - - def _reset(self): - self.data = [] - - def feed(self, data): - self.data.append(data) - - def done(self): - arr = numpy.array(self.data) - t = core.LoDTensor() - t.set(arr, self.place) - self._reset() - return t - - -class DygraphListTensorProvider(object): - def __init__(self, generator, place): - self.generator = generator - self.converters = [] - if place: - if isinstance(place, (core.CUDAPlace, core.CPUPlace)): - self.place = place - else: - raise ValueError("Please specify a valid place values \ - such as core.CPUPlace or core.CUDAPlace") - else: - self.place = _current_expected_place() - - def _read_data(self, iterable, place): - for items in iterable: - if len(self.converters) < len(items): - for _ in items: - self.converters.append(NumpyToLoDTensorConverter(place)) - for each_converter, each_slot in six.moves.zip(self.converters, - items): - each_converter.feed(each_slot) - yield [c.done() for c in self.converters] - - def __call__(self): - for batch in self.generator(): - yield list(self._read_data(batch, self.place)) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 45764af6147417e144b89fac84f8f7c2f7332aa4..57ffc62617525abf34fc41aab23e2de55e4c8528 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -21,7 +21,7 @@ import threading import paddle from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .executor import global_scope -from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider +from .data_feeder import DataFeeder, BatchedTensorProvider from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .unique_name import UniqueNameGenerator import logging @@ -545,11 +545,17 @@ class GeneratorLoader(DataLoaderBase): def set_sample_list_generator(self, reader, places=None): if in_dygraph_mode(): - provider = DygraphListTensorProvider(reader, places) def __tensor_reader_impl__(): - for slots in provider(): - yield slots[0] + for batch in reader(): + slots = [] + for items in batch: + for i, item in enumerate(items): + if len(slots) < len(items): + slots.append([item]) + else: + slots[i].append(item) + yield slots else: with program_guard(Program(), Program()): feeder = DataFeeder(