From a32e8bf1e7fb45b9bae85e80fe7742eae8739fac Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 15 Mar 2021 10:29:54 +0800 Subject: [PATCH] DataLoader supprot dict str (#31481) * add dict/str/list supprot for DataLoader. test=develop --- paddle/fluid/imperative/data_loader.cc | 24 +- .../fluid/operators/reader/blocking_queue.h | 12 +- paddle/fluid/pybind/reader_py.cc | 10 +- python/paddle/fluid/dataloader/collate.py | 87 +++++ .../fluid/dataloader/dataloader_iter.py | 342 +++--------------- python/paddle/fluid/dataloader/flat.py | 150 ++++++++ python/paddle/fluid/dataloader/worker.py | 253 +++++++++++++ python/paddle/fluid/multiprocess_utils.py | 4 + .../test_multiprocess_dataloader_dataset.py | 57 +++ ...ocess_dataloader_iterable_dataset_split.py | 4 +- 10 files changed, 646 insertions(+), 297 deletions(-) create mode 100644 python/paddle/fluid/dataloader/collate.py create mode 100644 python/paddle/fluid/dataloader/flat.py create mode 100644 python/paddle/fluid/dataloader/worker.py diff --git a/paddle/fluid/imperative/data_loader.cc b/paddle/fluid/imperative/data_loader.cc index 71ea82e9a1..c43149c9b5 100644 --- a/paddle/fluid/imperative/data_loader.cc +++ b/paddle/fluid/imperative/data_loader.cc @@ -71,9 +71,12 @@ void EraseLoadProcessPIDs(int64_t key) { } \ } while (0) -#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \ - static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \ - SIGNAL_HANDLE(SIGNAL); \ +#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ + static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \ + auto _w = \ + write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ + (void)_w; \ + SIGNAL_HANDLE(SIGNAL); \ } #define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \ @@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) { SIGNAL_HANDLE(SIGNAL); \ } -REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler); -REGISTER_SIGNAL_HANDLER(SIGBUS, SIGBUS_handler); +REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler, + "ERROR: Unexpected segmentation fault encountered in " + "DataLoader workers.\n"); +REGISTER_SIGNAL_HANDLER( + SIGBUS, SIGBUS_handler, + "ERROR: Unexpected BUS error encountered in DataLoader worker. " + "This might be caused by insufficient shared memory (shm), " + "please check whether use_shared_memory is set and storage space " + "in /dev/shm is enough\n"); +REGISTER_SIGNAL_HANDLER(SIGFPE, SIGFPE_handler, + "ERROR: Unexpected floating-point exception " + "encountered in DataLoader worker.\n") REGISTER_SPEC_SIGNAL_HANDLER(SIGTERM, SIGTERM_handler); static inline void setSignalHandler(int signal, @@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal, void SetLoadProcessSignalHandler() { setSignalHandler(SIGSEGV, &SIGSEGV_handler, nullptr); setSignalHandler(SIGBUS, &SIGBUS_handler, nullptr); + setSignalHandler(SIGFPE, &SIGFPE_handler, nullptr); setSignalHandler(SIGTERM, &SIGTERM_handler, nullptr); } diff --git a/paddle/fluid/operators/reader/blocking_queue.h b/paddle/fluid/operators/reader/blocking_queue.h index 8929da20b5..f126070a7e 100644 --- a/paddle/fluid/operators/reader/blocking_queue.h +++ b/paddle/fluid/operators/reader/blocking_queue.h @@ -45,7 +45,11 @@ class BlockingQueue { std::unique_lock lock(mutex_); send_cv_.wait( lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; }); - EnforceNotKilled(); + if (killed_) { + VLOG(3) + << "WARNING:: Sending an element to a killed reader::BlokcingQueue"; + return false; + } if (closed_) { VLOG(5) << "WARNING: Sending an element to a closed reader::BlokcingQueue."; @@ -66,7 +70,11 @@ class BlockingQueue { std::unique_lock lock(mutex_); send_cv_.wait( lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; }); - EnforceNotKilled(); + if (killed_) { + VLOG(3) + << "WARNING:: Sending an element to a killed reader::BlokcingQueue"; + return false; + } if (closed_) { VLOG(5) << "WARNING: Sending an element to a closed reader::BlokcingQueue."; diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc index 856c5aac5e..abe1977eb6 100644 --- a/paddle/fluid/pybind/reader_py.cc +++ b/paddle/fluid/pybind/reader_py.cc @@ -223,6 +223,10 @@ class MultiDeviceFeedReader { ReadAsync(); } + void Shutdown() { + for (auto &r : readers_) r->Shutdown(); + } + ~MultiDeviceFeedReader() { queue_->Close(); pool_.reset(); @@ -266,10 +270,6 @@ class MultiDeviceFeedReader { } } - void Shutdown() { - for (auto &r : readers_) r->Shutdown(); - } - void Start() { for (auto &r : readers_) r->Start(); } @@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) { }, py::call_guard()) .def("reset", &ReaderType::Reset, + py::call_guard()) + .def("shutdown", &ReaderType::Shutdown, py::call_guard()); } diff --git a/python/paddle/fluid/dataloader/collate.py b/python/paddle/fluid/dataloader/collate.py new file mode 100644 index 0000000000..ddc010d042 --- /dev/null +++ b/python/paddle/fluid/dataloader/collate.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numbers +import numpy as np +from ..framework import in_dygraph_mode +from .. import core, layers + +try: + from collections.abc import Sequence, Mapping +except: + from collections import Sequence, Mapping + + +def default_collate_fn(batch): + """ + Default batch collating function for :code:`paddle.io.DataLoader`, + batch should be a list of samples, and each sample should be a list + of fields as follows: + + [[filed1, filed2, ...], [filed1, filed2, ...], ...] + + This default collate function zipped each filed together and stack + each filed as the batch field as follows: + + [batch_filed1, batch_filed2, ...] + + Args: + batch(list of list of numpy array|paddle.Tensor): the batch data, each fields + should be a numpy array, each sample should be a list of + fileds, and batch should be a list of sample. + + Returns: + a list of numpy array|Paddle.Tensor: collated batch of input batch data, + fields data type as same as fields in each sample. + """ + sample = batch[0] + if isinstance(sample, np.ndarray): + batch = np.stack(batch, axis=0) + return batch + elif isinstance(sample, paddle.Tensor): + return layers.stack(batch, axis=0) + elif isinstance(sample, numbers.Number): + batch = np.array(batch) + return batch + elif isinstance(sample, (str, bytes)): + return batch + elif isinstance(sample, Mapping): + return { + key: default_collate_fn([d[key] for d in batch]) + for key in sample + } + elif isinstance(sample, Sequence): + sample_fields_num = len(sample) + if not all(len(sample) == sample_fields_num for sample in iter(batch)): + raise RuntimeError( + "fileds number not same among samples in a batch") + return [default_collate_fn(fields) for fields in zip(*batch)] + + raise TypeError("batch data con only contains: tensor, numpy.ndarray, " + "dict, list, number, but got {}".format(type(sample))) + return outputs + + +def default_convert_fn(batch): + if isinstance(batch, (paddle.Tensor, np.ndarray)): + return batch + elif isinstance(batch, (str, bytes)): + return batch + elif isinstance(batch, Mapping): + return {key: default_convert_fn(batch[key]) for key in batch} + elif isinstance(batch, Sequence): + return [default_convert_fn(d) for d in batch] + else: + return batch diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 0dd2420691..0cd12e874d 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -35,181 +35,16 @@ else: import paddle from .. import core, layers from ..framework import in_dygraph_mode -from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler +from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher from .batch_sampler import _InfiniteIterableSampler +from .collate import default_collate_fn, default_convert_fn +from .worker import ParentWatchDog, get_worker_info, _worker_loop, \ + _DatasetKind, _IterableDatasetStopIteration, _WorkerException +from .flat import _flatten_batch, _restore_batch __all__ = ['get_worker_info'] -# multi-process worker check indices queue interval, avoid -# hanging in subprocess data loading -MP_INDICES_CHECK_INTERVAL = 5 - -_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', - ['worker_id']) - - -def default_collate_fn(batch): - """ - Default batch collating function for :code:`fluid.io.DataLoader`, - batch should be a list of samples, and each sample should be a list - of fields as follows: - - [[filed1, filed2, ...], [filed1, filed2, ...], ...] - - This default collate function zipped each filed together and stack - each filed as the batch field as follows: - - [batch_filed1, batch_filed2, ...] - - Args: - batch(list of list of numpy array): the batch data, each fields - should be a numpy array, each sample should be a list of - fileds, and batch should be a list of sample. - - Returns: - a list of numpy array: collated batch - """ - sample = batch[0] - # dataset has only 1 field - if isinstance(sample, np.ndarray): - return [np.stack(batch, axis=0)] - - # batch each field - slots = [] - for items in batch: - for i, item in enumerate(items): - if len(slots) < len(items): - slots.append([item]) - else: - slots[i].append(item) - - outputs = [] - for slot in slots: - if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)): - tmp = np.stack(slot, axis=0) - outputs.append(tmp) - elif isinstance(slot[0], paddle.Tensor): - tmp = layers.stack(slot, axis=0) - outputs.append(tmp) - else: - raise RuntimeError("Unknown data type {}".format(type(slot[0]))) - return outputs - - -class _DatasetKind(object): - MAP = 0 - ITER = 1 - - @staticmethod - def create_fetcher(kind, dataset, auto_collate_batch, collate_fn, - drop_last): - if kind == _DatasetKind.MAP: - return _MapDatasetFetcher(dataset, auto_collate_batch, collate_fn, - drop_last) - elif kind == _DatasetKind.ITER: - return _IterableDatasetFetcher(dataset, auto_collate_batch, - collate_fn, drop_last) - else: - raise NotImplementedError("unknown Dataset kind {}".format(kind)) - - -class ParentWatchDog(object): - def __init__(self): - self._parent_pid = os.getppid() - self._parent_alive = True - - def is_alive(self): - if self._parent_alive: - self._parent_alive = os.getppid() == self._parent_pid - return self._parent_alive - - -# worker information for each workers, used for splitting data copy -# for IteratorDataset in worker processes. -_worker_info = None - - -def get_worker_info(): - """ - Get DataLoader worker process information function, this function is - used to split data copy in worker process for IterableDataset - (see :code:`paddle.io.IterableDataset`), worker information contains - following fields: - - :attr:`num_workers`: total worker process number, see `paddle.io.DataLoader` - - :attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1` - - :attr:`dataset`: the dataset object in this worker process - - Returns: - WorkerInfo: an instance of WorkerInfo which contains fields above. - - .. note:: - For mode usage and exampls, please see :code:`paddle.io.IterableDataset` - - Example: - - .. code-block:: python - - import math - import paddle - import numpy as np - from paddle.io import IterableDataset, DataLoader, get_worker_info - - class SplitedIterableDataset(IterableDataset): - def __init__(self, start, end): - self.start = start - self.end = end - - def __iter__(self): - worker_info = get_worker_info() - if worker_info is None: - iter_start = self.start - iter_end = self.end - else: - per_worker = int( - math.ceil((self.end - self.start) / float( - worker_info.num_workers))) - worker_id = worker_info.id - iter_start = self.start + worker_id * per_worker - iter_end = min(iter_start + per_worker, self.end) - - for i in range(iter_start, iter_end): - yield np.array([i]) - - place = paddle.CPUPlace() - dataset = SplitedIterableDataset(start=2, end=9) - dataloader = DataLoader( - dataset, - places=place, - num_workers=2, - batch_size=1, - drop_last=True) - - for data in dataloader: - print(data) - # outputs: [2, 5, 3, 6, 4, 7] - - """ - return _worker_info - - -class WorkerInfo(object): - __initialized = False - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - self.__initialized = True - - def __setattr__(self, key, val): - if self.__initialized: - raise RuntimeError("Cannot assign attributes to {} objects".format( - self.__class__.__name__)) - return super(WorkerInfo, self).__setattr__(key, val) - class _DataLoaderIterBase(object): """ @@ -230,7 +65,7 @@ class _DataLoaderIterBase(object): self._num_workers = loader.num_workers self._use_buffer_reader = loader.use_buffer_reader self._use_shared_memory = loader.use_shared_memory - self._timeout = loader.timeout if loader.timeout > 0 else MP_INDICES_CHECK_INTERVAL + self._timeout = loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL self._worker_init_fn = loader.worker_init_fn self._dataset_kind = loader.dataset_kind self._pin_memory = loader.pin_memory @@ -244,7 +79,7 @@ class _DataLoaderIterBase(object): else: self._sampler_iter = iter( _InfiniteIterableSampler(self._dataset, 1)) - self._collate_fn = loader.collate_fn + self._collate_fn = loader.collate_fn or default_convert_fn # LoDTensorBlockingQueue instance for create_py_reader and a thread # to put mini-batch data to self._blocking_queue, mini-batch data @@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): self._dataset_kind, self._dataset, self._auto_collate_batch, self._collate_fn, True) + # NOTE: _structrue_infos used to record the data structure of + # batch to restore batch structure after reading Tensor + # from blocking_queue in single-process mode. Note that + # only single process is used in single-process mode, we + # can record the data structure sequencely in a list without + # recording the send and recv index + self._structure_infos = [] + # NOTE: len(self._places) batch data compose as an output # iteration, set blocking_queue can cache 2 iteration datas # at most here @@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): # read data from dataset in mini-batch batch = self._dataset_fetcher.fetch(indices) + # flat batch and record structure infos + batch, structure = _flatten_batch(batch) + self._structure_infos.append(structure) + # pack as LoDTensorArray array = core.LoDTensorArray() for slot in batch: if not isinstance(slot, core.LoDTensor): - # FIXME(dkp): blocking_queue only support - # core.LoDTensorArray as input now, read - # numpy data into a LoDTensorArray here, - # should support paddle.Tensor list later - if isinstance(slot, paddle.Tensor): - slot = slot.numpy() tmp = core.LoDTensor() tmp.set(slot, core.CPUPlace()) slot = tmp @@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): def __next__(self): try: if in_dygraph_mode(): - return self._reader.read_next_var_list() + data = self._reader.read_next_var_list() + data = _restore_batch(data, self._structure_infos.pop(0)) else: if self._return_list: + data = self._reader.read_next_list() + data = [ + _restore_batch(d, s) + for d, s in zip(data, self._structure_infos[:len( + self._places)]) + ] + self._structure_infos = self._structure_infos[len( + self._places):] # static graph organized data on multi-device with list, if # place number is 1, there is only 1 device, extra the data # from list for devices to be compatible with dygraph mode if len(self._places) == 1: - return self._reader.read_next_list()[0] - else: - return self._reader.read_next_list() + data = data[0] else: - return self._reader.read_next() + data = self._reader.read_next() + + return data except StopIteration: - self._reader.reset() + self._reader.shutdown() six.reraise(*sys.exc_info()) # python2 compatibility @@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): self._blocking_queue.close() -# NOTE(chenweihang): _worker_loop must be top level method to be pickled -def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, - auto_collate_batch, collate_fn, init_fn, worker_id, - num_workers, use_shared_memory): - try: - # NOTE: [ mmap files clear ] When the child process exits unexpectedly, - # some shared memory objects may have been applied for but have not yet - # been put into the inter-process Queue. This part of the object needs - # to be cleaned up when the process ends. - CleanupFuncRegistrar.register(_cleanup_mmap) - - # set signal handler - core._set_process_signal_handler() - - global _worker_info - _worker_info = WorkerInfo( - id=worker_id, num_workers=num_workers, dataset=dataset) - - init_exception = None - try: - if init_fn is not None: - init_fn(worker_id) - fetcher = _DatasetKind.create_fetcher( - dataset_kind, dataset, auto_collate_batch, collate_fn, True) - except: - init_exception = Exception("init_fn failed in worker {}: " \ - "{}".format(worker_id, sys.exc_info())) - - iterator_drained = False - parent_watch_dog = ParentWatchDog() - - while parent_watch_dog.is_alive(): - try: - data = indices_queue.get(MP_INDICES_CHECK_INTERVAL) - except queue.Empty: - continue - - # None as poison piil, so worker event should be set - if data is None: - assert done_event.is_set() or iterator_drained, \ - "get None when worker done_event set" - break - # If worker done event is set but get still get data in - # indices_queue, remaining data should be get and skipped. - if done_event.is_set() or iterator_drained: - continue - - idx, indices = data - try: - if init_exception is not None: - batch = init_exception - init_exception = None - else: - batch = fetcher.fetch(indices) - except Exception as e: - if isinstance( - e, StopIteration) and dataset_kind == _DatasetKind.ITER: - out_queue.put(_IterableDatasetStopIteration(worker_id)) - iterator_drained = True - else: - out_queue.put((idx, e)) - else: - if use_shared_memory: - # FIXME(dkp): _convert_to_tensor_list only support np.array - # list now, should support paddle.Tensor list - new_batch = [] - for sample in batch: - new_sample = [] - for s in sample: - if isinstance(s, paddle.Tensor): - new_sample.append(s.numpy()) - else: - new_sample.append(s) - new_batch.append(new_sample) - batch = new_batch - - tensor_list = core._convert_to_tensor_list(batch) - out_queue.put((idx, tensor_list)) - core._remove_tensor_list_mmap_fds(tensor_list) - else: - out_queue.put((idx, batch)) - except KeyboardInterrupt: - # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process - pass - except: - six.reraise(*sys.exc_info()) - finally: - if use_shared_memory: - _cleanup_mmap() - - class _DataLoaderIterMultiProcess(_DataLoaderIterBase): def __init__(self, loader): super(_DataLoaderIterMultiProcess, self).__init__(loader) @@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._rcvd_idx = 0 self._batches_outstanding = 0 self._task_infos = {} + self._structure_infos = [] # indices outstand as _outstanding_capacity at first, and # blocking_queue capacity is also _outstanding_capacity. @@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): if not self._thread_done_event.is_set(): if batch is None: self._exit_thread_expectedly() - elif isinstance(batch, Exception): - self._exit_thread_unexpectedly() else: try: # pack as LoDTensorArray @@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): # batch indices and increase _rcvd_idx if self._dataset_kind == _DatasetKind.ITER: while self._rcvd_idx < self._send_idx: + sys.stdout.flush() info = self._task_infos[self._rcvd_idx] - if len(info) == 2 or self._worker_status[info[0]]: + if len(info) == 3 or self._worker_status[info[0]]: break del self._task_infos[self._rcvd_idx] self._rcvd_idx += 1 @@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): continue if self._rcvd_idx in self._task_infos and \ - len(self._task_infos[self._rcvd_idx]) == 2: - return self._task_infos.pop(self._rcvd_idx)[1] + len(self._task_infos[self._rcvd_idx]) == 3: + info = self._task_infos.pop(self._rcvd_idx) + self._structure_infos.append(info[2]) + return info[1] try: # [ avoid hang ]: main process may blocking at _reader.read_next when # KeyboardInterrupt, we do following tradeoff: - # 1. get data with timeout, MP_INDICES_CHECK_INTERVAL(5s) as timeout + # 1. get data with timeout, MP_STATUS_CHECK_INTERVAL(5s) as timeout # default, if KeyboardInterrupt blocking, failed workers will be # checked and raise RuntimeError to quit DataLoader in timeout # exception handling. @@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._try_put_indices() continue - idx, batch = data + idx, batch, structure = data + if isinstance(batch, _WorkerException): + self._exit_thread_unexpectedly() + batch.reraise() + if idx == self._rcvd_idx: del self._task_infos[idx] + self._structure_infos.append(structure) return batch else: - self._task_infos[idx] += (batch, ) + self._task_infos[idx] += (batch, structure) continue def _try_put_indices(self): @@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): if in_dygraph_mode(): data = self._reader.read_next_var_list() + data = _restore_batch(data, self._structure_infos.pop(0)) else: if self._return_list: data = self._reader.read_next_list() + data = [ + _restore_batch(d, s) + for d, s in zip(data, self._structure_infos[:len( + self._places)]) + ] + self._structure_infos = self._structure_infos[len( + self._places):] # static graph organized data on multi-device with list, if # place number is 1, there is only 1 device, extra the data # from list for devices to be compatible with dygraph mode @@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._on_output_batch() return data except StopIteration: - self._reader.reset() + self._reader.shutdown() self._try_shutdown_all() six.reraise(*sys.exc_info()) diff --git a/python/paddle/fluid/dataloader/flat.py b/python/paddle/fluid/dataloader/flat.py new file mode 100644 index 0000000000..6cccbc7ee4 --- /dev/null +++ b/python/paddle/fluid/dataloader/flat.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numbers +import numpy as np + +try: + from collections.abc import Sequence, Mapping +except: + from collections import Sequence, Mapping + +FIELD_PREFIX = "_paddle_field_" + + +def _flatten_batch(batch): + """ + For lod_blocking_queue only receive tensor array, flatten batch + data, extract numpy.array data out as a list of numpy.array to + send to lod_blocking_queue, and save the batch data structure + such as fields in other types (str, int, etc) or key-value map + of dictionaries + """ + + def _flatten(batch, flat_batch, structure, field_idx): + if isinstance(batch, Sequence): + for field in batch: + if isinstance(field, np.ndarray): + structure.append('{}{}'.format(FIELD_PREFIX, field_idx)) + flat_batch.append(field) + field_idx += 1 + elif isinstance(field, paddle.Tensor): + structure.append('{}{}'.format(FIELD_PREFIX, field_idx)) + flat_batch.append(field.numpy()) + field_idx += 1 + elif isinstance(field, (str, bytes, numbers.Number)): + structure.append(field) + elif isinstance(field, Sequence): + field_struct, field_idx = _flatten(field, flat_batch, [], + field_idx) + structure.append(field_struct) + elif isinstance(field, Mapping): + field_struct, field_idx = _flatten(field, flat_batch, {}, + field_idx) + structure.append(field_struct) + else: + structure.append(field) + elif isinstance(batch, Mapping): + for k, field in batch.items(): + if isinstance(field, np.ndarray): + structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx) + flat_batch.append(field) + field_idx += 1 + elif isinstance(field, paddle.Tensor): + structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx) + flat_batch.append(field.numpy()) + field_idx += 1 + elif isinstance(field, (str, bytes, numbers.Number)): + structure[k] = field + elif isinstance(field, Sequence): + field_struct, field_idx = _flatten(field, flat_batch, [], + field_idx) + structure[k] = field_struct + elif isinstance(field, Mapping): + field_struct, field_idx = _flatten(field, flat_batch, {}, + field_idx) + structure[k] = field_struct + else: + structure[k] = field + else: + raise TypeError("wrong flat data type: {}".format(type(batch))) + + return structure, field_idx + + # sample only contains single fields + if not isinstance(batch, Sequence): + flat_batch = [] + structure, _ = _flatten([batch], flat_batch, [], 0) + return flat_batch, structure[0] + flat_batch = [] + structure, _ = _flatten(batch, flat_batch, [], 0) + return flat_batch, structure + + +def _restore_batch(flat_batch, structure): + """ + After reading list of Tensor data from lod_blocking_queue outputs, + use this function to restore the batch data structrue, replace + :attr:`_paddle_field_x` with data from flat_batch + """ + + def _restore(structure, field_idx): + if isinstance(structure, Sequence): + for i, field in enumerate(structure): + if isinstance(field, str) and field.startswith(FIELD_PREFIX): + cur_field_idx = int(field.replace(FIELD_PREFIX, '')) + field_idx = max(field_idx, cur_field_idx) + assert flat_batch[cur_field_idx] is not None, \ + "flat_batch[{}] parsed repeatly" + structure[i] = flat_batch[cur_field_idx] + flat_batch[cur_field_idx] = None + elif isinstance(field, (str, bytes, numbers.Number)): + continue + elif isinstance(field, (Sequence, Mapping)): + field_idx = _restore(structure[i], field_idx) + elif isinstance(structure, Mapping): + for k, field in structure.items(): + if isinstance(field, str) and field.startswith(FIELD_PREFIX): + cur_field_idx = int(field.replace(FIELD_PREFIX, '')) + field_idx = max(field_idx, cur_field_idx) + assert flat_batch[cur_field_idx] is not None, \ + "flat_batch[{}] parsed repeatly" + structure[k] = flat_batch[cur_field_idx] + flat_batch[cur_field_idx] = None + elif isinstance(field, (str, bytes, numbers.Number)): + continue + elif isinstance(field, (Sequence, Mapping)): + field_idx = _restore(structure[k], field_idx) + else: + raise TypeError("wrong flat data type: {}".format(type(batch))) + + return field_idx + + assert isinstance(flat_batch, Sequence), \ + "flat_batch is not a list or tuple" + + # no np.array in dataset, no output tensor from blocking queue + # simply return structure + if len(flat_batch) == 0: + return structure + + # sample only contains single fields + if isinstance(structure, (str, bytes)): + assert structure == '{}{}'.format(FIELD_PREFIX, 0), \ + "invalid structure: {}".format(structure) + return flat_batch[0] + field_idx = _restore(structure, 0) + assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete" + return structure diff --git a/python/paddle/fluid/dataloader/worker.py b/python/paddle/fluid/dataloader/worker.py new file mode 100644 index 0000000000..2d1b554e53 --- /dev/null +++ b/python/paddle/fluid/dataloader/worker.py @@ -0,0 +1,253 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import six +import sys +import paddle +import numpy as np +import traceback +from collections import namedtuple +from .. import core +from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher +from ..multiprocess_utils import _cleanup_mmap, CleanupFuncRegistrar, MP_STATUS_CHECK_INTERVAL +from ..framework import in_dygraph_mode +from .flat import _flatten_batch + +# NOTE: queue has a different name in python2 and python3 +if six.PY2: + import Queue as queue +else: + import queue + +__all__ = ['get_worker_info'] + + +class _IterableDatasetStopIteration(object): + def __init__(self, worker_id): + self.worker_id = worker_id + + +class _DatasetKind(object): + MAP = 0 + ITER = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collate_batch, collate_fn, + drop_last): + if kind == _DatasetKind.MAP: + return _MapDatasetFetcher(dataset, auto_collate_batch, collate_fn, + drop_last) + elif kind == _DatasetKind.ITER: + return _IterableDatasetFetcher(dataset, auto_collate_batch, + collate_fn, drop_last) + else: + raise NotImplementedError("unknown Dataset kind {}".format(kind)) + + +class ParentWatchDog(object): + def __init__(self): + self._parent_pid = os.getppid() + self._parent_alive = True + + def is_alive(self): + if self._parent_alive: + self._parent_alive = os.getppid() == self._parent_pid + return self._parent_alive + + +# worker information for each workers, used for splitting data copy +# for IteratorDataset in worker processes. +_worker_info = None + + +def get_worker_info(): + """ + Get DataLoader worker process information function, this function is + used to split data copy in worker process for IterableDataset + (see :code:`paddle.io.IterableDataset`), worker information contains + following fields: + + :attr:`num_workers`: total worker process number, see `paddle.io.DataLoader` + + :attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1` + + :attr:`dataset`: the dataset object in this worker process + + Returns: + WorkerInfo: an instance of WorkerInfo which contains fields above. + + .. note:: + For more usage and examples, please see :code:`paddle.io.IterableDataset` + + Example: + + .. code-block:: python + + import math + import paddle + import numpy as np + from paddle.io import IterableDataset, DataLoader, get_worker_info + + class SplitedIterableDataset(IterableDataset): + def __init__(self, start, end): + self.start = start + self.end = end + + def __iter__(self): + worker_info = get_worker_info() + if worker_info is None: + iter_start = self.start + iter_end = self.end + else: + per_worker = int( + math.ceil((self.end - self.start) / float( + worker_info.num_workers))) + worker_id = worker_info.id + iter_start = self.start + worker_id * per_worker + iter_end = min(iter_start + per_worker, self.end) + + for i in range(iter_start, iter_end): + yield np.array([i]) + + place = paddle.CPUPlace() + dataset = SplitedIterableDataset(start=2, end=9) + dataloader = DataLoader( + dataset, + places=place, + num_workers=2, + batch_size=1, + drop_last=True) + + for data in dataloader: + print(data) + # outputs: [2, 5, 3, 6, 4, 7] + + """ + return _worker_info + + +class WorkerInfo(object): + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError("Cannot assign attributes to {} objects".format( + self.__class__.__name__)) + return super(WorkerInfo, self).__setattr__(key, val) + + +class _WorkerException(object): + def __init__(self, worker_id, exc_info=None): + self.worker_id = worker_id + exc_info = exc_info or sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + def reraise(self): + msg = "DataLoader worker({}) caught {} with message:\n{}".format( + self.worker_id, self.exc_type.__name__, self.exc_msg) + if getattr(self.exc_type, "message", None): + raise self.exc_type(message=msg) + raise self.exc_type(msg) + + +def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, + auto_collate_batch, collate_fn, init_fn, worker_id, + num_workers, use_shared_memory): + try: + # NOTE: [ mmap files clear ] When the child process exits unexpectedly, + # some shared memory objects may have been applied for but have not yet + # been put into the inter-process Queue. This part of the object needs + # to be cleaned up when the process ends. + CleanupFuncRegistrar.register(_cleanup_mmap) + + # set signal handler + core._set_process_signal_handler() + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, dataset=dataset) + + init_exception = None + try: + if init_fn is not None: + init_fn(worker_id) + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collate_batch, collate_fn, True) + except: + init_exception = _WorkerException(worker_id) + + iterator_drained = False + parent_watch_dog = ParentWatchDog() + + while parent_watch_dog.is_alive(): + try: + data = indices_queue.get(MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + + # None as poison piil, so worker event should be set + if data is None: + assert done_event.is_set() or iterator_drained, \ + "get None when worker done_event set" + break + # If worker done event is set but get still get data in + # indices_queue, remaining data should be get and skipped. + if done_event.is_set() or iterator_drained: + continue + + idx, indices = data + try: + if init_exception is not None: + batch = init_exception + init_exception = None + else: + # NOTE: GPU tensor operation is not supported in sub-process + # but default device is GPU in paddle-gpu version, which + # may copy CPU tensor to GPU even if users want to use + # CPU tensor operation, so we add CPUPlace guard here + # to make sure tensor will be operated only on CPU + with paddle.fluid.dygraph.guard(place=paddle.CPUPlace()): + batch = fetcher.fetch(indices) + except Exception as e: + if isinstance( + e, StopIteration) and dataset_kind == _DatasetKind.ITER: + out_queue.put(_IterableDatasetStopIteration(worker_id)) + iterator_drained = True + else: + out_queue.put((idx, _WorkerException(worker_id), None)) + else: + if isinstance(batch, _WorkerException): + out_queue.put((idx, batch, None)) + batch, structure = _flatten_batch(batch) + if use_shared_memory: + tensor_list = core._convert_to_tensor_list(batch) + out_queue.put((idx, tensor_list, structure)) + core._remove_tensor_list_mmap_fds(tensor_list) + else: + out_queue.put((idx, batch, structure)) + except KeyboardInterrupt: + # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process + pass + except: + six.reraise(*sys.exc_info()) + finally: + if use_shared_memory: + _cleanup_mmap() diff --git a/python/paddle/fluid/multiprocess_utils.py b/python/paddle/fluid/multiprocess_utils.py index a63825e736..82fb0f60b0 100644 --- a/python/paddle/fluid/multiprocess_utils.py +++ b/python/paddle/fluid/multiprocess_utils.py @@ -25,6 +25,10 @@ if six.PY2: else: import queue +# multi-process worker check indices queue interval, avoid +# hanging in subprocess data loading +MP_STATUS_CHECK_INTERVAL = 5. + # NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading, # the data in the queue needs to be popped. Then the LoDTensor read by the main process # from the child process will automatically clear the memory-mapped file. diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py index 39fc965e5e..977882543a 100755 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py @@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset): assert isinstance(label, paddle.Tensor) +class ComplextDataset(Dataset): + def __init__(self, sample_num): + self.sample_num = sample_num + + def __len__(self): + return self.sample_num + + def __getitem__(self, idx): + return (3.1, 'abc', paddle.to_tensor( + np.random.random([IMAGE_SIZE]).astype('float32'), + place=paddle.CPUPlace()), + [1, np.random.random([2]).astype('float32')], { + 'a': 2.0, + 'b': np.random.random([2]).astype('float32') + }) + + +class TestComplextDataset(unittest.TestCase): + def run_main(self, num_workers): + paddle.static.default_startup_program().random_seed = 1 + paddle.static.default_main_program().random_seed = 1 + place = paddle.CPUPlace() + with fluid.dygraph.guard(place): + dataset = ComplextDataset(16) + assert len(dataset) == 16 + dataloader = DataLoader( + dataset, + places=place, + num_workers=num_workers, + batch_size=2, + drop_last=True) + + for i, data in enumerate(dataloader()): + assert len(data) == 5 + # data[0]: collate 3.1 + assert data[0].shape == [2] + assert isinstance(data[1], list) + # data[1]: collate 'abc' + assert len(data[1]) == 2 + assert isinstance(data[1][0], str) + assert isinstance(data[1][1], str) + # data[2]: collate tensor + assert data[2].shape == [2, IMAGE_SIZE] + # data[3]: collate list + assert isinstance(data[3], list) + assert data[3][0].shape == [2] + assert data[3][1].shape == [2, 2] + # data[4]: collate dict + assert isinstance(data[4], dict) + assert data[4]['a'].shape == [2] + assert data[4]['b'].shape == [2, 2] + + def test_main(self): + for num_workers in [0, 2]: + self.run_main(num_workers) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py index 5620513358..d2b7971a85 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py @@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase): rets = [] for d in dataloader: - rets.append(d[0].numpy()[0][0]) + rets.append(d.numpy()[0][0]) assert tuple(sorted(rets)) == tuple(range(0, 10)) @@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase): rets = [] for d in dataloader: - rets.append(d[0].numpy()[0][0]) + rets.append(d.numpy()[0][0]) assert tuple(sorted(rets)) == tuple(range(0, 10)) -- GitLab