未验证 提交 bf09dcb3 编写于 作者: K Kaipeng Deng 提交者: GitHub

add GPU tensor notice & update default_collate_fn/default_convert_fn. test=develop (#31763)

上级 27f2d8df
...@@ -27,24 +27,31 @@ except: ...@@ -27,24 +27,31 @@ except:
def default_collate_fn(batch): def default_collate_fn(batch):
""" """
Default batch collating function for :code:`paddle.io.DataLoader`, Default batch collating function for :code:`paddle.io.DataLoader`,
batch should be a list of samples, and each sample should be a list get input data as a list of sample datas, each element in list
of fields as follows: if the data of a sample, and sample data should composed of list,
dictionary, string, number, numpy array and paddle.Tensor, this
function will parse input data recursively and stack number,
numpy array and paddle.Tensor datas as batch datas. e.g. for
following input data:
[{'image': np.array(shape=[3, 224, 224]), 'label': 1},
{'image': np.array(shape=[3, 224, 224]), 'label': 3},
{'image': np.array(shape=[3, 224, 224]), 'label': 4},
{'image': np.array(shape=[3, 224, 224]), 'label': 5},]
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack This default collate function zipped each number and numpy array
each filed as the batch field as follows: field together and stack each field as the batch field as follows:
{'image': np.array(shape=[4, 3, 224, 224]), 'label': np.array([1, 3, 4, 5])}
[batch_filed1, batch_filed2, ...]
Args: Args:
batch(list of list of numpy array|paddle.Tensor): the batch data, each fields batch(list of sample data): batch should be a list of sample data.
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns: Returns:
a list of numpy array|Paddle.Tensor: collated batch of input batch data, Batched data: batched each number, numpy array and paddle.Tensor
fields data type as same as fields in each sample. in input data.
""" """
sample = batch[0] sample = batch[0]
if isinstance(sample, np.ndarray): if isinstance(sample, np.ndarray):
...@@ -75,6 +82,24 @@ def default_collate_fn(batch): ...@@ -75,6 +82,24 @@ def default_collate_fn(batch):
def default_convert_fn(batch): def default_convert_fn(batch):
"""
Default batch converting function for :code:`paddle.io.DataLoader`.
get input data as a list of sample datas, each element in list
if the data of a sample, and sample data should composed of list,
dictionary, string, number, numpy array and paddle.Tensor.
.. note::
This function is default :attr:`collate_fn` in **Distable
automatic batching** mode, for **Distable automatic batching**
mode, please ses :attr:`paddle.io.DataLoader`
Args:
batch(list of sample data): batch should be a list of sample data.
Returns:
Batched data: batched each number, numpy array and paddle.Tensor
in input data.
"""
if isinstance(batch, (paddle.Tensor, np.ndarray)): if isinstance(batch, (paddle.Tensor, np.ndarray)):
return batch return batch
elif isinstance(batch, (str, bytes)): elif isinstance(batch, (str, bytes)):
......
...@@ -165,6 +165,12 @@ class DataLoader(object): ...@@ -165,6 +165,12 @@ class DataLoader(object):
For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler` For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler`
.. note::
GPU tensor operation is not supported in subprocess currently,
please don't use GPU tensor operations in pipeline which will
be performed in subprocess, such as dataset transforms, collte_fn,
etc. Numpy array and CPU tensor operation is supported.
**Disable automatic batching** **Disable automatic batching**
In certain cases such as some NLP tasks, instead of automatic batching, In certain cases such as some NLP tasks, instead of automatic batching,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册