From d96acc336342f1c1216795d1e21693b70ee17a7e Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 10 Dec 2019 21:56:55 +0800 Subject: [PATCH] Refine dygraph DataLoader implementation (#21634) * refine dygraph dataloader & polish related code, test=develop * refine code based review comment, test=develop --- paddle/fluid/pybind/reader_py.cc | 27 +++++++++++++++ python/paddle/fluid/data_feeder.py | 54 ++++++++++++------------------ python/paddle/fluid/reader.py | 39 +++++++++++---------- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc index 4f15c574bd2..2d39af2faa5 100644 --- a/paddle/fluid/pybind/reader_py.cc +++ b/paddle/fluid/pybind/reader_py.cc @@ -22,6 +22,8 @@ #include "Python.h" #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/platform/place.h" @@ -207,6 +209,31 @@ void BindReader(py::module *module) { py::call_guard()) .def("read_next_list", &MultiDeviceFeedReader::ReadNextList, py::call_guard()) + .def("read_next_var_list", + [](MultiDeviceFeedReader &self) { + auto result_list = self.ReadNextList(); + auto &tensor_list = result_list[0]; + std::vector> var_list; + var_list.reserve(tensor_list.size()); + auto func = [](framework::LoDTensor &lod_tensor) { + std::string act_name = + imperative::GetCurrentTracer()->GenerateUniqueName( + "generated_var"); + auto new_var = std::make_shared(act_name); + new_var->SetPersistable(false); + new_var->SetType(framework::proto::VarType::LOD_TENSOR); + new_var->SetDataType(lod_tensor.type()); + auto *tensor = + new_var->MutableVar()->GetMutable(); + *tensor = std::move(lod_tensor); + return new_var; + }; + for (auto &tensor : tensor_list) { + var_list.emplace_back(func(tensor)); + } + return var_list; + }, + py::call_guard()) .def("reset", &MultiDeviceFeedReader::Reset, py::call_guard()); diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 493eb5c18ff..d82a925c757 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -524,41 +524,29 @@ class NumpyToLoDTensorConverter(object): return t -class ListTensorProvider(object): - def __init__(self, generator, places): +class DygraphListTensorProvider(object): + def __init__(self, generator, place): self.generator = generator self.converters = [] - self.places = [] - if places: - if not isinstance(places, (list, tuple)): - places = [places] - assert len( - places) == 1, "dygraph mode CAN NOT specify multiple places." - for place in places: - if isinstance(place, (core.CUDAPlace, core.CPUPlace)): - self.places.append(place) - else: - raise ValueError( - "Please specify a valid place values such as core.CPUPlace or core.CUDAPlace" - ) - if len(self.places) == 0: - self.places.append(_current_expected_place()) - - def _readData(self, iterable, places): - for place, each_sample in six.moves.zip(places, iterable): - for item in each_sample: - if len(self.converters) < len(item): - for i in item: - self.converters.append(NumpyToLoDTensorConverter(place)) - for each_converter, each_slot in six.moves.zip(self.converters, - item): - each_converter.feed(each_slot) - yield [c.done() for c in self.converters] + if place: + if isinstance(place, (core.CUDAPlace, core.CPUPlace)): + self.place = place + else: + raise ValueError("Please specify a valid place values \ + such as core.CPUPlace or core.CUDAPlace") + else: + self.place = _current_expected_place() + + def _read_data(self, iterable, place): + for items in iterable: + if len(self.converters) < len(items): + for _ in items: + self.converters.append(NumpyToLoDTensorConverter(place)) + for each_converter, each_slot in six.moves.zip(self.converters, + items): + each_converter.feed(each_slot) + yield [c.done() for c in self.converters] def __call__(self): - item = [] for batch in self.generator(): - item.append(batch) - if len(item) == len(self.places): - yield list(self._readData(item, self.places)) - item = [] + yield list(self._read_data(batch, self.place)) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 81d75870714..45764af6147 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -21,7 +21,7 @@ import threading import paddle from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .executor import global_scope -from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider +from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .unique_name import UniqueNameGenerator import logging @@ -442,14 +442,13 @@ class GeneratorLoader(DataLoaderBase): def __next__(self): try: - if not in_dygraph_mode(): + if in_dygraph_mode(): + return self._reader.read_next_var_list() + else: if self._return_list: return self._reader.read_next_list() else: return self._reader.read_next() - else: - ret = self._reader.read_next_list()[0] - return [dygraph.base.to_variable(np.array(v)) for v in ret] except StopIteration: self._queue.close() self._reset() @@ -517,7 +516,12 @@ class GeneratorLoader(DataLoaderBase): drop_last=True, places=None): assert batch_size > 0, "batch_size must be larger than 0" - if not in_dygraph_mode(): + if in_dygraph_mode(): + self.set_sample_list_generator( + paddle.batch( + reader, batch_size=batch_size, drop_last=drop_last), + places=places) + else: has_lod = False for f in self._feed_list: if f.lod_level != 0: @@ -537,15 +541,16 @@ class GeneratorLoader(DataLoaderBase): generator=reader, drop_last=drop_last) self.set_batch_generator(reader, places=places) - else: - self.set_sample_list_generator( - paddle.batch( - reader, batch_size=batch_size, drop_last=drop_last), - places=places) return self def set_sample_list_generator(self, reader, places=None): - if not in_dygraph_mode(): + if in_dygraph_mode(): + provider = DygraphListTensorProvider(reader, places) + + def __tensor_reader_impl__(): + for slots in provider(): + yield slots[0] + else: with program_guard(Program(), Program()): feeder = DataFeeder( feed_list=self._feed_list, place=core.CPUPlace()) @@ -555,12 +560,6 @@ class GeneratorLoader(DataLoaderBase): def __tensor_reader_impl__(): for slots in paddle_reader(): yield [slots[var.name] for var in self._feed_list] - else: - provider = ListTensorProvider(reader, places) - - def __tensor_reader_impl__(): - for slots in provider(): - yield slots[0] self.set_batch_generator(__tensor_reader_impl__, places) return self @@ -571,8 +570,8 @@ class GeneratorLoader(DataLoaderBase): assert places is not None, "Places cannot be None when DataLoader is iterable" self._places = _convert_places(places) if in_dygraph_mode(): - assert len(self._places - ) == 1, "Number of places must be 1 in dygraph mode" + assert len(self._places) == 1, \ + "Number of places must be 1 in dygraph mode" else: if places is not None: logging.info( -- GitLab