未验证 提交 d96acc33 编写于 作者: C Chen Weihang 提交者: GitHub

Refine dygraph DataLoader implementation (#21634)

* refine dygraph dataloader & polish related code, test=develop

* refine code based review comment, test=develop
上级 5eec8cf5
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "Python.h" #include "Python.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -207,6 +209,31 @@ void BindReader(py::module *module) { ...@@ -207,6 +209,31 @@ void BindReader(py::module *module) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("read_next_list", &MultiDeviceFeedReader::ReadNextList, .def("read_next_list", &MultiDeviceFeedReader::ReadNextList,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("read_next_var_list",
[](MultiDeviceFeedReader &self) {
auto result_list = self.ReadNextList();
auto &tensor_list = result_list[0];
std::vector<std::shared_ptr<imperative::VarBase>> var_list;
var_list.reserve(tensor_list.size());
auto func = [](framework::LoDTensor &lod_tensor) {
std::string act_name =
imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_var");
auto new_var = std::make_shared<imperative::VarBase>(act_name);
new_var->SetPersistable(false);
new_var->SetType(framework::proto::VarType::LOD_TENSOR);
new_var->SetDataType(lod_tensor.type());
auto *tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
*tensor = std::move(lod_tensor);
return new_var;
};
for (auto &tensor : tensor_list) {
var_list.emplace_back(func(tensor));
}
return var_list;
},
py::call_guard<py::gil_scoped_release>())
.def("reset", &MultiDeviceFeedReader::Reset, .def("reset", &MultiDeviceFeedReader::Reset,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -524,41 +524,29 @@ class NumpyToLoDTensorConverter(object): ...@@ -524,41 +524,29 @@ class NumpyToLoDTensorConverter(object):
return t return t
class ListTensorProvider(object): class DygraphListTensorProvider(object):
def __init__(self, generator, places): def __init__(self, generator, place):
self.generator = generator self.generator = generator
self.converters = [] self.converters = []
self.places = [] if place:
if places:
if not isinstance(places, (list, tuple)):
places = [places]
assert len(
places) == 1, "dygraph mode CAN NOT specify multiple places."
for place in places:
if isinstance(place, (core.CUDAPlace, core.CPUPlace)): if isinstance(place, (core.CUDAPlace, core.CPUPlace)):
self.places.append(place) self.place = place
else: else:
raise ValueError( raise ValueError("Please specify a valid place values \
"Please specify a valid place values such as core.CPUPlace or core.CUDAPlace" such as core.CPUPlace or core.CUDAPlace")
) else:
if len(self.places) == 0: self.place = _current_expected_place()
self.places.append(_current_expected_place())
def _read_data(self, iterable, place):
def _readData(self, iterable, places): for items in iterable:
for place, each_sample in six.moves.zip(places, iterable): if len(self.converters) < len(items):
for item in each_sample: for _ in items:
if len(self.converters) < len(item):
for i in item:
self.converters.append(NumpyToLoDTensorConverter(place)) self.converters.append(NumpyToLoDTensorConverter(place))
for each_converter, each_slot in six.moves.zip(self.converters, for each_converter, each_slot in six.moves.zip(self.converters,
item): items):
each_converter.feed(each_slot) each_converter.feed(each_slot)
yield [c.done() for c in self.converters] yield [c.done() for c in self.converters]
def __call__(self): def __call__(self):
item = []
for batch in self.generator(): for batch in self.generator():
item.append(batch) yield list(self._read_data(batch, self.place))
if len(item) == len(self.places):
yield list(self._readData(item, self.places))
item = []
...@@ -21,7 +21,7 @@ import threading ...@@ -21,7 +21,7 @@ import threading
import paddle import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
from .executor import global_scope from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator from .unique_name import UniqueNameGenerator
import logging import logging
...@@ -442,14 +442,13 @@ class GeneratorLoader(DataLoaderBase): ...@@ -442,14 +442,13 @@ class GeneratorLoader(DataLoaderBase):
def __next__(self): def __next__(self):
try: try:
if not in_dygraph_mode(): if in_dygraph_mode():
return self._reader.read_next_var_list()
else:
if self._return_list: if self._return_list:
return self._reader.read_next_list() return self._reader.read_next_list()
else: else:
return self._reader.read_next() return self._reader.read_next()
else:
ret = self._reader.read_next_list()[0]
return [dygraph.base.to_variable(np.array(v)) for v in ret]
except StopIteration: except StopIteration:
self._queue.close() self._queue.close()
self._reset() self._reset()
...@@ -517,7 +516,12 @@ class GeneratorLoader(DataLoaderBase): ...@@ -517,7 +516,12 @@ class GeneratorLoader(DataLoaderBase):
drop_last=True, drop_last=True,
places=None): places=None):
assert batch_size > 0, "batch_size must be larger than 0" assert batch_size > 0, "batch_size must be larger than 0"
if not in_dygraph_mode(): if in_dygraph_mode():
self.set_sample_list_generator(
paddle.batch(
reader, batch_size=batch_size, drop_last=drop_last),
places=places)
else:
has_lod = False has_lod = False
for f in self._feed_list: for f in self._feed_list:
if f.lod_level != 0: if f.lod_level != 0:
...@@ -537,15 +541,16 @@ class GeneratorLoader(DataLoaderBase): ...@@ -537,15 +541,16 @@ class GeneratorLoader(DataLoaderBase):
generator=reader, generator=reader,
drop_last=drop_last) drop_last=drop_last)
self.set_batch_generator(reader, places=places) self.set_batch_generator(reader, places=places)
else:
self.set_sample_list_generator(
paddle.batch(
reader, batch_size=batch_size, drop_last=drop_last),
places=places)
return self return self
def set_sample_list_generator(self, reader, places=None): def set_sample_list_generator(self, reader, places=None):
if not in_dygraph_mode(): if in_dygraph_mode():
provider = DygraphListTensorProvider(reader, places)
def __tensor_reader_impl__():
for slots in provider():
yield slots[0]
else:
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
feeder = DataFeeder( feeder = DataFeeder(
feed_list=self._feed_list, place=core.CPUPlace()) feed_list=self._feed_list, place=core.CPUPlace())
...@@ -555,12 +560,6 @@ class GeneratorLoader(DataLoaderBase): ...@@ -555,12 +560,6 @@ class GeneratorLoader(DataLoaderBase):
def __tensor_reader_impl__(): def __tensor_reader_impl__():
for slots in paddle_reader(): for slots in paddle_reader():
yield [slots[var.name] for var in self._feed_list] yield [slots[var.name] for var in self._feed_list]
else:
provider = ListTensorProvider(reader, places)
def __tensor_reader_impl__():
for slots in provider():
yield slots[0]
self.set_batch_generator(__tensor_reader_impl__, places) self.set_batch_generator(__tensor_reader_impl__, places)
return self return self
...@@ -571,8 +570,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -571,8 +570,8 @@ class GeneratorLoader(DataLoaderBase):
assert places is not None, "Places cannot be None when DataLoader is iterable" assert places is not None, "Places cannot be None when DataLoader is iterable"
self._places = _convert_places(places) self._places = _convert_places(places)
if in_dygraph_mode(): if in_dygraph_mode():
assert len(self._places assert len(self._places) == 1, \
) == 1, "Number of places must be 1 in dygraph mode" "Number of places must be 1 in dygraph mode"
else: else:
if places is not None: if places is not None:
logging.info( logging.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册