diff --git a/paddle/fluid/imperative/data_loader.cc b/paddle/fluid/imperative/data_loader.cc index 18836ed07a0b4888f4ee59466e789350bf4a2f2f..3b8239e566d21b2a15f0829c52b92ba2aa23d4f3 100644 --- a/paddle/fluid/imperative/data_loader.cc +++ b/paddle/fluid/imperative/data_loader.cc @@ -22,21 +22,23 @@ #include #include #include +#include +#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace imperative { -static std::map load_process_pids; +static std::map> load_process_pids; -void SetLoadProcessPID(int64_t key, pid_t pid) { - VLOG(3) << "Dygraph Data Loader: set loader child process PID (" << key - << ", " << pid << ")"; - load_process_pids[key] = pid; +void SetLoadProcessPIDs(int64_t key, std::set pids) { + VLOG(3) << "DataLoader: set loader child process PID (" << key + << ", pid number: " << pids.size() << ")"; + load_process_pids[key] = pids; } -void EraseLoadProcessPID(int64_t key) { +void EraseLoadProcessPIDs(int64_t key) { auto it = load_process_pids.find(key); // Note: Can not find key also possible if (it != load_process_pids.end()) { @@ -54,17 +56,21 @@ void EraseLoadProcessPID(int64_t key) { // siginfo_t doc: https://www.mkssoftware.com/docs/man5/siginfo_t.5.asp // waitid doc: https://linux.die.net/man/2/waitid -#define SIGNAL_HANDLE(SIGNAL) \ - do { \ - struct sigaction sa; \ - sa.sa_handler = SIG_DFL; \ - sa.sa_flags = 0; \ - if (sigemptyset(&sa.sa_mask) != 0 || \ - sigaction(SIGNAL, &sa, nullptr) != 0) { \ - _exit(EXIT_FAILURE); \ - } else { \ - raise(SIGNAL); \ - } \ +// clear mmap fds on signal handler, make sure mmap clear will be called +// on signal handling and no need to register mmap clear up handler on +// python side. If shared memory is not used Clear() will do nothing. +#define SIGNAL_HANDLE(SIGNAL) \ + do { \ + memory::allocation::MemoryMapFdSet::Instance().Clear(); \ + struct sigaction sa; \ + sa.sa_handler = SIG_DFL; \ + sa.sa_flags = 0; \ + if (sigemptyset(&sa.sa_mask) != 0 || \ + sigaction(SIGNAL, &sa, nullptr) != 0) { \ + _exit(EXIT_FAILURE); \ + } else { \ + raise(SIGNAL); \ + } \ } while (0) #define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \ @@ -106,44 +112,62 @@ void SetLoadProcessSignalHandler() { void ThrowErrorIfLoadProcessFailed() { int error; + std::set *pids_set; pid_t process_pid; siginfo_t infop; - for (auto &w : load_process_pids) { - process_pid = w.second; - // Use waitid rather than waitpid so that we can set NOWAIT, and that Python - // and other handlers can get whatever info they want about the child. - infop.si_pid = 0; - VLOG(3) << "Dygraph Data Loader: monitor loader child process " - << process_pid; - error = waitid(P_PID, process_pid, &infop, WEXITED | WNOHANG | WNOWAIT); - // ignore errors and case with no waitable child - if (error < 0 || infop.si_pid == 0) continue; - if (infop.si_code == CLD_EXITED && - infop.si_status != EXIT_SUCCESS) { // exit with error - PADDLE_THROW(platform::errors::Fatal( - "DataLoader process (pid %ld) exited unexpectedly with code %d. " - "Error detailed are lost due to multiprocessing. Rerunning with " - "DataLoader.from_generator(..., use_multiprocess=False) may give " - "better error trace.", - process_pid, infop.si_status)); - } else if (infop.si_code == CLD_KILLED || - infop.si_code == CLD_DUMPED) { // killed by signal - if (infop.si_status == SIGBUS) { + for (auto &p : load_process_pids) { + pids_set = &(p.second); + for (auto pid_it = pids_set->begin(); pid_it != pids_set->end(); ++pid_it) { + process_pid = *pid_it; + // Use waitid rather than waitpid so that we can set NOWAIT, and that + // Python and other handlers can get whatever info they want about the + // child. + infop.si_pid = 0; + VLOG(3) << "DataLoader: monitor loader child process " << process_pid; + error = waitid(P_PID, process_pid, &infop, WEXITED | WNOHANG | WNOWAIT); + // ignore errors and case with no waitable child + if (error < 0 || infop.si_pid == 0) continue; + if (infop.si_code == CLD_EXITED && + infop.si_status != EXIT_SUCCESS) { // exit with error + pids_set->clear(); PADDLE_THROW(platform::errors::Fatal( - "DataLoader process (pid %ld) exited is killed by signal: %s.\n" - " It may be caused by insufficient shared storage space. This " - "problem usually occurs when using docker as a development " - "environment.\n Please use command `df -h` to check the storage " - "space of `/dev/shm`. Shared storage space needs to be greater " - "than (DataLoader Num * DataLoader queue capacity * 1 batch data " - "size).\n You can solve this problem by increasing the shared " - "storage space or reducing the queue capacity appropriately.", - process_pid, strsignal(infop.si_status))); - } else { - PADDLE_THROW(platform::errors::Fatal( - "DataLoader process (pid %ld) exited is killed by signal: %s.", - process_pid, strsignal(infop.si_status))); + "DataLoader process (pid %ld) exited unexpectedly with code %d. " + "Error detailed are lost due to multiprocessing. Rerunning with:\n" + " 1. If run DataLoader by DataLoader.from_generator(...), run " + "with " + "DataLoader.from_generator(..., use_multiprocess=False) may give " + "better error trace.\n" + " 2. If run DataLoader by DataLoader(dataset, ...), run with " + "DataLoader(dataset, ..., num_workers=0) may give better error " + "trace", + process_pid, infop.si_status)); + } else if (infop.si_code == CLD_KILLED || + infop.si_code == CLD_DUMPED) { // killed by signal + if (infop.si_status == SIGBUS) { + pids_set->clear(); + PADDLE_THROW(platform::errors::Fatal( + "DataLoader process (pid %ld) exited is killed by signal: %s.\n" + " It may be caused by insufficient shared storage space. This " + "problem usually occurs when using docker as a development " + "environment.\n Please use command `df -h` to check the storage " + "space of `/dev/shm`. Shared storage space needs to be greater " + "than (DataLoader Num * DataLoader queue capacity * 1 batch data " + "size).\n You can solve this problem by increasing the shared " + "storage space or reducing the queue capacity appropriately.\n", + " 1. If run DataLoader by DataLoader.from_generator(...), queue " + "capacity is set by from_generator(..., capacity=xx, ...).\n" + " 2. If run DataLoader by DataLoader(dataset, ...), queue " + "capacity is set as 2 times of the max value of num_workers and " + "len(places).\n" + " 3. If run by DataLoader(dataset, ..., use_shared_memory=True)," + " set use_shared_memory=False for not using shared memory.", + process_pid, strsignal(infop.si_status))); + } else { + PADDLE_THROW(platform::errors::Fatal( + "DataLoader process (pid %ld) exited is killed by signal: %s.", + process_pid, strsignal(infop.si_status))); + } } } } diff --git a/paddle/fluid/imperative/data_loader.h b/paddle/fluid/imperative/data_loader.h index 99dce7a2e39d89d995d78fb533eae0110bf1aea7..fdfa117eafe762c58e55e4b2eecc42beafa11dc6 100644 --- a/paddle/fluid/imperative/data_loader.h +++ b/paddle/fluid/imperative/data_loader.h @@ -18,12 +18,13 @@ #include #include +#include namespace paddle { namespace imperative { -extern void SetLoadProcessPID(int64_t key, pid_t pid); -extern void EraseLoadProcessPID(int64_t key); +extern void SetLoadProcessPIDs(int64_t key, std::set pids); +extern void EraseLoadProcessPIDs(int64_t key); extern void SetLoadProcessSignalHandler(); extern void ThrowErrorIfLoadProcessFailed(); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index df40b56df1e2c388fca4cdc4c8670462a628dede..93a3137b9991036e161c7863e4acc4b6ccf5a711 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include @@ -290,11 +291,22 @@ void BindImperative(py::module *m_ptr) { #ifndef _WIN32 // Dygraph DataLoader signal handler - m.def("_set_process_pid", [](int64_t key, pid_t pid) { - imperative::SetLoadProcessPID(key, pid); + m.def("_set_process_pids", [](int64_t key, py::object &obj) { + PADDLE_ENFORCE_EQ( + py::isinstance(obj) || py::isinstance(obj), true, + platform::errors::InvalidArgument( + "The subprocess ids set in DataLoader is illegal." + "Expected data type is tuple or list, but received %s", + obj.get_type())); + py::list pids = py::cast(obj); + std::set pids_set = {}; + for (size_t i = 0; i < pids.size(); i++) { + pids_set.insert(pids[i].cast()); + } + imperative::SetLoadProcessPIDs(key, pids_set); }); - m.def("_erase_process_pid", - [](int64_t key) { imperative::EraseLoadProcessPID(key); }); + m.def("_erase_process_pids", + [](int64_t key) { imperative::EraseLoadProcessPIDs(key); }); m.def("_set_process_signal_handler", []() { imperative::SetLoadProcessSignalHandler(); }); m.def("_throw_error_if_process_failed", diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index b6295e41198b23a3ae270013b3e1895c7f08c0d8..3b55ec7ffce4df7e96034df067f667f2c807943f 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -252,7 +252,9 @@ def get_cluster(node_ips, node_ip, paddle_ports, selected_gpus): def terminate_local_procs(procs): for p in procs: if p.proc.poll() is None: - p.proc.terminate() + # subprocess need to release resource(e.g. shared memory) + # use join to wait subprocess releasing + p.proc.join(timeout=1) p.log_fn.close() logger.debug("terminate process id:{}".format(p.proc.pid)) diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 71ed95a90673a6768338a17f17de0052c962e6d9..c3fbb7b51b5ad2b0f2701f325a0d81df0b0ede79 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -185,8 +185,8 @@ if avx_supported(): from .core_avx import _load_dygraph_dict from .core_avx import _create_loaded_parameter if sys.platform != 'win32': - from .core_avx import _set_process_pid - from .core_avx import _erase_process_pid + from .core_avx import _set_process_pids + from .core_avx import _erase_process_pids from .core_avx import _set_process_signal_handler from .core_avx import _throw_error_if_process_failed from .core_avx import _convert_to_tensor_list @@ -229,8 +229,8 @@ if load_noavx: from .core_noavx import _load_dygraph_dict from .core_noavx import _create_loaded_parameter if sys.platform != 'win32': - from .core_noavx import _set_process_pid - from .core_noavx import _erase_process_pid + from .core_noavx import _set_process_pids + from .core_noavx import _erase_process_pids from .core_noavx import _set_process_signal_handler from .core_noavx import _throw_error_if_process_failed from .core_noavx import _convert_to_tensor_list diff --git a/python/paddle/fluid/dataloader/__init__.py b/python/paddle/fluid/dataloader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62aefd6aec8cb92faf4595a821381f35c2abd5bd --- /dev/null +++ b/python/paddle/fluid/dataloader/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import dataset +from .dataset import * + +from . import batch_sampler +from .batch_sampler import * + +__all__ = dataset.__all__ \ + + batch_sampler.__all__ diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c6163f7da1ee6de139b7503e8e6a6c3722ea8a7b --- /dev/null +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -0,0 +1,143 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import numpy as np +from .dataset import Dataset + +__all__ = ["BatchSampler"] + + +class BatchSampler(object): + """ + A base implement of batch sampler used by `paddle.io.DataLoader` + which yield mini-batch indices(a list/tuple with length as + mini-batch size and holds sample indices) iterably. + + Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass + of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should + implement following methods: + + :code:`__iter__`: return mini-batch indices iterably. + + :code:`__len__`: get mini-batch number in an epoch. + + + Args: + dataset(Dataset): this could be a :code:`paddle.io.Dataset` + implement or other python object which implemented + :code:`__len__` for BatchSampler to get indices as the + range of :attr:`dataset` length. Default None. + indices (list|tuple): a substitution parameter for + :attr:`dataset` either :attr:`dataset` or + :attr:`indices` should be set, give the whole + indices to sampler from directly. Default None. + shuffle(bool): whether to shuffle indices order before genrating + batch indices. Default False. + batch_size(int): sample indice number in a mini-batch indices. + drop_last(bool): whether drop the last incomplete batch dataset size + is not divisible by the batch size. Default False + + Returns: + BatchSampler: an iterable object for indices iterating + + Examples: + + .. code-block:: python + + from paddle.io import BatchSampler, Dataset + + # init with indices + bs = BatchSampler(indices=list(range(100)), + shuffle=True, + batch_size=8, + drop_last=True) + + for batch_indices in bs: + print(batch_indices) + + # init with dataset + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + bs = BatchSampler(dataset=RandomDataset(100), + shuffle=False, + batch_size=16, + drop_last=False) + + for batch_indices in bs: + print(batch_indices) + + see `paddle.io.DataLoader` + + """ + + def __init__(self, + dataset=None, + indices=None, + shuffle=False, + batch_size=1, + drop_last=False): + if dataset is None: + assert indices is not None, \ + "either dataset or indices should be set" + assert isinstance(indices, list) or isinstance(indices, tuple), \ + "indices should be a list or tuple, but got {}".format(type(indices)) + self.indices = indices + else: + assert isinstance(dataset, Dataset), \ + "dataset should be an instance of paddle.io.Dataset" + assert indices is None, \ + "should not set both dataset and indices" + self.indices = list(range(len(dataset))) + + assert isinstance(batch_size, int) and batch_size > 0, \ + "batch_size should be a positive integer, but got {}".format(batch_size) + self.batch_size = batch_size + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value, but got {}".format(type(shuffle)) + self.shuffle = shuffle + assert isinstance(drop_last, bool), \ + "drop_last should be a boolean value, but got {}".format(type(drop_last)) + self.drop_last = drop_last + + def __iter__(self): + if self.shuffle: + np.random.shuffle(self.indices) + _iter = iter(self.indices) + + batch_indices = [] + for idx in _iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices + + def __len__(self): + num_samples = len(self.indices) + num_samples += int(not self.drop_last) * (self.batch_size - 1) + return num_samples // self.batch_size diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6e05248f72d5a0499138586b25f6f35c4822af --- /dev/null +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -0,0 +1,528 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import six +import sys +import time +import signal +import logging +import itertools +import threading +import numpy as np +import multiprocessing + +# NOTE: queue has a different name in python2 and python3 +if six.PY2: + import Queue as queue +else: + import queue + +from .. import core +from ..framework import in_dygraph_mode +from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler + +# multi-process worker check indices queue interval, avoid +# hanging in subprocess data loading +MP_INDICES_CHECK_INTERVAL = 5 + + +def _default_collate_fn(batch): + sample = batch[0] + # dataset has only 1 field + if isinstance(sample, np.ndarray): + return [np.stack(batch, axis=0)] + + # batch each field + slots = [] + for items in batch: + for i, item in enumerate(items): + if len(slots) < len(items): + slots.append([item]) + else: + slots[i].append(item) + return [np.stack(slot, axis=0) for slot in slots] + + +class ParentWatchDog(object): + def __init__(self): + self._parent_pid = os.getppid() + self._parent_alive = True + + def is_alive(self): + if self._parent_alive: + self._parent_alive = os.getppid() == self._parent_pid + return self._parent_alive + + +class _DataLoaderIterBase(object): + """ + Iterator implement of DataLoader, will load and feed mini-batch + data by setting in given dataloader. + + Args: + loader(instance of DataLoader): instance of `fluid.io.DataLoader` + """ + + def __init__(self, loader): + self._dataset = loader.dataset + self._feed_list = loader.feed_list or [] + self._places = loader.places + self._return_list = loader.return_list + self._batch_sampler = loader.batch_sampler + self._sampler_iter = iter(loader.batch_sampler) + self._collate_fn = loader.collate_fn or _default_collate_fn + self._num_workers = loader.num_workers + self._use_buffer_reader = loader.use_buffer_reader + self._use_shared_memory = loader.use_shared_memory + self._timeout = loader.timeout if loader.timeout > 0 else MP_INDICES_CHECK_INTERVAL + self._worker_init_fn = loader.worker_init_fn + + # LoDTensorBlockingQueue instance for create_py_reader and a thread + # to put mini-batch data to self._blocking_queue, mini-batch data + # will be get from: + # 1. multi-process mode: get data from workers' result queue + # 2. single-process mode: read mini-batch data in main process + self._blocking_queue = None + self._thread = None + self._thread_done_event = threading.Event() + + def __iter__(self): + return self + + def __len__(self): + return len(self._batch_sampler) + + +class _DataLoaderIterSingleProcess(_DataLoaderIterBase): + """ + Single process implement of DataLoaderIter, loading data from + loader.data in main process + """ + + def __init__(self, loader): + super(_DataLoaderIterSingleProcess, self).__init__(loader) + + # NOTE: len(self._places) batch data compose as an output + # iteration, set blocking_queue can cache 2 iteration datas + # at most here + self._blocking_queue_capacity = 2 * len(self._places) + + self._init_thread() + + def _init_thread(self): + self._var_names = [v.name for v in self._feed_list] + self._shapes = [v.shape for v in self._feed_list] + self._dtypes = [v.dtype for v in self._feed_list] + self._need_check_feed = [ + v.desc.need_check_feed() for v in self._feed_list + ] + self._blocking_queue = core.init_lod_tensor_blocking_queue( + core.Variable(), self._blocking_queue_capacity, True) + self._reader = core.create_py_reader( + self._blocking_queue, self._var_names, self._shapes, self._dtypes, + self._need_check_feed, self._places, self._use_buffer_reader, True) + + self._thread = threading.Thread(target=self._thread_loop) + self._thread.daemon = True + self._thread.start() + + def _thread_loop(self): + try: + for indices in self._sampler_iter: + # read data from dataset in mini-batch + batch = [self._dataset[i] for i in indices] + if self._collate_fn is not None: + batch = self._collate_fn(batch) + + # pack as LoDTensorArray + array = core.LoDTensorArray() + for slot in batch: + if not isinstance(slot, core.LoDTensor): + self._check_input_array(slot) + tmp = core.LoDTensor() + tmp.set(slot, core.CPUPlace()) + slot = tmp + + array.append(slot) + + if not self._blocking_queue.push(array): + break + + self._blocking_queue.close() + self._thread = None + except Exception: + self._blocking_queue.kill() + self._thread = None + logging.warning("DataLoader reader thread raised an exception.") + six.reraise(*sys.exc_info()) + + @classmethod + def _check_input_array(cls, item): + arr = np.array(item) + if arr.dtype == np.object: + raise TypeError(( + "\n\tFaild to convert input data to a regular ndarray :\n\t* Usually " + "this means the input data contains nested lists with different lengths. " + "\n\t* Check the reader function passed to 'decorate_batch_generator'" + " to locate the data causes this issue.\n\t* Please consider using " + "'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")) + + def __next__(self): + try: + if in_dygraph_mode(): + return self._reader.read_next_var_list() + else: + if self._return_list: + return self._reader.read_next_list() + else: + return self._reader.read_next() + except StopIteration: + self._reader.reset() + six.reraise(*sys.exc_info()) + + # python2 compatibility + def next(self): + return self.__next__() + + +class _DataLoaderIterMultiProcess(_DataLoaderIterBase): + def __init__(self, loader): + super(_DataLoaderIterMultiProcess, self).__init__(loader) + + assert self._num_workers > 0, "Multi-process DataLoader " \ + "invalid num_workers({})".format(self._num_workers) + + # subprocess wrokers' result queue + self._data_queue = None + + # data get from _data_queue will be reordered by _rcvd_idx + # for data order keeping, data index not equal _rcvd_idx + # will be cached in _reorder_dict + self._send_idx = 0 + self._rcvd_idx = 0 + self._batches_outstanding = 0 + self._reorder_dict = {} + + # indices outstand as _outstanding_capacity at first, and + # blocking_queue capacity is also _outstanding_capacity. + # _outstanding_capacity here to make sure each indices_queue + # has at least 2 indices, and outstanding batch cached + # output data for at least 2 iterations(Note that len(_places) + # batches will be composed as an iteration output) + self._outstanding_capacity = 2 * max(self._num_workers, + len(self._places)) + + self._init_workers() + self._init_thread() + + self._shutdown = False + + for _ in range(self._outstanding_capacity): + self._try_put_indices() + + def _init_workers(self): + # multiprocess worker and indice queue list initial as empty + self._workers = [] + self._worker_status = [] + self._indices_queues = [] + self._workers_idx_cycle = itertools.cycle(range(self._num_workers)) + + # create data_queue for workers + self._data_queue = multiprocessing.Queue() + + # event for workers and thread, thread event is only need + # in multi-processing mode + self._workers_done_event = multiprocessing.Event() + self._thread_done_event = threading.Event() + + for i in range(self._num_workers): + indices_queue = multiprocessing.Queue() + self._indices_queues.append(indices_queue) + worker = multiprocessing.Process( + target=self._worker_loop, + args=(self._dataset, indices_queue, self._data_queue, + self._workers_done_event, self._collate_fn, + self._worker_init_fn, i)) + worker.daemon = True + worker.start() + self._workers.append(worker) + self._worker_status.append(True) + + core._set_process_pids(id(self), tuple(w.pid for w in self._workers)) + _set_SIGCHLD_handler() + + def _clear_and_remove_data_queue(self): + if self._data_queue is not None: + while True: + try: + self._data_queue.get_nowait() + except: + self._data_queue.cancel_join_thread() + self._data_queue.close() + break + + def _init_thread(self): + self._var_names = [v.name for v in self._feed_list] + self._shapes = [v.shape for v in self._feed_list] + self._dtypes = [v.dtype for v in self._feed_list] + self._need_check_feed = [ + v.desc.need_check_feed() for v in self._feed_list + ] + self._blocking_queue = core.init_lod_tensor_blocking_queue( + core.Variable(), self._outstanding_capacity, True) + self._reader = core.create_py_reader( + self._blocking_queue, self._var_names, self._shapes, self._dtypes, + self._need_check_feed, self._places, self._use_buffer_reader, True) + + self._thread_done_event = threading.Event() + self._thread = threading.Thread(target=self._thread_loop) + self._thread.daemon = True + self._thread.start() + + def _shutdown_worker(self, worker_id): + if self._worker_status[worker_id]: + self._indices_queues[worker_id].put(None) + self._worker_status[worker_id] = False + + def _try_shutdown_all(self): + if not self._shutdown: + try: + self._exit_thread_expectedly() + self._clear_and_remove_data_queue() + + # set _workers_done_event should be set before put None + # to indices_queue, workers wll exit on reading None from + # indices_queue + self._workers_done_event.set() + for i in range(self._num_workers): + self._shutdown_worker(i) + + for w in self._workers: + w.join() + for q in self._indices_queues: + q.cancel_join_thread() + q.close() + finally: + core._erase_process_pids(id(self)) + self._shutdown = True + + def _exit_thread_expectedly(self): + self._thread_done_event.set() + self._blocking_queue.close() + + def _exit_thread_unexpectedly(self): + self._thread_done_event.set() + self._blocking_queue.kill() + logging.error("DataLoader reader thread raised an exception!") + + def _worker_loop(self, dataset, indices_queue, out_queue, done_event, + collate_fn, init_fn, worker_id): + try: + # NOTE: [ mmap files clear ] When the child process exits unexpectedly, + # some shared memory objects may have been applied for but have not yet + # been put into the inter-process Queue. This part of the object needs + # to be cleaned up when the process ends. + CleanupFuncRegistrar.register(_cleanup_mmap) + + # set signal handler + core._set_process_signal_handler() + + init_exception = None + if init_fn is not None: + try: + init_fn(worker_id) + except: + init_exception = Exception("init_fn failed in worker {}: " \ + "{}".format(worker_id, sys.exc_info())) + + parent_watch_dog = ParentWatchDog() + + while parent_watch_dog.is_alive(): + try: + data = indices_queue.get(MP_INDICES_CHECK_INTERVAL) + except queue.Empty: + continue + + # None as poison piil, so worker event should be set + if data is None: + assert done_event.is_set( + ), "get None when worker done_event set" + break + # If worker done event is set but get still get data in + # indices_queue, remaining data should be get and skipped. + if done_event.is_set(): + continue + + idx, indices = data + try: + if init_exception is not None: + batch = init_exception + init_exception = None + else: + batch = [dataset[i] for i in indices] + if self._collate_fn is not None: + batch = self._collate_fn(batch) + except Exception as e: + out_queue.put((idx, e)) + else: + if self._use_shared_memory: + tensor_list = core._convert_to_tensor_list(batch) + out_queue.put((idx, tensor_list)) + core._remove_tensor_list_mmap_fds(tensor_list) + else: + out_queue.put((idx, batch)) + except KeyboardInterrupt: + # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process + pass + except: + six.reraise(*sys.exc_info()) + finally: + if self._use_shared_memory: + _cleanup_mmap() + + def _thread_loop(self): + while not self._thread_done_event.is_set(): + batch = self._get_data() + if not self._thread_done_event.is_set(): + if batch is None: + self._exit_thread_expectedly() + elif isinstance(batch, Exception): + self._exit_thread_unexpectedly() + else: + try: + # pack as LoDTensorArray + array = core.LoDTensorArray() + if self._use_shared_memory: + for tensor in batch: + array.append(tensor) + else: + # LoDTensor not in shared memory is not + # serializable, cannot be create in workers + for slot in batch: + if not isinstance(slot, core.LoDTensor): + # self._check_input_array(slot) + tmp = core.LoDTensor() + tmp.set(slot, core.CPUPlace()) + slot = tmp + array.append(slot) + + if not self._blocking_queue.push(array): + self._blocking_queue.close() + except: + self._exit_thread_unexpectedly() + six.reraise(*sys.exc_info()) + finally: + self._rcvd_idx += 1 + + def _get_data(self): + if self._rcvd_idx in self._reorder_dict.keys(): + return self._reorder_dict.pop(self._rcvd_idx) + + while not self._thread_done_event.is_set(): + try: + # [ avoid hang ]: main process may blocking at _reader.read_next when + # KeyboardInterrupt, we do following tradeoff: + # 1. get data with timeout, MP_INDICES_CHECK_INTERVAL(5s) as timeout + # default, if KeyboardInterrupt blocking, failed workers will be + # checked and raise RuntimeError to quit DataLoader in timeout + # exception handling. + # 2. if get data timeout and check workers all alive, continue to + # get data again + data = self._data_queue.get(timeout=self._timeout) + except Exception as e: + failed_workers = [] + for i, w in enumerate(self._workers): + if self._worker_status[i] and not w.is_alive(): + failed_workers.append(w) + self._shutdown_worker(i) + if len(failed_workers) > 0: + self._exit_thread_unexpectedly() + pids = ', '.join(str(w.pid) for w in failed_workers) + raise RuntimeError("DataLoader {} workers exit unexpectedly, " \ + "pids: {}".format(len(failed_workers), pids)) + + # get(timeout) will call _poll(timeout) and may raise IOError + if isinstance(e, queue.Empty) or isinstance(e, IOError): + # continue on timeout to keep getting data from queue + continue + + self._exit_thread_unexpectedly() + logging.error("DataLoader reader thread failed({}) to read data from " \ + "workers' result queue.".format(e)) + six.reraise(*sys.exc_info()) + else: + idx, batch = data + if idx == self._rcvd_idx: + return batch + else: + self._reorder_dict[idx] = batch + continue + + def _try_put_indices(self): + assert self._send_idx - self._rcvd_idx <= self._outstanding_capacity, \ + "too many indices have been put to queue" + try: + indices = next(self._sampler_iter) + except StopIteration: + return + + worker_idx = next(self._workers_idx_cycle) + self._indices_queues[worker_idx].put((self._send_idx, indices)) + self._batches_outstanding += 1 + self._send_idx += 1 + + def __del__(self): + self._try_shutdown_all() + + def __next__(self): + try: + # _batches_outstanding here record the total batch data number + # in 'from after _try_put_indices to beforeoutput data', this + # value should be _outstanding_capacity if data is not drained, + # if _batches_outstanding is less than _places number, there are + # no enough data to generate next output, close blocking_queue and + # set _thread_done_event here, py_reader will raise StopIteration, + # end workers and indices_queues in StopIteration handling + if self._batches_outstanding < len(self._places): + self._thread_done_event.set() + self._blocking_queue.close() + + if in_dygraph_mode(): + data = self._reader.read_next_var_list() + else: + if self._return_list: + data = self._reader.read_next_list() + # static graph organized data on multi-device with list, if + # place number is 1, there is only 1 device, extra the data + # from list for devices to be compatible with dygraph mode + if len(self._places) == 1: + data = data[0] + else: + data = self._reader.read_next() + self._on_output_batch() + return data + except StopIteration: + self._reader.reset() + self._try_shutdown_all() + six.reraise(*sys.exc_info()) + + # python2 compatibility + def next(self): + return self.__next__() + + def _on_output_batch(self): + for _ in range(len(self._places)): + self._batches_outstanding -= 1 + self._try_put_indices() diff --git a/python/paddle/fluid/dataloader/dataset.py b/python/paddle/fluid/dataloader/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b49ceaddefdef272d34ea9008d57d3f1ced43205 --- /dev/null +++ b/python/paddle/fluid/dataloader/dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle.dataset.common + +__all__ = ["Dataset"] + + +class Dataset(object): + """ + An abstract class to encapsulates methods and behaviors of datasets. + + All datasets in map-style(dataset samples can be get by a given key) + should be a subclass of `paddle.io.Dataset`. All subclasses should + implement following methods: + + :code:`__getitem__`: get sample from dataset with a given index. This + method is required by reading dataset sample in :code:`paddle.io.DataLoader`. + + :code:`__len__`: return dataset sample number. This method is required + by some implements of :code:`paddle.io.BatchSampler` + + see :code:`paddle.io.DataLoader`. + + Examples: + + .. code-block:: python + + import numpy as np + from paddle.io import Dataset + + # define a random dataset + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + dataset = RandomDataset(10) + for i in range(len(dataset)): + print(dataset[i]) + + """ + + def __init__(self): + pass + + def __getitem__(self, idx): + raise NotImplementedError("'{}' not implement in class "\ + "{}".format('__getitem__', self.__class__.__name__)) + + def __len__(self): + raise NotImplementedError("'{}' not implement in class "\ + "{}".format('__len__', self.__class__.__name__)) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index c92eebcdab73cb83ad0ba96baa3aed454a4ba40c..ba1480377245e67ba3b8419b843b9ae563a80cd0 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -37,6 +37,8 @@ from paddle.fluid.compiler import CompiledProgram from paddle.fluid.log_helper import get_logger from . import reader from .reader import * +from . import dataloader +from .dataloader import * from . import core from .. import compat as cpt diff --git a/python/paddle/fluid/multiprocess_utils.py b/python/paddle/fluid/multiprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a63825e73638b6af234b2596ce107c88bcfc94f9 --- /dev/null +++ b/python/paddle/fluid/multiprocess_utils.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import sys +import signal +import atexit + +from . import core + +# NOTE: queue has a different name in python2 and python3 +if six.PY2: + import Queue as queue +else: + import queue + +# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading, +# the data in the queue needs to be popped. Then the LoDTensor read by the main process +# from the child process will automatically clear the memory-mapped file. +multiprocess_queue_set = set() + + +def _clear_multiprocess_queue_set(): + global multiprocess_queue_set + for data_queue in multiprocess_queue_set: + while True: + try: + data_queue.get_nowait() + except queue.Empty: + break + + +# NOTE: main process clear function at exit +def _cleanup(): + # NOTE: inter-process Queue shared memory objects clear function + _clear_multiprocess_queue_set() + # NOTE: main process memory map files clear funciton + core._cleanup_mmap_fds() + + +# NOTE: for child process clear function at exit +def _cleanup_mmap(): + # clear memory map files in child process + core._cleanup_mmap_fds() + + +# NOTE used for register a function to be executed at interpreter exit. +class CleanupFuncRegistrar(): + # Record the cleanup functions that have been executed + _executed_func_set = set() + # Record the cleanup functions that have been registered + _registered_func_set = set() + + @classmethod + def register(cls, function, signals=[]): + def _func_exectuor(): + if function not in cls._executed_func_set: + try: + function() + finally: + cls._executed_func_set.add(function) + + def _func_register(function): + if not callable(function): + raise TypeError("%s is not callable object." % (function)) + # check function object whether hash-able + set([function]) + if function not in cls._registered_func_set: + atexit.register(_func_exectuor) + cls._registered_func_set.add(function) + + def _signal_handler(signum=None, frame=None): + _func_exectuor() + if signum is not None: + if signum == signal.SIGINT: + raise KeyboardInterrupt + sys.exit(signum) + + def _signal_register(signals): + signals = set(signals) + for sig in signals: + orig_handler = signal.signal(sig, _signal_handler) + if orig_handler not in (signal.SIG_DFL, signal.SIG_IGN): + if (sig == signal.SIGINT and + orig_handler is signal.default_int_handler): + continue + if orig_handler not in cls._registered_func_set: + atexit.register(orig_handler) + cls._registered_func_set.add(orig_handler) + + # deal with signals + _signal_register(signals) + # deal with function + _func_register(function) + + +# NOTE: [ mmap files clear ] When the main process exits unexpectedly, the remaining +# shared memory objects in the inter-process Queue and the main process (mostly in the +# BlockingQueue) may not be completely released, resulting in the corresponding +# memory-mapped file remaining on the disk (/dev/shm), so register this function +# to clean up shared memory objects in these two queues before the python interpreter exits. +# NOTE: Currently multi-process DataLoader only supports Linux platform +if not (sys.platform == 'darwin' or sys.platform == 'win32'): + CleanupFuncRegistrar.register(_cleanup) + +# ------------ SIGCHLD handler setting -------------- +_SIGCHLD_handler_set = False + + +def _set_SIGCHLD_handler(): + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + + current_handler = signal.getsignal(signal.SIGCHLD) + if not callable(current_handler): + current_handler = None + + def __handler__(signum, frame): + # NOTE: Here the signum is SIGCHLD, when the child process exits, + # this handler will be called whenever the child process exits + # normally or abnormally. + core._throw_error_if_process_failed() + if current_handler is not None: + current_handler(signum, frame) + + signal.signal(signal.SIGCHLD, __handler__) + _SIGCHLD_handler_set = True diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index fe5e9a2e045f3831a1e946b81de38ccf634f4d39..969cac4a6631ed627fd384f46a3f8610c11ff15c 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -21,29 +21,28 @@ import paddle from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .executor import global_scope from .data_feeder import DataFeeder, BatchedTensorProvider +from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler +from .dataloader import BatchSampler, Dataset +from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .unique_name import UniqueNameGenerator import logging from .dataset import DatasetBase, InMemoryDataset ### Dygraph DataLoader configs ### -import atexit import os import multiprocessing import signal + # NOTE: queue has a different name in python2 and python3 -if sys.version_info[0] == 2: +if six.PY2: import Queue as queue else: import queue + # NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process QUEUE_GET_TIMEOUT = 60 -# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading, -# the data in the queue needs to be popped. Then the LoDTensor read by the main process -# from the child process will automatically clear the memory-mapped file. -multiprocess_queue_set = set() - __all__ = ['PyReader', 'DataLoader'] data_loader_unique_name_generator = UniqueNameGenerator() @@ -75,84 +74,6 @@ def _convert_places(places): return ret -def _clear_multiprocess_queue_set(): - global multiprocess_queue_set - for data_queue in multiprocess_queue_set: - while True: - try: - data_queue.get_nowait() - except queue.Empty: - break - - -# NOTE: main process clear function at exit -def _cleanup(): - # NOTE: inter-process Queue shared memory objects clear function - _clear_multiprocess_queue_set() - # NOTE: main process memory map files clear funciton - core._cleanup_mmap_fds() - - -# NOTE used for register a function to be executed at interpreter exit. -class CleanupFuncRegistrar(): - # Record the cleanup functions that have been executed - _executed_func_set = set() - # Record the cleanup functions that have been registered - _registered_func_set = set() - - @classmethod - def register(cls, function, signals=[signal.SIGTERM]): - def _func_exectuor(): - if function not in cls._executed_func_set: - try: - function() - finally: - cls._executed_func_set.add(function) - - def _func_register(function): - if not callable(function): - raise TypeError("%s is not callable object." % (function)) - # check function object whether hash-able - set([function]) - if function not in cls._registered_func_set: - atexit.register(_func_exectuor) - cls._registered_func_set.add(function) - - def _signal_handler(signum=None, frame=None): - _func_exectuor() - if signum is not None: - if signum == signal.SIGINT: - raise KeyboardInterrupt - sys.exit(signum) - - def _signal_register(signals): - signals = set(signals) - for sig in signals: - orig_handler = signal.signal(sig, _signal_handler) - if orig_handler not in (signal.SIG_DFL, signal.SIG_IGN): - if (sig == signal.SIGINT and - orig_handler is signal.default_int_handler): - continue - if orig_handler not in cls._registered_func_set: - atexit.register(orig_handler) - cls._registered_func_set.add(orig_handler) - - # deal with signals - _signal_register(signals) - # deal with function - _func_register(function) - - -# NOTE: [ mmap files clear ] When the main process exits unexpectedly, the remaining -# shared memory objects in the inter-process Queue and the main process (mostly in the -# BlockingQueue) may not be completely released, resulting in the corresponding -# memory-mapped file remaining on the disk (/dev/shm), so register this function -# to clean up shared memory objects in these two queues before the python interpreter exits. -# NOTE: Currently multi-process DataLoader only supports Linux platform -if not (sys.platform == 'darwin' or sys.platform == 'win32'): - CleanupFuncRegistrar.register(_cleanup) - - class DataLoaderBase(object): def __init__(self): self._places = None @@ -177,6 +98,264 @@ class DataLoaderBase(object): class DataLoader(object): + """ + DataLoader prodives an iterator which iterates given dataset + once by the batch_sampler. + + DataLoader supports single-process and multi-prcess data loading, + multi-process workers will be used to load data asynchronously if + :attr:`num_workers` is set as a positive number. + + DataLoader only supports map-style dataset(can get a sample from + dataset with a given index) currently, for a map-style dataset, + please see :code:`paddle.io.Dataset`. + + batch_sampler please see :code:`paddle.io.BatchSampler` + + Args: + dataset(Dataset): the dataset to load data from, should be an + instance of subclass of :code:`paddle.io.Dataset`. + feed_list (list(Variable)|tuple(Variable)): feed variable list. + The variables should be created by :code:`fluid.data()`. + :attr:`feed_list` must be set if :attr:`return_list` is + False. Default None. + places(list(Place)|tuple(Place)): a list of Place, to put data + onto, :attr:`places` must be set in both static graph and + dynamic graph mode, in dynamic graph mode, place number must + be 1. Default None. + return_list (bool): whether the return value on each device is + presented as a list. If :attr:`return_list=False`, the return + value on each device would be a dict of str -> LoDTensor, where + the key of the dict is the name of each fed variables. If + :attr:`return_list=True`, the return value on each device would + be a list(LoDTensor). :attr:`return_list` can only be True + in dynamic graph mode. Default False. + batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` + to generate batch indices to draw samples from :attr:`dataset` + and combine a batch. Default None. + batch_size(int): sample number in a mini-batch, a substitution + parameter for :attr:`batch_sampler`, if :attr:`batch_sampler` + is not set, a default `paddle.io.BatchSampler` will be used + and initialize by :attr:`batch_size`, :attr:`shuffle` and + :attr:`drop_last`. Default 1. + shuffle(bool): whther to shuffle indices order before genrate + batch indices, a substitution parameter for :attr:`batch_sampler` + see :attr:`batch_size`. Default False. + drop_last(bool): whether drop the last incomplete batch dataset size + is not divisible by the batch size, a substitution parameter + for :attr:`batch_sampler`, see :attr:`batch_size`. Default False + collate_fn(callable): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0(same as :attr::`np.stack(..., axis=0)`). Default None + num_workers(int): the number of subprocess to load data, 0 for no + subprocess used and loading data in main process. Default 0 + use_buffer_reader (bool): whether to use bufferred reader. + If use_buffer_reader=True, the DataLoader would prefetch next + batch data asynchronously, so it would speed up data feeding + and occupies a little more CPU or GPU memory, i.e., the memory + of one batch input data. Default True. + use_shared_memory (bool): whether to use shared memory to speed up + putting data into inter-process queue, set :attr:`use_shared_memory` + as True only when the shared memory space on your machine(e.g. + space of '/dev/shm' on Linux operating sysytem) is large enough. + Shared memory will only be enabled in multi-process mode(num_workers + > 0). Default True. + timeout(int): the timeout value for getting data form output queue + of subprocesses. Default 0. + worker_init_fn(callable): init function which will be called with + worker id on each subproces starting if not set as None. Default + None. + + Returns: + DataLoader: an iterable object for data iterating + + Examples: + + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + from paddle.io import Dataset, BatchSampler, DataLoader + + BATCH_NUM = 20 + BATCH_SIZE = 16 + EPOCH_NUM = 4 + + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + USE_GPU = False # whether use GPU to run model + + # define a random dataset + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + # get places + places = fluid.cuda_places() if USE_GPU else fluid.cpu_places() + + # -------------------- static graph --------------------- + + def simple_net(image, label): + fc_tmp = fluid.layers.fc(image, size=CLASS_NUM, act='softmax') + cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label) + loss = fluid.layers.reduce_mean(cross_entropy) + sgd = fluid.optimizer.SGD(learning_rate=1e-3) + sgd.minimize(loss) + return loss + + image = fluid.data(name='image', shape=[None, IMAGE_SIZE], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + + loss = simple_net(image, label) + + exe = fluid.Executor(places[0]) + exe.run(fluid.default_startup_program()) + + prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) + + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + + loader = DataLoader(dataset, + feed_list=[image, label], + places=places, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + for e in range(EPOCH_NUM): + for i, data in enumerate(loader()): + l = exe.run(prog, feed=data, fetch_list=[loss], return_numpy=True) + print("Epoch {} batch {}: loss = {}".format(e, i, l[0][0])) + + # ------------------------------------------------------- + + # --------------------- dygraph mode -------------------- + + class SimpleNet(fluid.dygraph.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax') + + def forward(self, image, label=None): + return self.fc(image) + + with fluid.dygraph.guard(places[0]): + simple_net = SimpleNet() + opt = fluid.optimizer.SGD(learning_rate=1e-3, + parameter_list=simple_net.parameters()) + + loader = DataLoader(dataset, + places=places[0], + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + for e in range(EPOCH_NUM): + for i, (image, label) in enumerate(loader()): + out = simple_net(image) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.reduce_mean(loss) + avg_loss.backward() + opt.minimize(avg_loss) + simple_net.clear_gradients() + print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy()))) + + # ------------------------------------------------------- + + """ + + def __init__(self, + dataset, + feed_list=None, + places=None, + return_list=False, + batch_sampler=None, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + use_shared_memory=True, + timeout=0, + worker_init_fn=None): + self.return_list = return_list + self.collate_fn = collate_fn + self.use_buffer_reader = use_buffer_reader + self.worker_init_fn = worker_init_fn + + assert isinstance(dataset, Dataset), \ + "dataset should be subclass instance of paddle.io.Dataset" + self.dataset = dataset + + if not return_list and not in_dygraph_mode(): + assert feed_list is not None, \ + "feed_list should be set when return_list=False" + self.feed_list = feed_list + + assert places is not None, "places cannot be None" + self.places = _convert_places(places) + if in_dygraph_mode(): + assert len(self.places) == 1, \ + "Number of places must be 1 in dygraph mode" + + assert num_workers >= 0, "num_workers should be a non-negative value" + if num_workers > 0 and (sys.platform == 'darwin' or + sys.platform == 'win32'): + logging.warning( + "multi-process mode not support MacOs and Windows currently." \ + " use signle-process with num_workers = 0 instead") + num_workers = 0 + self.num_workers = num_workers + + self.use_shared_memory = use_shared_memory + if use_shared_memory and num_workers == 0: + self.use_shared_memory = False + + assert timeout >= 0, "timeout should be a non-negative value" + self.timeout = timeout + + if batch_sampler is not None: + assert isinstance(batch_sampler, BatchSampler), \ + "batch_sampler should be None or subclass instance " \ + "of paddle.io.BatchSampler" + assert batch_size == 1 and not shuffle and not drop_last, \ + "batch_size/shuffle/drop_last should not be set when " \ + "batch_sampler is given" + self.batch_sampler = batch_sampler + else: + assert batch_size is not None and batch_size > 0, \ + "batch_size should be a positive value when " \ + "batch_sampler is not given" + self.batch_sampler = BatchSampler( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last) + + def __len__(self): + return len(self.batch_sampler) + + def __iter__(self): + if self.num_workers == 0: + return _DataLoaderIterSingleProcess(self) + else: + return _DataLoaderIterMultiProcess(self) + + def __call__(self): + return self.__iter__() + @staticmethod def from_generator(feed_list=None, capacity=None, @@ -553,22 +732,7 @@ class DygraphGeneratorLoader(DataLoaderBase): if process is not None: process.join() # erase process id - core._erase_process_pid(id(self)) - - def _set_child_signal_handler(self): - core._set_process_pid(id(self), self._process.pid) - current_handler = signal.getsignal(signal.SIGCHLD) - if not callable(current_handler): - current_handler = None - - def __handler__(signum, frame): - # NOTE: Here the signum is SIGDHLD, when the child process exits, this handler - # will be called whenever the child process exits normally or abnormally. - core._throw_error_if_process_failed() - if current_handler is not None: - current_handler(signum, frame) - - signal.signal(signal.SIGCHLD, __handler__) + core._erase_process_pids(id(self)) def _init_iterable(self): self._wait_thread_ends() @@ -605,7 +769,8 @@ class DygraphGeneratorLoader(DataLoaderBase): # with SIGSEGV and SIGBUS of child process; 2. if the main process end before child # process, it shuts the all its daemonic children down with a SIGTERM (instead of # joining them without a timeout), so here nedd to deal with SIGTERM. - self._set_child_signal_handler() + core._set_process_pids(id(self), [self._process.pid]) + _set_SIGCHLD_handler() # Set reader_thread self._thread_done_event = threading.Event() @@ -666,16 +831,11 @@ class DygraphGeneratorLoader(DataLoaderBase): # set signal handler core._set_process_signal_handler() - # child process clear function at exit - def _cleanup(): - # clear memory map files in child process - core._cleanup_mmap_fds() - # NOTE: [ mmap files clear ] When the child process exits unexpectedly, # some shared memory objects may have been applied for but have not yet # been put into the inter-process Queue. This part of the object needs # to be cleaned up when the process ends. - CleanupFuncRegistrar.register(_cleanup) + CleanupFuncRegistrar.register(_cleanup_mmap) for batch in self._batch_reader(): tensor_list = core._convert_to_tensor_list(batch) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 400b57eec84dd29ee383f3a22b07e89a5caffa26..4ddcae9b240e69c2b2f012eba21f85b1baf58a5b 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -211,6 +211,8 @@ if (APPLE OR WIN32) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exit_func) list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler) + list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_base) + list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception) endif() if(NOT WITH_GPU OR WIN32 OR APPLE) @@ -381,4 +383,6 @@ if(NOT WIN32 AND NOT APPLE) set_tests_properties(test_imperative_data_loader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) set_tests_properties(test_imperative_data_loader_fds_clear PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) # set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) + set_tests_properties(test_multiprocess_dataloader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) + set_tests_properties(test_multiprocess_dataloader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) endif() diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7d90bbd0357bcc93cf7a66e99082feeb7e254db4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -0,0 +1,120 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest + +import paddle.fluid as fluid +from paddle.io import BatchSampler, Dataset + + +class RandomDataset(Dataset): + def __init__(self, sample_num, class_num): + self.sample_num = sample_num + self.class_num = class_num + + def __getitem__(self, idx): + np.random.seed(idx) + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.sample_num + + +class TestBatchSampler(unittest.TestCase): + def setUp(self): + self.num_samples = 1000 + self.num_classes = 10 + self.batch_size = 32 + self.shuffle = False + self.drop_last = False + + def init_batch_sampler(self): + dataset = RandomDataset(self.num_samples, self.num_classes) + bs = BatchSampler( + dataset=dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_last=self.drop_last) + return bs + + def test_main(self): + bs = self.init_batch_sampler() + # length check + bs_len = (self.num_samples + int(not self.drop_last) \ + * (self.batch_size - 1)) // self.batch_size + self.assertTrue(bs_len == len(bs)) + + # output indices check + if not self.shuffle: + index = 0 + for indices in bs: + for idx in indices: + self.assertTrue(index == idx) + index += 1 + + +class TestBatchSamplerDropLast(TestBatchSampler): + def setUp(self): + self.num_samples = 1000 + self.num_classes = 10 + self.batch_size = 32 + self.shuffle = False + self.drop_last = True + + +class TestBatchSamplerShuffle(TestBatchSampler): + def setUp(self): + self.num_samples = 1000 + self.num_classes = 10 + self.batch_size = 32 + self.shuffle = True + self.drop_last = True + + +class TestBatchSamplerWithIndices(TestBatchSampler): + def init_batch_sampler(self): + bs = BatchSampler( + indices=list(range(self.num_samples)), + batch_size=self.batch_size, + drop_last=self.drop_last) + return bs + + +class TestBatchSamplerWithIndicesAndDataSource(unittest.TestCase): + def setUp(self): + self.num_samples = 1000 + self.num_classes = 10 + self.batch_size = 32 + self.shuffle = False + self.drop_last = True + + def test_main(self): + try: + dataset = RandomDataset(self.num_samples, self.num_classes) + bs = BatchSampler( + dataset=dataset, + indices=list(range(self.num_samples)), + batch_size=self.batch_size, + drop_last=self.drop_last) + self.assertTrue(False) + except AssertionError: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c498fe4a3c71296101bc08e6bbbe0887ac8b6c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np + +import paddle.fluid as fluid +from paddle.io import * + + +class TestDatasetAbstract(unittest.TestCase): + def test_main(self): + dataset = Dataset() + try: + d = dataset[0] + self.assertTrue(False) + except NotImplementedError: + pass + + try: + l = len(dataset) + self.assertTrue(False) + except NotImplementedError: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py index 664d4078d2c5bc27d79de4454675cd5b17521e3e..e6b60873273d55da1724051d3419da4f94b76e7c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py @@ -17,6 +17,7 @@ import unittest import numpy as np import paddle.fluid as fluid from paddle.fluid import core +from paddle.io import Dataset, DataLoader def get_random_images_and_labels(image_shape, label_shape): @@ -35,6 +36,20 @@ def batch_generator_creator(batch_size, batch_num): return __reader__ +class RandomDataset(Dataset): + def __init__(self, sample_num): + self.sample_num = sample_num + + def __getitem__(self, idx): + np.random.seed(idx) + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.sample_num + + class TestDygraphDataLoaderMmapFdsClear(unittest.TestCase): def setUp(self): self.batch_size = 8 @@ -74,5 +89,19 @@ class TestDygraphDataLoaderMmapFdsClear(unittest.TestCase): self.run_one_epoch_with_break(loader) +class TestMultiProcessDataLoaderMmapFdsClear(TestDygraphDataLoaderMmapFdsClear): + def prepare_data_loader(self): + place = fluid.CPUPlace() + with fluid.dygraph.guard(place): + dataset = RandomDataset(self.batch_size * self.batch_num) + loader = DataLoader( + dataset, + places=place, + batch_size=self.batch_size, + drop_last=True, + num_workers=2) + return loader + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py b/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py index 57d1b394f968124e3accf4cea4f7d282dcbfdb6f..b47834ffab85e56b12c787ac57823ee02dd18df7 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py @@ -24,7 +24,7 @@ from paddle.fluid import core def set_child_signal_handler(self, child_pid): - core._set_process_pid(id(self), child_pid) + core._set_process_pids(id(self), tuple([child_pid])) current_handler = signal.getsignal(signal.SIGCHLD) if not callable(current_handler): current_handler = None diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b3ed710caa1ca2d7511be3753e06359ef69556 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py @@ -0,0 +1,260 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import os +import sys +import six +import time +import unittest +import multiprocessing +import numpy as np + +import paddle.fluid as fluid +from paddle.io import Dataset, BatchSampler, DataLoader +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.dygraph.base import to_variable + +EPOCH_NUM = 5 +BATCH_SIZE = 16 +IMAGE_SIZE = 784 +SAMPLE_NUM = 400 +CLASS_NUM = 10 + + +class RandomDataset(Dataset): + def __init__(self, sample_num, class_num): + self.sample_num = sample_num + self.class_num = class_num + + def __getitem__(self, idx): + np.random.seed(idx) + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.sample_num + + +def simple_fc_net_static(): + startup_prog = fluid.Program() + main_prog = fluid.Program() + startup_prog.random_seed = 1 + main_prog.random_seed = 1 + + with fluid.unique_name.guard(): + with fluid.program_guard(main_prog, startup_prog): + image = fluid.data( + name='image', shape=[None, IMAGE_SIZE], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = image + param_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.8)) + bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)) + for hidden_size in [10, 20, 30]: + hidden = fluid.layers.fc(hidden, + size=hidden_size, + act='tanh', + param_attr=param_attr, + bias_attr=bias_attr) + + predict_label = fluid.layers.fc(hidden, + size=CLASS_NUM, + act='softmax', + param_attr=param_attr, + bias_attr=bias_attr) + loss = fluid.layers.reduce_mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + return startup_prog, main_prog, image, label, loss + + +class SimpleFCNet(fluid.dygraph.Layer): + def __init__(self): + super(SimpleFCNet, self).__init__() + + param_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.8)) + bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)) + self._fcs = [] + in_channel = IMAGE_SIZE + for hidden_size in [10, 20, 30]: + self._fcs.append( + Linear( + in_channel, + hidden_size, + act='tanh', + param_attr=param_attr, + bias_attr=bias_attr)) + in_channel = hidden_size + self._fcs.append( + Linear( + in_channel, + CLASS_NUM, + act='softmax', + param_attr=param_attr, + bias_attr=bias_attr)) + + def forward(self, image): + out = image + for fc in self._fcs: + out = fc(out) + return out + + +class TestStaticDataLoader(unittest.TestCase): + def run_main(self, num_workers, places, with_data_parallel): + scope = fluid.Scope() + with fluid.scope_guard(scope): + startup_prog, main_prog, image, label, loss = simple_fc_net_static() + + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + places=places, + num_workers=num_workers, + batch_size=BATCH_SIZE, + drop_last=True) + assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) + + exe = fluid.Executor(place=places[0]) + exe.run(startup_prog) + + prog = fluid.CompiledProgram(main_prog) + if with_data_parallel: + prog = prog.with_data_parallel( + loss_name=loss.name, places=places) + + step_list = [] + loss_list = [] + start_t = time.time() + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for d in dataloader: + assert len(d) == len(places), "{} != {}".format( + len(d), len(places)) + for i, item in enumerate(d): + image = item['image'] + label = item['label'] + assert image.shape() == [BATCH_SIZE, IMAGE_SIZE] + assert label.shape() == [BATCH_SIZE, 1] + assert image._place()._equals(places[i]) + assert label._place()._equals(places[i]) + L, = exe.run(program=prog, + feed=d, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + + def prepare_places(self, with_data_parallel, with_cpu=True, with_gpu=True): + places = [] + # FIXME: PR_CI_Py35 may hang on Multi-CPUs with multiprocess, but it + # works fine locally, this should be fixed. OTOH, multiprocessing + # is not recommended when running on CPU generally + if with_cpu and not sys.version.startswith('3.5'): + places.append([fluid.CPUPlace()]) + if with_data_parallel: + places.append([fluid.CPUPlace()] * 2) + + if with_gpu and fluid.core.is_compiled_with_cuda(): + tmp = fluid.cuda_places()[:2] + assert len(tmp) > 0, "no gpu detected" + if with_data_parallel: + places.append(tmp) + places.append([tmp[0]]) + return places + + def test_main(self): + for with_data_parallel in [False] if self.__class__.__name__ \ + == "TestDygraphDataLoader" else [True, False]: + for p in self.prepare_places(with_data_parallel): + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers) + ret = self.run_main( + num_workers=num_workers, + places=p, + with_data_parallel=with_data_parallel) + results.append(ret) + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss']) / + np.abs(results[0]['loss'])) + self.assertLess(diff, 1e-2) + + +class TestDygraphDataLoader(TestStaticDataLoader): + def run_main(self, num_workers, places, with_data_parallel): + fluid.default_startup_program().random_seed = 1 + fluid.default_main_program().random_seed = 1 + with fluid.dygraph.guard(places[0]): + fc_net = SimpleFCNet() + optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters()) + + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + places=places, + num_workers=num_workers, + batch_size=BATCH_SIZE, + drop_last=True) + assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) + + step_list = [] + loss_list = [] + start_t = time.time() + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for image, label in dataloader(): + out = fc_net(image) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.reduce_mean(loss) + avg_loss.backward() + optimizer.minimize(avg_loss) + fc_net.clear_gradients() + + loss_list.append(np.mean(avg_loss.numpy())) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e6999112e498c3311c7b5a037912db1000eaf8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import os +import sys +import six +import time +import unittest +import multiprocessing +import numpy as np + +import paddle.fluid as fluid +from paddle.io import Dataset, BatchSampler, DataLoader +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.dygraph.base import to_variable + + +class RandomDataset(Dataset): + def __init__(self, sample_num): + self.sample_num = sample_num + + def __getitem__(self, idx): + np.random.seed(idx) + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.sample_num + + +class TestDataLoaderAssert(unittest.TestCase): + def test_main(self): + place = fluid.cpu_places()[0] + with fluid.dygraph.guard(place): + dataset = RandomDataset(100) + batch_sampler = BatchSampler(dataset=dataset, batch_size=4) + + # dataset is not instance of Dataset + try: + loader = DataLoader(dataset=batch_sampler, places=place) + self.assertTrue(False) + except AssertionError: + pass + + # places is None + try: + loader = DataLoader(dataset=dataset, places=None) + self.assertTrue(False) + except AssertionError: + pass + + # num_workers < 0 + try: + loader = DataLoader( + dataset=dataset, places=place, num_workers=-1) + self.assertTrue(False) + except AssertionError: + pass + + # timeout < 0 + try: + loader = DataLoader(dataset=dataset, places=place, timeout=-1) + self.assertTrue(False) + except AssertionError: + pass + + # batch_sampler is not instance of BatchSampler + try: + loader = DataLoader( + dataset=dataset, places=place, batch_sampler=dataset) + self.assertTrue(False) + except AssertionError: + pass + + # set batch_sampler and shuffle/batch_size/drop_last + try: + loader = DataLoader( + dataset=dataset, + places=place, + batch_sampler=batch_sampler, + shuffle=True, + drop_last=True) + self.assertTrue(False) + except AssertionError: + pass + + # set batch_sampler correctly + try: + loader = DataLoader( + dataset=dataset, places=place, batch_sampler=batch_sampler) + self.assertTrue(True) + except AssertionError: + self.assertTrue(False) + + +# CI Converage cannot record stub in subprocess, +# HACK a _worker_loop in main process call here +class TestDataLoaderWorkerLoop(unittest.TestCase): + def run_without_worker_done(self, use_shared_memory=True): + try: + place = fluid.cpu_places()[0] + with fluid.dygraph.guard(place): + dataset = RandomDataset(800) + + # test init_fn + def _init_fn(worker_id): + pass + + # test collate_fn + def _collate_fn(sample_list): + return [ + np.stack( + s, axis=0) for s in list(zip(*sample_list)) + ] + + loader = DataLoader( + dataset, + num_workers=1, + places=place, + use_shared_memory=use_shared_memory) + assert loader.num_workers > 0, \ + "go to AssertionError and pass in Mac and Windows" + loader = iter(loader) + print("loader length", len(loader)) + indices_queue = multiprocessing.Queue() + for i in range(10): + indices_queue.put([i, i + 10]) + indices_queue.put(None) + loader._worker_loop( + loader._dataset, indices_queue, loader._data_queue, + loader._workers_done_event, _collate_fn, _init_fn, 0) + self.assertTrue(False) + except AssertionError: + pass + except Exception: + self.assertTrue(False) + + def run_with_worker_done(self, use_shared_memory=True): + try: + place = fluid.cpu_places()[0] + with fluid.dygraph.guard(place): + dataset = RandomDataset(800) + + # test init_fn + def _init_fn(worker_id): + pass + + # test collate_fn + def _collate_fn(sample_list): + return [ + np.stack( + s, axis=0) for s in list(zip(*sample_list)) + ] + + loader = DataLoader( + dataset, + num_workers=1, + places=place, + use_shared_memory=use_shared_memory) + assert loader.num_workers > 0, \ + "go to AssertionError and pass in Mac and Windows" + loader = iter(loader) + print("loader length", len(loader)) + indices_queue = multiprocessing.Queue() + for i in range(10): + indices_queue.put([i, i + 10]) + indices_queue.put(None) + loader._workers_done_event.set() + loader._worker_loop( + loader._dataset, indices_queue, loader._data_queue, + loader._workers_done_event, _collate_fn, _init_fn, 0) + self.assertTrue(True) + except AssertionError: + pass + except Exception: + self.assertTrue(False) + + def test_main(self): + for use_shared_memory in [True, False]: + self.run_without_worker_done(use_shared_memory) + self.run_with_worker_done(use_shared_memory) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index b9ad3e4c77c5a282fe71d8e82939c8705c92d27c..b85ef741b062633e028e5767f8ee4fbaf816d262 100644 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -13,22 +13,27 @@ # limitations under the License. # TODO: define all functions about input & output in this directory -# __all__ = ['Dataset', -# 'Sampler', -# 'Transform', -# 'DataLoader', -# 'load', -# 'save', -# 'load_program_state', -# 'set_program_state', -# 'load_inference_model', -# 'save_inference_model', -# 'batch', -# 'shuffle', -# 'buffered', -# 'cache', -# 'chain', -# 'firstn', -# 'compose', -# 'map_readers', -# 'xmap_readers'] +__all__ = [ + 'Dataset', + 'BatchSampler', + # 'Transform', + 'DataLoader', + # 'load', + # 'save', + # 'load_program_state', + # 'set_program_state', + # 'load_inference_model', + # 'save_inference_model', + # 'batch', + # 'shuffle', + # 'buffered', + # 'cache', + # 'chain', + # 'firstn', + # 'compose', + # 'map_readers', + # 'xmap_readers' +] + +from ..fluid.io import DataLoader +from ..fluid.dataloader import Dataset, BatchSampler diff --git a/python/setup.py.in b/python/setup.py.in index ed77787d4cdd2711bd203fddc3e38199668d91af..f6ec0fa26b9846f02895afb8ad0082aaa9e78b17 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -149,6 +149,7 @@ packages=['paddle', 'paddle.fluid.proto.profiler', 'paddle.fluid.distributed', 'paddle.fluid.layers', + 'paddle.fluid.dataloader', 'paddle.fluid.contrib', 'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.quantize', @@ -176,6 +177,7 @@ packages=['paddle', 'paddle.fluid.incubate.fleet.parameter_server.pslib', 'paddle.fluid.incubate.fleet.collective', 'paddle.fluid.incubate.fleet.utils', + 'paddle.io', 'paddle.nn', 'paddle.nn.functional', 'paddle.nn.layer',