未验证 提交 2bfe8b2c 编写于 作者: J Jackwaterveg 提交者: GitHub

[Dataloader]Add prefetch_factor in dataloader (#43081)

* fix usage of prefetch_factor

* add assert

* add docstring and change prefetch_factor when num_workers=0

* fix doc
上级 67163fb4
...@@ -96,6 +96,7 @@ class _DataLoaderIterBase(object): ...@@ -96,6 +96,7 @@ class _DataLoaderIterBase(object):
self._auto_collate_batch = loader.auto_collate_batch self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader self._use_buffer_reader = loader.use_buffer_reader
self._prefetch_factor = loader.prefetch_factor
self._use_shared_memory = loader.use_shared_memory self._use_shared_memory = loader.use_shared_memory
self._timeout = loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL self._timeout = loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL
self._worker_init_fn = loader.worker_init_fn self._worker_init_fn = loader.worker_init_fn
...@@ -166,9 +167,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -166,9 +167,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._structure_infos = [] self._structure_infos = []
# NOTE: len(self._places) batch data compose as an output # NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas # iteration, set blocking_queue can cache "self._prefetch_factor" iteration datas
# at most here # at most here
self._blocking_queue_capacity = 1 * len(self._places) self._blocking_queue_capacity = self._prefetch_factor * len(
self._places)
self._init_thread() self._init_thread()
self._shutdown = False self._shutdown = False
...@@ -363,11 +365,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -363,11 +365,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# indices outstand as _outstanding_capacity at first, and # indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity. # blocking_queue capacity is also _outstanding_capacity.
# _outstanding_capacity here to make sure each indices_queue # _outstanding_capacity here to make sure each indices_queue
# has at least 2 indices, and outstanding batch cached # has at least "_prefetch_factor" indices, and outstanding batch cached
# output data for at least 2 iterations(Note that len(_places) # output data for at least "_prefetch_factor" iterations(Note that len(_places)
# batches will be composed as an iteration output) # batches will be composed as an iteration output)
self._outstanding_capacity = 2 * max(self._num_workers, self._outstanding_capacity = self._prefetch_factor * max(
len(self._places)) self._num_workers, len(self._places))
# see _try_put_indices # see _try_put_indices
self._thread_lock = threading.Lock() self._thread_lock = threading.Lock()
......
...@@ -314,56 +314,58 @@ class DataLoader(object): ...@@ -314,56 +314,58 @@ class DataLoader(object):
dataset(Dataset): the dataset to load data from, should be an dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or instance of subclass of :code:`paddle.io.Dataset` or
:code:`paddle.io.IterableDataset`. :code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list. feed_list (list(Tensor)|tuple(Tensor), optional): feed Tensor list.
The Tensors should be created by :code:`paddle.static.data()`. The Tensors should be created by :code:`paddle.static.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is :attr:`feed_list` must be set if :attr:`return_list` is
False. Default None. False. Default None.
places(list(Place)|tuple(Place)|list(str)|optional): a list of Place, places(list(Place)|tuple(Place)|list(str), optional): a list of Place,
to put data onto, :attr:`places` can be None, if to put data onto, :attr:`places` can be None, if
:attr:`places` is None, default place(CPUPlace or CUDAPlace(0)) :attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
will be used. Default None. If ``places`` is list of string, will be used. Default None. If ``places`` is list of string,
the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``,
where ``x`` is the index of the GPUs. where ``x`` is the index of the GPUs.
return_list (bool): whether the return value on each device is return_list (bool, optional): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> Tensor, where value on each device would be a dict of str -> Tensor, where
the key of the dict is the name of each fed Tensors. If the key of the dict is the name of each fed Tensors. If
:attr:`return_list=True`, the return value on each device would :attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default True. in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` batch_sampler(BatchSampler, optional): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset` to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None. and combine a batch. Default None.
batch_size(int|None): sample number in a mini-batch, a substitution batch_size(int|None, optional): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler` parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and and initialize by :attr:`batch_size`, :attr:`shuffle` and
:attr:`drop_last`. Default 1. :attr:`drop_last`. Default 1.
shuffle(bool): whther to shuffle indices order before genrate shuffle(bool, optional): whther to shuffle indices order before genrate
batch indices, a substitution parameter for :attr:`batch_sampler` batch indices, a substitution parameter for :attr:`batch_sampler`
see :attr:`batch_size`. Default False. see :attr:`batch_size`. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size drop_last(bool, optional): whether drop the last incomplete batch dataset size
is not divisible by the batch size, a substitution parameter is not divisible by the batch size, a substitution parameter
for :attr:`batch_sampler`, see :attr:`batch_size`. Default False for :attr:`batch_sampler`, see :attr:`batch_size`. Default False
collate_fn(callable): function to generate mini-batch data by merging collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis the sample list, None for only stack each fields of sample in axis
0(same as :attr::`np.stack(..., axis=0)`). Default None 0(same as :attr::`np.stack(..., axis=0)`). Default None
num_workers(int): the number of subprocess to load data, 0 for no num_workers(int, optional): the number of subprocess to load data, 0 for no
subprocess used and loading data in main process. Default 0 subprocess used and loading data in main process. Default 0
use_buffer_reader (bool): whether to use bufferred reader. use_buffer_reader (bool, optional): whether to use bufferred reader.
If use_buffer_reader=True, the DataLoader would prefetch next If use_buffer_reader=True, the DataLoader would prefetch
batch data asynchronously, so it would speed up data feeding batch data asynchronously, so it would speed up data feeding
and occupies a little more CPU or GPU memory, i.e., the memory and occupies a little more CPU or GPU memory, i.e., the memory
of one batch input data. Default True. of one batch input data. Default True.
use_shared_memory (bool): whether to use shared memory to speed up prefetch_factor (int, optional): Number of batch data the DataLoader would prefetch
if use_buffer_reader=True. Default 2.
use_shared_memory (bool, optional): whether to use shared memory to speed up
putting data into inter-process queue, set :attr:`use_shared_memory` putting data into inter-process queue, set :attr:`use_shared_memory`
as True only when the shared memory space on your machine(e.g. as True only when the shared memory space on your machine(e.g.
space of '/dev/shm' on Linux operating sysytem) is large enough. space of '/dev/shm' on Linux operating sysytem) is large enough.
Shared memory will only be enabled in multi-process mode(num_workers Shared memory will only be enabled in multi-process mode(num_workers
> 0). Default True. > 0). Default True.
timeout(int): the timeout value for getting data form output queue timeout(int, optional): the timeout value for getting data form output queue
of subprocesses. Default 0. of subprocesses. Default 0.
worker_init_fn(callable): init function which will be called with worker_init_fn(callable, optional): init function which will be called with
worker id on each subproces starting if not set as None. Default worker id on each subproces starting if not set as None. Default
None. None.
...@@ -450,6 +452,7 @@ class DataLoader(object): ...@@ -450,6 +452,7 @@ class DataLoader(object):
collate_fn=None, collate_fn=None,
num_workers=0, num_workers=0,
use_buffer_reader=True, use_buffer_reader=True,
prefetch_factor=2,
use_shared_memory=True, use_shared_memory=True,
timeout=0, timeout=0,
worker_init_fn=None, worker_init_fn=None,
...@@ -457,6 +460,7 @@ class DataLoader(object): ...@@ -457,6 +460,7 @@ class DataLoader(object):
self.return_list = return_list self.return_list = return_list
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader self.use_buffer_reader = use_buffer_reader
self.prefetch_factor = prefetch_factor
self.worker_init_fn = worker_init_fn self.worker_init_fn = worker_init_fn
self.dataset = dataset self.dataset = dataset
...@@ -483,6 +487,8 @@ class DataLoader(object): ...@@ -483,6 +487,8 @@ class DataLoader(object):
num_workers = 0 num_workers = 0
self.num_workers = num_workers self.num_workers = num_workers
assert prefetch_factor > 0, "prefetch_factor should be a positive value"
self.use_shared_memory = use_shared_memory self.use_shared_memory = use_shared_memory
if use_shared_memory and num_workers == 0: if use_shared_memory and num_workers == 0:
self.use_shared_memory = False self.use_shared_memory = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册