diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 6753c18da464926dbc5c18d71046cd95b88d441b..214cd772af6b1fa6e24ec972c0f0644dc1c09f95 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -38,7 +38,27 @@ from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCH MP_INDICES_CHECK_INTERVAL = 5 -def _default_collate_fn(batch): +def default_collate_fn(batch): + """ + Default batch collating function for :code:`fluid.io.DataLoader`, + batch should be a list of samples, and each sample should be a list + of fields as follows: + + [[filed1, filed2, ...], [filed1, filed2, ...], ...] + + This default collate function zipped each filed together and stack + each filed as the batch field as follows: + + [batch_filed1, batch_filed2, ...] + + Args: + batch(list of list of numpy array): the batch data, each fields + should be a numpy array, each sample should be a list of + fileds, and batch should be a list of sample. + + Returns: + a list of numpy array: collated batch + """ sample = batch[0] # dataset has only 1 field if isinstance(sample, np.ndarray): @@ -82,7 +102,7 @@ class _DataLoaderIterBase(object): self._return_list = loader.return_list self._batch_sampler = loader.batch_sampler self._sampler_iter = iter(loader.batch_sampler) - self._collate_fn = loader.collate_fn or _default_collate_fn + self._collate_fn = loader.collate_fn or default_collate_fn self._num_workers = loader.num_workers self._use_buffer_reader = loader.use_buffer_reader self._use_shared_memory = loader.use_shared_memory diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index ebe16a8bbbc31735fb203a6ac93c3e0f24ab3d35..0289ecea34acf65d01aa13b555ee523f7127b48d 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -23,7 +23,7 @@ from .executor import global_scope from .data_feeder import DataFeeder, BatchedTensorProvider from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler from .dataloader import BatchSampler, Dataset -from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess +from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, default_collate_fn from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .unique_name import UniqueNameGenerator import logging @@ -44,7 +44,7 @@ else: # NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process QUEUE_GET_TIMEOUT = 60 -__all__ = ['PyReader', 'DataLoader'] +__all__ = ['PyReader', 'DataLoader', 'default_collate_fn'] data_loader_unique_name_generator = UniqueNameGenerator()