未验证 提交 a32e8bf1 编写于 作者: K Kaipeng Deng 提交者: GitHub

DataLoader supprot dict str (#31481)

* add dict/str/list supprot for DataLoader. test=develop
上级 30a627aa
...@@ -71,9 +71,12 @@ void EraseLoadProcessPIDs(int64_t key) { ...@@ -71,9 +71,12 @@ void EraseLoadProcessPIDs(int64_t key) {
} \ } \
} while (0) } while (0)
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \ #define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \ static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
SIGNAL_HANDLE(SIGNAL); \ auto _w = \
write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
(void)_w; \
SIGNAL_HANDLE(SIGNAL); \
} }
#define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \ #define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
...@@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) { ...@@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) {
SIGNAL_HANDLE(SIGNAL); \ SIGNAL_HANDLE(SIGNAL); \
} }
REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler); REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler,
REGISTER_SIGNAL_HANDLER(SIGBUS, SIGBUS_handler); "ERROR: Unexpected segmentation fault encountered in "
"DataLoader workers.\n");
REGISTER_SIGNAL_HANDLER(
SIGBUS, SIGBUS_handler,
"ERROR: Unexpected BUS error encountered in DataLoader worker. "
"This might be caused by insufficient shared memory (shm), "
"please check whether use_shared_memory is set and storage space "
"in /dev/shm is enough\n");
REGISTER_SIGNAL_HANDLER(SIGFPE, SIGFPE_handler,
"ERROR: Unexpected floating-point exception "
"encountered in DataLoader worker.\n")
REGISTER_SPEC_SIGNAL_HANDLER(SIGTERM, SIGTERM_handler); REGISTER_SPEC_SIGNAL_HANDLER(SIGTERM, SIGTERM_handler);
static inline void setSignalHandler(int signal, static inline void setSignalHandler(int signal,
...@@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal, ...@@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal,
void SetLoadProcessSignalHandler() { void SetLoadProcessSignalHandler() {
setSignalHandler(SIGSEGV, &SIGSEGV_handler, nullptr); setSignalHandler(SIGSEGV, &SIGSEGV_handler, nullptr);
setSignalHandler(SIGBUS, &SIGBUS_handler, nullptr); setSignalHandler(SIGBUS, &SIGBUS_handler, nullptr);
setSignalHandler(SIGFPE, &SIGFPE_handler, nullptr);
setSignalHandler(SIGTERM, &SIGTERM_handler, nullptr); setSignalHandler(SIGTERM, &SIGTERM_handler, nullptr);
} }
......
...@@ -45,7 +45,11 @@ class BlockingQueue { ...@@ -45,7 +45,11 @@ class BlockingQueue {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait( send_cv_.wait(
lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; }); lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; });
EnforceNotKilled(); if (killed_) {
VLOG(3)
<< "WARNING:: Sending an element to a killed reader::BlokcingQueue";
return false;
}
if (closed_) { if (closed_) {
VLOG(5) VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue."; << "WARNING: Sending an element to a closed reader::BlokcingQueue.";
...@@ -66,7 +70,11 @@ class BlockingQueue { ...@@ -66,7 +70,11 @@ class BlockingQueue {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait( send_cv_.wait(
lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; }); lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; });
EnforceNotKilled(); if (killed_) {
VLOG(3)
<< "WARNING:: Sending an element to a killed reader::BlokcingQueue";
return false;
}
if (closed_) { if (closed_) {
VLOG(5) VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue."; << "WARNING: Sending an element to a closed reader::BlokcingQueue.";
......
...@@ -223,6 +223,10 @@ class MultiDeviceFeedReader { ...@@ -223,6 +223,10 @@ class MultiDeviceFeedReader {
ReadAsync(); ReadAsync();
} }
void Shutdown() {
for (auto &r : readers_) r->Shutdown();
}
~MultiDeviceFeedReader() { ~MultiDeviceFeedReader() {
queue_->Close(); queue_->Close();
pool_.reset(); pool_.reset();
...@@ -266,10 +270,6 @@ class MultiDeviceFeedReader { ...@@ -266,10 +270,6 @@ class MultiDeviceFeedReader {
} }
} }
void Shutdown() {
for (auto &r : readers_) r->Shutdown();
}
void Start() { void Start() {
for (auto &r : readers_) r->Start(); for (auto &r : readers_) r->Start();
} }
...@@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) { ...@@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
}, },
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("reset", &ReaderType::Reset, .def("reset", &ReaderType::Reset,
py::call_guard<py::gil_scoped_release>())
.def("shutdown", &ReaderType::Shutdown,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import numpy as np
from ..framework import in_dygraph_mode
from .. import core, layers
try:
from collections.abc import Sequence, Mapping
except:
from collections import Sequence, Mapping
def default_collate_fn(batch):
"""
Default batch collating function for :code:`paddle.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array|paddle.Tensor): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array|Paddle.Tensor: collated batch of input batch data,
fields data type as same as fields in each sample.
"""
sample = batch[0]
if isinstance(sample, np.ndarray):
batch = np.stack(batch, axis=0)
return batch
elif isinstance(sample, paddle.Tensor):
return layers.stack(batch, axis=0)
elif isinstance(sample, numbers.Number):
batch = np.array(batch)
return batch
elif isinstance(sample, (str, bytes)):
return batch
elif isinstance(sample, Mapping):
return {
key: default_collate_fn([d[key] for d in batch])
for key in sample
}
elif isinstance(sample, Sequence):
sample_fields_num = len(sample)
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
raise RuntimeError(
"fileds number not same among samples in a batch")
return [default_collate_fn(fields) for fields in zip(*batch)]
raise TypeError("batch data con only contains: tensor, numpy.ndarray, "
"dict, list, number, but got {}".format(type(sample)))
return outputs
def default_convert_fn(batch):
if isinstance(batch, (paddle.Tensor, np.ndarray)):
return batch
elif isinstance(batch, (str, bytes)):
return batch
elif isinstance(batch, Mapping):
return {key: default_convert_fn(batch[key]) for key in batch}
elif isinstance(batch, Sequence):
return [default_convert_fn(d) for d in batch]
else:
return batch
...@@ -35,181 +35,16 @@ else: ...@@ -35,181 +35,16 @@ else:
import paddle import paddle
from .. import core, layers from .. import core, layers
from ..framework import in_dygraph_mode from ..framework import in_dygraph_mode
from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from .batch_sampler import _InfiniteIterableSampler from .batch_sampler import _InfiniteIterableSampler
from .collate import default_collate_fn, default_convert_fn
from .worker import ParentWatchDog, get_worker_info, _worker_loop, \
_DatasetKind, _IterableDatasetStopIteration, _WorkerException
from .flat import _flatten_batch, _restore_batch
__all__ = ['get_worker_info'] __all__ = ['get_worker_info']
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_INDICES_CHECK_INTERVAL = 5
_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration',
['worker_id'])
def default_collate_fn(batch):
"""
Default batch collating function for :code:`fluid.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array: collated batch
"""
sample = batch[0]
# dataset has only 1 field
if isinstance(sample, np.ndarray):
return [np.stack(batch, axis=0)]
# batch each field
slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
outputs = []
for slot in slots:
if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
tmp = np.stack(slot, axis=0)
outputs.append(tmp)
elif isinstance(slot[0], paddle.Tensor):
tmp = layers.stack(slot, axis=0)
outputs.append(tmp)
else:
raise RuntimeError("Unknown data type {}".format(type(slot[0])))
return outputs
class _DatasetKind(object):
MAP = 0
ITER = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collate_batch, collate_fn,
drop_last):
if kind == _DatasetKind.MAP:
return _MapDatasetFetcher(dataset, auto_collate_batch, collate_fn,
drop_last)
elif kind == _DatasetKind.ITER:
return _IterableDatasetFetcher(dataset, auto_collate_batch,
collate_fn, drop_last)
else:
raise NotImplementedError("unknown Dataset kind {}".format(kind))
class ParentWatchDog(object):
def __init__(self):
self._parent_pid = os.getppid()
self._parent_alive = True
def is_alive(self):
if self._parent_alive:
self._parent_alive = os.getppid() == self._parent_pid
return self._parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info = None
def get_worker_info():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For mode usage and exampls, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
return _worker_info
class WorkerInfo(object):
__initialized = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self.__initialized = True
def __setattr__(self, key, val):
if self.__initialized:
raise RuntimeError("Cannot assign attributes to {} objects".format(
self.__class__.__name__))
return super(WorkerInfo, self).__setattr__(key, val)
class _DataLoaderIterBase(object): class _DataLoaderIterBase(object):
""" """
...@@ -230,7 +65,7 @@ class _DataLoaderIterBase(object): ...@@ -230,7 +65,7 @@ class _DataLoaderIterBase(object):
self._num_workers = loader.num_workers self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader self._use_buffer_reader = loader.use_buffer_reader
self._use_shared_memory = loader.use_shared_memory self._use_shared_memory = loader.use_shared_memory
self._timeout = loader.timeout if loader.timeout > 0 else MP_INDICES_CHECK_INTERVAL self._timeout = loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL
self._worker_init_fn = loader.worker_init_fn self._worker_init_fn = loader.worker_init_fn
self._dataset_kind = loader.dataset_kind self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory self._pin_memory = loader.pin_memory
...@@ -244,7 +79,7 @@ class _DataLoaderIterBase(object): ...@@ -244,7 +79,7 @@ class _DataLoaderIterBase(object):
else: else:
self._sampler_iter = iter( self._sampler_iter = iter(
_InfiniteIterableSampler(self._dataset, 1)) _InfiniteIterableSampler(self._dataset, 1))
self._collate_fn = loader.collate_fn self._collate_fn = loader.collate_fn or default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread # LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data # to put mini-batch data to self._blocking_queue, mini-batch data
...@@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._dataset_kind, self._dataset, self._auto_collate_batch, self._dataset_kind, self._dataset, self._auto_collate_batch,
self._collate_fn, True) self._collate_fn, True)
# NOTE: _structrue_infos used to record the data structure of
# batch to restore batch structure after reading Tensor
# from blocking_queue in single-process mode. Note that
# only single process is used in single-process mode, we
# can record the data structure sequencely in a list without
# recording the send and recv index
self._structure_infos = []
# NOTE: len(self._places) batch data compose as an output # NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas # iteration, set blocking_queue can cache 2 iteration datas
# at most here # at most here
...@@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# read data from dataset in mini-batch # read data from dataset in mini-batch
batch = self._dataset_fetcher.fetch(indices) batch = self._dataset_fetcher.fetch(indices)
# flat batch and record structure infos
batch, structure = _flatten_batch(batch)
self._structure_infos.append(structure)
# pack as LoDTensorArray # pack as LoDTensorArray
array = core.LoDTensorArray() array = core.LoDTensorArray()
for slot in batch: for slot in batch:
if not isinstance(slot, core.LoDTensor): if not isinstance(slot, core.LoDTensor):
# FIXME(dkp): blocking_queue only support
# core.LoDTensorArray as input now, read
# numpy data into a LoDTensorArray here,
# should support paddle.Tensor list later
if isinstance(slot, paddle.Tensor):
slot = slot.numpy()
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(slot, core.CPUPlace()) tmp.set(slot, core.CPUPlace())
slot = tmp slot = tmp
...@@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
def __next__(self): def __next__(self):
try: try:
if in_dygraph_mode(): if in_dygraph_mode():
return self._reader.read_next_var_list() data = self._reader.read_next_var_list()
data = _restore_batch(data, self._structure_infos.pop(0))
else: else:
if self._return_list: if self._return_list:
data = self._reader.read_next_list()
data = [
_restore_batch(d, s)
for d, s in zip(data, self._structure_infos[:len(
self._places)])
]
self._structure_infos = self._structure_infos[len(
self._places):]
# static graph organized data on multi-device with list, if # static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data # place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode # from list for devices to be compatible with dygraph mode
if len(self._places) == 1: if len(self._places) == 1:
return self._reader.read_next_list()[0] data = data[0]
else:
return self._reader.read_next_list()
else: else:
return self._reader.read_next() data = self._reader.read_next()
return data
except StopIteration: except StopIteration:
self._reader.reset() self._reader.shutdown()
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
# python2 compatibility # python2 compatibility
...@@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._blocking_queue.close() self._blocking_queue.close()
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
auto_collate_batch, collate_fn, init_fn, worker_id,
num_workers, use_shared_memory):
try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar.register(_cleanup_mmap)
# set signal handler
core._set_process_signal_handler()
global _worker_info
_worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, dataset=dataset)
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
except:
init_exception = Exception("init_fn failed in worker {}: " \
"{}".format(worker_id, sys.exc_info()))
iterator_drained = False
parent_watch_dog = ParentWatchDog()
while parent_watch_dog.is_alive():
try:
data = indices_queue.get(MP_INDICES_CHECK_INTERVAL)
except queue.Empty:
continue
# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if done_event.is_set() or iterator_drained:
continue
idx, indices = data
try:
if init_exception is not None:
batch = init_exception
init_exception = None
else:
batch = fetcher.fetch(indices)
except Exception as e:
if isinstance(
e, StopIteration) and dataset_kind == _DatasetKind.ITER:
out_queue.put(_IterableDatasetStopIteration(worker_id))
iterator_drained = True
else:
out_queue.put((idx, e))
else:
if use_shared_memory:
# FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list
new_batch = []
for sample in batch:
new_sample = []
for s in sample:
if isinstance(s, paddle.Tensor):
new_sample.append(s.numpy())
else:
new_sample.append(s)
new_batch.append(new_sample)
batch = new_batch
tensor_list = core._convert_to_tensor_list(batch)
out_queue.put((idx, tensor_list))
core._remove_tensor_list_mmap_fds(tensor_list)
else:
out_queue.put((idx, batch))
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
six.reraise(*sys.exc_info())
finally:
if use_shared_memory:
_cleanup_mmap()
class _DataLoaderIterMultiProcess(_DataLoaderIterBase): class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader): def __init__(self, loader):
super(_DataLoaderIterMultiProcess, self).__init__(loader) super(_DataLoaderIterMultiProcess, self).__init__(loader)
...@@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._rcvd_idx = 0 self._rcvd_idx = 0
self._batches_outstanding = 0 self._batches_outstanding = 0
self._task_infos = {} self._task_infos = {}
self._structure_infos = []
# indices outstand as _outstanding_capacity at first, and # indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity. # blocking_queue capacity is also _outstanding_capacity.
...@@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if not self._thread_done_event.is_set(): if not self._thread_done_event.is_set():
if batch is None: if batch is None:
self._exit_thread_expectedly() self._exit_thread_expectedly()
elif isinstance(batch, Exception):
self._exit_thread_unexpectedly()
else: else:
try: try:
# pack as LoDTensorArray # pack as LoDTensorArray
...@@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx # batch indices and increase _rcvd_idx
if self._dataset_kind == _DatasetKind.ITER: if self._dataset_kind == _DatasetKind.ITER:
while self._rcvd_idx < self._send_idx: while self._rcvd_idx < self._send_idx:
sys.stdout.flush()
info = self._task_infos[self._rcvd_idx] info = self._task_infos[self._rcvd_idx]
if len(info) == 2 or self._worker_status[info[0]]: if len(info) == 3 or self._worker_status[info[0]]:
break break
del self._task_infos[self._rcvd_idx] del self._task_infos[self._rcvd_idx]
self._rcvd_idx += 1 self._rcvd_idx += 1
...@@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
continue continue
if self._rcvd_idx in self._task_infos and \ if self._rcvd_idx in self._task_infos and \
len(self._task_infos[self._rcvd_idx]) == 2: len(self._task_infos[self._rcvd_idx]) == 3:
return self._task_infos.pop(self._rcvd_idx)[1] info = self._task_infos.pop(self._rcvd_idx)
self._structure_infos.append(info[2])
return info[1]
try: try:
# [ avoid hang ]: main process may blocking at _reader.read_next when # [ avoid hang ]: main process may blocking at _reader.read_next when
# KeyboardInterrupt, we do following tradeoff: # KeyboardInterrupt, we do following tradeoff:
# 1. get data with timeout, MP_INDICES_CHECK_INTERVAL(5s) as timeout # 1. get data with timeout, MP_STATUS_CHECK_INTERVAL(5s) as timeout
# default, if KeyboardInterrupt blocking, failed workers will be # default, if KeyboardInterrupt blocking, failed workers will be
# checked and raise RuntimeError to quit DataLoader in timeout # checked and raise RuntimeError to quit DataLoader in timeout
# exception handling. # exception handling.
...@@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._try_put_indices() self._try_put_indices()
continue continue
idx, batch = data idx, batch, structure = data
if isinstance(batch, _WorkerException):
self._exit_thread_unexpectedly()
batch.reraise()
if idx == self._rcvd_idx: if idx == self._rcvd_idx:
del self._task_infos[idx] del self._task_infos[idx]
self._structure_infos.append(structure)
return batch return batch
else: else:
self._task_infos[idx] += (batch, ) self._task_infos[idx] += (batch, structure)
continue continue
def _try_put_indices(self): def _try_put_indices(self):
...@@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if in_dygraph_mode(): if in_dygraph_mode():
data = self._reader.read_next_var_list() data = self._reader.read_next_var_list()
data = _restore_batch(data, self._structure_infos.pop(0))
else: else:
if self._return_list: if self._return_list:
data = self._reader.read_next_list() data = self._reader.read_next_list()
data = [
_restore_batch(d, s)
for d, s in zip(data, self._structure_infos[:len(
self._places)])
]
self._structure_infos = self._structure_infos[len(
self._places):]
# static graph organized data on multi-device with list, if # static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data # place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode # from list for devices to be compatible with dygraph mode
...@@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._on_output_batch() self._on_output_batch()
return data return data
except StopIteration: except StopIteration:
self._reader.reset() self._reader.shutdown()
self._try_shutdown_all() self._try_shutdown_all()
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import numpy as np
try:
from collections.abc import Sequence, Mapping
except:
from collections import Sequence, Mapping
FIELD_PREFIX = "_paddle_field_"
def _flatten_batch(batch):
"""
For lod_blocking_queue only receive tensor array, flatten batch
data, extract numpy.array data out as a list of numpy.array to
send to lod_blocking_queue, and save the batch data structure
such as fields in other types (str, int, etc) or key-value map
of dictionaries
"""
def _flatten(batch, flat_batch, structure, field_idx):
if isinstance(batch, Sequence):
for field in batch:
if isinstance(field, np.ndarray):
structure.append('{}{}'.format(FIELD_PREFIX, field_idx))
flat_batch.append(field)
field_idx += 1
elif isinstance(field, paddle.Tensor):
structure.append('{}{}'.format(FIELD_PREFIX, field_idx))
flat_batch.append(field.numpy())
field_idx += 1
elif isinstance(field, (str, bytes, numbers.Number)):
structure.append(field)
elif isinstance(field, Sequence):
field_struct, field_idx = _flatten(field, flat_batch, [],
field_idx)
structure.append(field_struct)
elif isinstance(field, Mapping):
field_struct, field_idx = _flatten(field, flat_batch, {},
field_idx)
structure.append(field_struct)
else:
structure.append(field)
elif isinstance(batch, Mapping):
for k, field in batch.items():
if isinstance(field, np.ndarray):
structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx)
flat_batch.append(field)
field_idx += 1
elif isinstance(field, paddle.Tensor):
structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx)
flat_batch.append(field.numpy())
field_idx += 1
elif isinstance(field, (str, bytes, numbers.Number)):
structure[k] = field
elif isinstance(field, Sequence):
field_struct, field_idx = _flatten(field, flat_batch, [],
field_idx)
structure[k] = field_struct
elif isinstance(field, Mapping):
field_struct, field_idx = _flatten(field, flat_batch, {},
field_idx)
structure[k] = field_struct
else:
structure[k] = field
else:
raise TypeError("wrong flat data type: {}".format(type(batch)))
return structure, field_idx
# sample only contains single fields
if not isinstance(batch, Sequence):
flat_batch = []
structure, _ = _flatten([batch], flat_batch, [], 0)
return flat_batch, structure[0]
flat_batch = []
structure, _ = _flatten(batch, flat_batch, [], 0)
return flat_batch, structure
def _restore_batch(flat_batch, structure):
"""
After reading list of Tensor data from lod_blocking_queue outputs,
use this function to restore the batch data structrue, replace
:attr:`_paddle_field_x` with data from flat_batch
"""
def _restore(structure, field_idx):
if isinstance(structure, Sequence):
for i, field in enumerate(structure):
if isinstance(field, str) and field.startswith(FIELD_PREFIX):
cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
field_idx = max(field_idx, cur_field_idx)
assert flat_batch[cur_field_idx] is not None, \
"flat_batch[{}] parsed repeatly"
structure[i] = flat_batch[cur_field_idx]
flat_batch[cur_field_idx] = None
elif isinstance(field, (str, bytes, numbers.Number)):
continue
elif isinstance(field, (Sequence, Mapping)):
field_idx = _restore(structure[i], field_idx)
elif isinstance(structure, Mapping):
for k, field in structure.items():
if isinstance(field, str) and field.startswith(FIELD_PREFIX):
cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
field_idx = max(field_idx, cur_field_idx)
assert flat_batch[cur_field_idx] is not None, \
"flat_batch[{}] parsed repeatly"
structure[k] = flat_batch[cur_field_idx]
flat_batch[cur_field_idx] = None
elif isinstance(field, (str, bytes, numbers.Number)):
continue
elif isinstance(field, (Sequence, Mapping)):
field_idx = _restore(structure[k], field_idx)
else:
raise TypeError("wrong flat data type: {}".format(type(batch)))
return field_idx
assert isinstance(flat_batch, Sequence), \
"flat_batch is not a list or tuple"
# no np.array in dataset, no output tensor from blocking queue
# simply return structure
if len(flat_batch) == 0:
return structure
# sample only contains single fields
if isinstance(structure, (str, bytes)):
assert structure == '{}{}'.format(FIELD_PREFIX, 0), \
"invalid structure: {}".format(structure)
return flat_batch[0]
field_idx = _restore(structure, 0)
assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete"
return structure
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import sys
import paddle
import numpy as np
import traceback
from collections import namedtuple
from .. import core
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from ..multiprocess_utils import _cleanup_mmap, CleanupFuncRegistrar, MP_STATUS_CHECK_INTERVAL
from ..framework import in_dygraph_mode
from .flat import _flatten_batch
# NOTE: queue has a different name in python2 and python3
if six.PY2:
import Queue as queue
else:
import queue
__all__ = ['get_worker_info']
class _IterableDatasetStopIteration(object):
def __init__(self, worker_id):
self.worker_id = worker_id
class _DatasetKind(object):
MAP = 0
ITER = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collate_batch, collate_fn,
drop_last):
if kind == _DatasetKind.MAP:
return _MapDatasetFetcher(dataset, auto_collate_batch, collate_fn,
drop_last)
elif kind == _DatasetKind.ITER:
return _IterableDatasetFetcher(dataset, auto_collate_batch,
collate_fn, drop_last)
else:
raise NotImplementedError("unknown Dataset kind {}".format(kind))
class ParentWatchDog(object):
def __init__(self):
self._parent_pid = os.getppid()
self._parent_alive = True
def is_alive(self):
if self._parent_alive:
self._parent_alive = os.getppid() == self._parent_pid
return self._parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info = None
def get_worker_info():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For more usage and examples, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
return _worker_info
class WorkerInfo(object):
__initialized = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self.__initialized = True
def __setattr__(self, key, val):
if self.__initialized:
raise RuntimeError("Cannot assign attributes to {} objects".format(
self.__class__.__name__))
return super(WorkerInfo, self).__setattr__(key, val)
class _WorkerException(object):
def __init__(self, worker_id, exc_info=None):
self.worker_id = worker_id
exc_info = exc_info or sys.exc_info()
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
def reraise(self):
msg = "DataLoader worker({}) caught {} with message:\n{}".format(
self.worker_id, self.exc_type.__name__, self.exc_msg)
if getattr(self.exc_type, "message", None):
raise self.exc_type(message=msg)
raise self.exc_type(msg)
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
auto_collate_batch, collate_fn, init_fn, worker_id,
num_workers, use_shared_memory):
try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar.register(_cleanup_mmap)
# set signal handler
core._set_process_signal_handler()
global _worker_info
_worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, dataset=dataset)
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
except:
init_exception = _WorkerException(worker_id)
iterator_drained = False
parent_watch_dog = ParentWatchDog()
while parent_watch_dog.is_alive():
try:
data = indices_queue.get(MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if done_event.is_set() or iterator_drained:
continue
idx, indices = data
try:
if init_exception is not None:
batch = init_exception
init_exception = None
else:
# NOTE: GPU tensor operation is not supported in sub-process
# but default device is GPU in paddle-gpu version, which
# may copy CPU tensor to GPU even if users want to use
# CPU tensor operation, so we add CPUPlace guard here
# to make sure tensor will be operated only on CPU
with paddle.fluid.dygraph.guard(place=paddle.CPUPlace()):
batch = fetcher.fetch(indices)
except Exception as e:
if isinstance(
e, StopIteration) and dataset_kind == _DatasetKind.ITER:
out_queue.put(_IterableDatasetStopIteration(worker_id))
iterator_drained = True
else:
out_queue.put((idx, _WorkerException(worker_id), None))
else:
if isinstance(batch, _WorkerException):
out_queue.put((idx, batch, None))
batch, structure = _flatten_batch(batch)
if use_shared_memory:
tensor_list = core._convert_to_tensor_list(batch)
out_queue.put((idx, tensor_list, structure))
core._remove_tensor_list_mmap_fds(tensor_list)
else:
out_queue.put((idx, batch, structure))
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
six.reraise(*sys.exc_info())
finally:
if use_shared_memory:
_cleanup_mmap()
...@@ -25,6 +25,10 @@ if six.PY2: ...@@ -25,6 +25,10 @@ if six.PY2:
else: else:
import queue import queue
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_STATUS_CHECK_INTERVAL = 5.
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading, # NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# the data in the queue needs to be popped. Then the LoDTensor read by the main process # the data in the queue needs to be popped. Then the LoDTensor read by the main process
# from the child process will automatically clear the memory-mapped file. # from the child process will automatically clear the memory-mapped file.
......
...@@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset): ...@@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset):
assert isinstance(label, paddle.Tensor) assert isinstance(label, paddle.Tensor)
class ComplextDataset(Dataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __len__(self):
return self.sample_num
def __getitem__(self, idx):
return (3.1, 'abc', paddle.to_tensor(
np.random.random([IMAGE_SIZE]).astype('float32'),
place=paddle.CPUPlace()),
[1, np.random.random([2]).astype('float32')], {
'a': 2.0,
'b': np.random.random([2]).astype('float32')
})
class TestComplextDataset(unittest.TestCase):
def run_main(self, num_workers):
paddle.static.default_startup_program().random_seed = 1
paddle.static.default_main_program().random_seed = 1
place = paddle.CPUPlace()
with fluid.dygraph.guard(place):
dataset = ComplextDataset(16)
assert len(dataset) == 16
dataloader = DataLoader(
dataset,
places=place,
num_workers=num_workers,
batch_size=2,
drop_last=True)
for i, data in enumerate(dataloader()):
assert len(data) == 5
# data[0]: collate 3.1
assert data[0].shape == [2]
assert isinstance(data[1], list)
# data[1]: collate 'abc'
assert len(data[1]) == 2
assert isinstance(data[1][0], str)
assert isinstance(data[1][1], str)
# data[2]: collate tensor
assert data[2].shape == [2, IMAGE_SIZE]
# data[3]: collate list
assert isinstance(data[3], list)
assert data[3][0].shape == [2]
assert data[3][1].shape == [2, 2]
# data[4]: collate dict
assert isinstance(data[4], dict)
assert data[4]['a'].shape == [2]
assert data[4]['b'].shape == [2, 2]
def test_main(self):
for num_workers in [0, 2]:
self.run_main(num_workers)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase): ...@@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase):
rets = [] rets = []
for d in dataloader: for d in dataloader:
rets.append(d[0].numpy()[0][0]) rets.append(d.numpy()[0][0])
assert tuple(sorted(rets)) == tuple(range(0, 10)) assert tuple(sorted(rets)) == tuple(range(0, 10))
...@@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
rets = [] rets = []
for d in dataloader: for d in dataloader:
rets.append(d[0].numpy()[0][0]) rets.append(d.numpy()[0][0])
assert tuple(sorted(rets)) == tuple(range(0, 10)) assert tuple(sorted(rets)) == tuple(range(0, 10))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册