未验证 提交 20ee36bd 编写于 作者: K Kaipeng Deng 提交者: GitHub

make default_collate_fn visible. test=develop (#25244)

* make default_collate_fn visible. test=develop
上级 ee44bcdd
......@@ -38,7 +38,27 @@ from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCH
MP_INDICES_CHECK_INTERVAL = 5
def _default_collate_fn(batch):
def default_collate_fn(batch):
"""
Default batch collating function for :code:`fluid.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array: collated batch
"""
sample = batch[0]
# dataset has only 1 field
if isinstance(sample, np.ndarray):
......@@ -82,7 +102,7 @@ class _DataLoaderIterBase(object):
self._return_list = loader.return_list
self._batch_sampler = loader.batch_sampler
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or _default_collate_fn
self._collate_fn = loader.collate_fn or default_collate_fn
self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader
self._use_shared_memory = loader.use_shared_memory
......
......@@ -23,7 +23,7 @@ from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
from .dataloader import BatchSampler, Dataset
from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess
from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, default_collate_fn
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator
import logging
......@@ -44,7 +44,7 @@ else:
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT = 60
__all__ = ['PyReader', 'DataLoader']
__all__ = ['PyReader', 'DataLoader', 'default_collate_fn']
data_loader_unique_name_generator = UniqueNameGenerator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册