未验证 提交 d96acc33 编写于 作者: C Chen Weihang 提交者: GitHub

Refine dygraph DataLoader implementation (#21634)

* refine dygraph dataloader & polish related code, test=develop

* refine code based review comment, test=develop
上级 5eec8cf5
......@@ -22,6 +22,8 @@
#include "Python.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
......@@ -207,6 +209,31 @@ void BindReader(py::module *module) {
py::call_guard<py::gil_scoped_release>())
.def("read_next_list", &MultiDeviceFeedReader::ReadNextList,
py::call_guard<py::gil_scoped_release>())
.def("read_next_var_list",
[](MultiDeviceFeedReader &self) {
auto result_list = self.ReadNextList();
auto &tensor_list = result_list[0];
std::vector<std::shared_ptr<imperative::VarBase>> var_list;
var_list.reserve(tensor_list.size());
auto func = [](framework::LoDTensor &lod_tensor) {
std::string act_name =
imperative::GetCurrentTracer()->GenerateUniqueName(
"generated_var");
auto new_var = std::make_shared<imperative::VarBase>(act_name);
new_var->SetPersistable(false);
new_var->SetType(framework::proto::VarType::LOD_TENSOR);
new_var->SetDataType(lod_tensor.type());
auto *tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
*tensor = std::move(lod_tensor);
return new_var;
};
for (auto &tensor : tensor_list) {
var_list.emplace_back(func(tensor));
}
return var_list;
},
py::call_guard<py::gil_scoped_release>())
.def("reset", &MultiDeviceFeedReader::Reset,
py::call_guard<py::gil_scoped_release>());
......
......@@ -524,41 +524,29 @@ class NumpyToLoDTensorConverter(object):
return t
class ListTensorProvider(object):
def __init__(self, generator, places):
class DygraphListTensorProvider(object):
def __init__(self, generator, place):
self.generator = generator
self.converters = []
self.places = []
if places:
if not isinstance(places, (list, tuple)):
places = [places]
assert len(
places) == 1, "dygraph mode CAN NOT specify multiple places."
for place in places:
if place:
if isinstance(place, (core.CUDAPlace, core.CPUPlace)):
self.places.append(place)
self.place = place
else:
raise ValueError(
"Please specify a valid place values such as core.CPUPlace or core.CUDAPlace"
)
if len(self.places) == 0:
self.places.append(_current_expected_place())
def _readData(self, iterable, places):
for place, each_sample in six.moves.zip(places, iterable):
for item in each_sample:
if len(self.converters) < len(item):
for i in item:
raise ValueError("Please specify a valid place values \
such as core.CPUPlace or core.CUDAPlace")
else:
self.place = _current_expected_place()
def _read_data(self, iterable, place):
for items in iterable:
if len(self.converters) < len(items):
for _ in items:
self.converters.append(NumpyToLoDTensorConverter(place))
for each_converter, each_slot in six.moves.zip(self.converters,
item):
items):
each_converter.feed(each_slot)
yield [c.done() for c in self.converters]
def __call__(self):
item = []
for batch in self.generator():
item.append(batch)
if len(item) == len(self.places):
yield list(self._readData(item, self.places))
item = []
yield list(self._read_data(batch, self.place))
......@@ -21,7 +21,7 @@ import threading
import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider
from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator
import logging
......@@ -442,14 +442,13 @@ class GeneratorLoader(DataLoaderBase):
def __next__(self):
try:
if not in_dygraph_mode():
if in_dygraph_mode():
return self._reader.read_next_var_list()
else:
if self._return_list:
return self._reader.read_next_list()
else:
return self._reader.read_next()
else:
ret = self._reader.read_next_list()[0]
return [dygraph.base.to_variable(np.array(v)) for v in ret]
except StopIteration:
self._queue.close()
self._reset()
......@@ -517,7 +516,12 @@ class GeneratorLoader(DataLoaderBase):
drop_last=True,
places=None):
assert batch_size > 0, "batch_size must be larger than 0"
if not in_dygraph_mode():
if in_dygraph_mode():
self.set_sample_list_generator(
paddle.batch(
reader, batch_size=batch_size, drop_last=drop_last),
places=places)
else:
has_lod = False
for f in self._feed_list:
if f.lod_level != 0:
......@@ -537,15 +541,16 @@ class GeneratorLoader(DataLoaderBase):
generator=reader,
drop_last=drop_last)
self.set_batch_generator(reader, places=places)
else:
self.set_sample_list_generator(
paddle.batch(
reader, batch_size=batch_size, drop_last=drop_last),
places=places)
return self
def set_sample_list_generator(self, reader, places=None):
if not in_dygraph_mode():
if in_dygraph_mode():
provider = DygraphListTensorProvider(reader, places)
def __tensor_reader_impl__():
for slots in provider():
yield slots[0]
else:
with program_guard(Program(), Program()):
feeder = DataFeeder(
feed_list=self._feed_list, place=core.CPUPlace())
......@@ -555,12 +560,6 @@ class GeneratorLoader(DataLoaderBase):
def __tensor_reader_impl__():
for slots in paddle_reader():
yield [slots[var.name] for var in self._feed_list]
else:
provider = ListTensorProvider(reader, places)
def __tensor_reader_impl__():
for slots in provider():
yield slots[0]
self.set_batch_generator(__tensor_reader_impl__, places)
return self
......@@ -571,8 +570,8 @@ class GeneratorLoader(DataLoaderBase):
assert places is not None, "Places cannot be None when DataLoader is iterable"
self._places = _convert_places(places)
if in_dygraph_mode():
assert len(self._places
) == 1, "Number of places must be 1 in dygraph mode"
assert len(self._places) == 1, \
"Number of places must be 1 in dygraph mode"
else:
if places is not None:
logging.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册