未验证 提交 68999b6c 编写于 作者: C Chen Weihang 提交者: GitHub

simplify dygraph data loader code, test=develop (#21722)

上级 9c481e12
......@@ -502,51 +502,3 @@ class DataFeeder(object):
"not implemented")
return __reader_creator__
class NumpyToLoDTensorConverter(object):
def __init__(self, place):
self.place = place
self.data = []
self._reset()
def _reset(self):
self.data = []
def feed(self, data):
self.data.append(data)
def done(self):
arr = numpy.array(self.data)
t = core.LoDTensor()
t.set(arr, self.place)
self._reset()
return t
class DygraphListTensorProvider(object):
def __init__(self, generator, place):
self.generator = generator
self.converters = []
if place:
if isinstance(place, (core.CUDAPlace, core.CPUPlace)):
self.place = place
else:
raise ValueError("Please specify a valid place values \
such as core.CPUPlace or core.CUDAPlace")
else:
self.place = _current_expected_place()
def _read_data(self, iterable, place):
for items in iterable:
if len(self.converters) < len(items):
for _ in items:
self.converters.append(NumpyToLoDTensorConverter(place))
for each_converter, each_slot in six.moves.zip(self.converters,
items):
each_converter.feed(each_slot)
yield [c.done() for c in self.converters]
def __call__(self):
for batch in self.generator():
yield list(self._read_data(batch, self.place))
......@@ -21,7 +21,7 @@ import threading
import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider
from .data_feeder import DataFeeder, BatchedTensorProvider
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator
import logging
......@@ -545,11 +545,17 @@ class GeneratorLoader(DataLoaderBase):
def set_sample_list_generator(self, reader, places=None):
if in_dygraph_mode():
provider = DygraphListTensorProvider(reader, places)
def __tensor_reader_impl__():
for slots in provider():
yield slots[0]
for batch in reader():
slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
yield slots
else:
with program_guard(Program(), Program()):
feeder = DataFeeder(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册