未验证 提交 af415bc2 编写于 作者: J Jackwaterveg 提交者: GitHub

[Cherry-pick ] to Release/2.3, Add prefetch_factor in dataloader (#43674)

* fix usage of prefetch_factor

* add assert

* add docstring and change prefetch_factor when num_workers=0

* fix doc
上级 9783e887
......@@ -96,6 +96,7 @@ class _DataLoaderIterBase(object):
self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader
self._prefetch_factor = loader.prefetch_factor
self._use_shared_memory = loader.use_shared_memory
self._timeout = loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL
self._worker_init_fn = loader.worker_init_fn
......@@ -166,9 +167,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._structure_infos = []
# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
# iteration, set blocking_queue can cache "self._prefetch_factor" iteration datas
# at most here
self._blocking_queue_capacity = 1 * len(self._places)
self._blocking_queue_capacity = self._prefetch_factor * len(
self._places)
self._init_thread()
self._shutdown = False
......@@ -363,11 +365,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity.
# _outstanding_capacity here to make sure each indices_queue
# has at least 2 indices, and outstanding batch cached
# output data for at least 2 iterations(Note that len(_places)
# has at least "_prefetch_factor" indices, and outstanding batch cached
# output data for at least "_prefetch_factor" iterations(Note that len(_places)
# batches will be composed as an iteration output)
self._outstanding_capacity = 2 * max(self._num_workers,
len(self._places))
self._outstanding_capacity = self._prefetch_factor * max(
self._num_workers, len(self._places))
# see _try_put_indices
self._thread_lock = threading.Lock()
......
......@@ -314,56 +314,58 @@ class DataLoader(object):
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or
:code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list.
feed_list (list(Tensor)|tuple(Tensor), optional): feed Tensor list.
The Tensors should be created by :code:`paddle.static.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is
False. Default None.
places(list(Place)|tuple(Place)|list(str)|optional): a list of Place,
places(list(Place)|tuple(Place)|list(str), optional): a list of Place,
to put data onto, :attr:`places` can be None, if
:attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
will be used. Default None. If ``places`` is list of string,
the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``,
where ``x`` is the index of the GPUs.
return_list (bool): whether the return value on each device is
return_list (bool, optional): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> Tensor, where
the key of the dict is the name of each fed Tensors. If
:attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
batch_sampler(BatchSampler, optional): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
batch_size(int|None): sample number in a mini-batch, a substitution
batch_size(int|None, optional): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and
:attr:`drop_last`. Default 1.
shuffle(bool): whther to shuffle indices order before genrate
shuffle(bool, optional): whther to shuffle indices order before genrate
batch indices, a substitution parameter for :attr:`batch_sampler`
see :attr:`batch_size`. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
drop_last(bool, optional): whether drop the last incomplete batch dataset size
is not divisible by the batch size, a substitution parameter
for :attr:`batch_sampler`, see :attr:`batch_size`. Default False
collate_fn(callable): function to generate mini-batch data by merging
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0(same as :attr::`np.stack(..., axis=0)`). Default None
num_workers(int): the number of subprocess to load data, 0 for no
num_workers(int, optional): the number of subprocess to load data, 0 for no
subprocess used and loading data in main process. Default 0
use_buffer_reader (bool): whether to use bufferred reader.
If use_buffer_reader=True, the DataLoader would prefetch next
use_buffer_reader (bool, optional): whether to use bufferred reader.
If use_buffer_reader=True, the DataLoader would prefetch
batch data asynchronously, so it would speed up data feeding
and occupies a little more CPU or GPU memory, i.e., the memory
of one batch input data. Default True.
use_shared_memory (bool): whether to use shared memory to speed up
prefetch_factor (int, optional): Number of batch data the DataLoader would prefetch
if use_buffer_reader=True. Default 2.
use_shared_memory (bool, optional): whether to use shared memory to speed up
putting data into inter-process queue, set :attr:`use_shared_memory`
as True only when the shared memory space on your machine(e.g.
space of '/dev/shm' on Linux operating sysytem) is large enough.
Shared memory will only be enabled in multi-process mode(num_workers
> 0). Default True.
timeout(int): the timeout value for getting data form output queue
timeout(int, optional): the timeout value for getting data form output queue
of subprocesses. Default 0.
worker_init_fn(callable): init function which will be called with
worker_init_fn(callable, optional): init function which will be called with
worker id on each subproces starting if not set as None. Default
None.
......@@ -450,6 +452,7 @@ class DataLoader(object):
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
prefetch_factor=2,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
......@@ -457,6 +460,7 @@ class DataLoader(object):
self.return_list = return_list
self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader
self.prefetch_factor = prefetch_factor
self.worker_init_fn = worker_init_fn
self.dataset = dataset
......@@ -483,6 +487,8 @@ class DataLoader(object):
num_workers = 0
self.num_workers = num_workers
assert prefetch_factor > 0, "prefetch_factor should be a positive value"
self.use_shared_memory = use_shared_memory
if use_shared_memory and num_workers == 0:
self.use_shared_memory = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册