未验证 提交 68999b6c 编写于 作者: C Chen Weihang 提交者: GitHub

simplify dygraph data loader code, test=develop (#21722)

上级 9c481e12
...@@ -502,51 +502,3 @@ class DataFeeder(object): ...@@ -502,51 +502,3 @@ class DataFeeder(object):
"not implemented") "not implemented")
return __reader_creator__ return __reader_creator__
class NumpyToLoDTensorConverter(object):
def __init__(self, place):
self.place = place
self.data = []
self._reset()
def _reset(self):
self.data = []
def feed(self, data):
self.data.append(data)
def done(self):
arr = numpy.array(self.data)
t = core.LoDTensor()
t.set(arr, self.place)
self._reset()
return t
class DygraphListTensorProvider(object):
def __init__(self, generator, place):
self.generator = generator
self.converters = []
if place:
if isinstance(place, (core.CUDAPlace, core.CPUPlace)):
self.place = place
else:
raise ValueError("Please specify a valid place values \
such as core.CPUPlace or core.CUDAPlace")
else:
self.place = _current_expected_place()
def _read_data(self, iterable, place):
for items in iterable:
if len(self.converters) < len(items):
for _ in items:
self.converters.append(NumpyToLoDTensorConverter(place))
for each_converter, each_slot in six.moves.zip(self.converters,
items):
each_converter.feed(each_slot)
yield [c.done() for c in self.converters]
def __call__(self):
for batch in self.generator():
yield list(self._read_data(batch, self.place))
...@@ -21,7 +21,7 @@ import threading ...@@ -21,7 +21,7 @@ import threading
import paddle import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
from .executor import global_scope from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider from .data_feeder import DataFeeder, BatchedTensorProvider
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator from .unique_name import UniqueNameGenerator
import logging import logging
...@@ -545,11 +545,17 @@ class GeneratorLoader(DataLoaderBase): ...@@ -545,11 +545,17 @@ class GeneratorLoader(DataLoaderBase):
def set_sample_list_generator(self, reader, places=None): def set_sample_list_generator(self, reader, places=None):
if in_dygraph_mode(): if in_dygraph_mode():
provider = DygraphListTensorProvider(reader, places)
def __tensor_reader_impl__(): def __tensor_reader_impl__():
for slots in provider(): for batch in reader():
yield slots[0] slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
yield slots
else: else:
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
feeder = DataFeeder( feeder = DataFeeder(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册