未验证 提交 9cd09a85 编写于 作者: C Chen Weihang 提交者: GitHub

Polish dataloader doc detail & update example (#28975)

* polish dataloader doc detail, test=decument_fix

* fix commnet error

* fix word error
上级 fbf9564f
...@@ -182,8 +182,8 @@ class DataLoader(object): ...@@ -182,8 +182,8 @@ class DataLoader(object):
dataset(Dataset): the dataset to load data from, should be an dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or instance of subclass of :code:`paddle.io.Dataset` or
:code:`paddle.io.IterableDataset`. :code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed variable list. feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list.
The variables should be created by :code:`paddle.static.data()`. The Tensors should be created by :code:`paddle.static.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is :attr:`feed_list` must be set if :attr:`return_list` is
False. Default None. False. Default None.
places(list(Place)|tuple(Place)|optional): a list of Place, places(list(Place)|tuple(Place)|optional): a list of Place,
...@@ -193,7 +193,7 @@ class DataLoader(object): ...@@ -193,7 +193,7 @@ class DataLoader(object):
return_list (bool): whether the return value on each device is return_list (bool): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> Tensor, where value on each device would be a dict of str -> Tensor, where
the key of the dict is the name of each fed variables. If the key of the dict is the name of each fed Tensors. If
:attr:`return_list=True`, the return value on each device would :attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default True. in dynamic graph mode. Default True.
...@@ -447,14 +447,11 @@ class DataLoader(object): ...@@ -447,14 +447,11 @@ class DataLoader(object):
If iterable = False, the created DataLoader object provides If iterable = False, the created DataLoader object provides
:code:`start()` and :code:`reset()` method to control the data reading :code:`start()` and :code:`reset()` method to control the data reading
process. This mode is designed to be compatible with the process.
:code:`fluid.layers.py_reader` interface. Users can migrate the codes
from :code:`fluid.layers.py_reader` to :code:`fluid.io.DataLoader`
easily when using iterable=False.
Args: Args:
feed_list (list(Variable)|tuple(Variable)): feed variable list. feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list.
The variables should be created by :code:`fluid.data()`. The Tensors should be created by :code:`fluid.data()`.
capacity (int): capacity of the queue maintained in DataLoader. capacity (int): capacity of the queue maintained in DataLoader.
The unit is batch number. Set larger capacity if your reader The unit is batch number. Set larger capacity if your reader
is fast. is fast.
...@@ -468,7 +465,7 @@ class DataLoader(object): ...@@ -468,7 +465,7 @@ class DataLoader(object):
presented as a list. It is only valid when iterable=True. presented as a list. It is only valid when iterable=True.
If return_list=False, the return value on each device would If return_list=False, the return value on each device would
be a dict of str -> LoDTensor, where the key of the dict is be a dict of str -> LoDTensor, where the key of the dict is
the name of each fed variables. If return_list=True, the the name of each fed Tensors. If return_list=True, the
return value on each device would be a list(LoDTensor). It is return value on each device would be a list(LoDTensor). It is
recommended to use return_list=False in static graph mode and recommended to use return_list=False in static graph mode and
use return_list=True in dygraph mode. use return_list=True in dygraph mode.
...@@ -492,9 +489,16 @@ class DataLoader(object): ...@@ -492,9 +489,16 @@ class DataLoader(object):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid '''
Example in static graph mode
'''
import numpy as np import numpy as np
import paddle
import paddle.static as static
import paddle.nn.functional as F
BATCH_NUM = 10 BATCH_NUM = 10
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCH_NUM = 4 EPOCH_NUM = 4
...@@ -506,11 +510,13 @@ class DataLoader(object): ...@@ -506,11 +510,13 @@ class DataLoader(object):
DATA_FORMAT = 'batch_generator' # data format of data source user provides DATA_FORMAT = 'batch_generator' # data format of data source user provides
paddle.enable_static()
def simple_net(image, label): def simple_net(image, label):
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM) fc_tmp = static.nn.fc(image, size=CLASS_NUM)
cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label) cross_entropy = F.softmax_with_cross_entropy(image, label)
loss = fluid.layers.reduce_mean(cross_entropy) loss = paddle.mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd = paddle.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss) sgd.minimize(loss)
return loss return loss
...@@ -566,7 +572,7 @@ class DataLoader(object): ...@@ -566,7 +572,7 @@ class DataLoader(object):
try: try:
while True: while True:
exe.run(prog, fetch_list=[loss]) exe.run(prog, fetch_list=[loss])
except fluid.core.EOFException: except paddle.core.EOFException:
loader.reset() # call DataLoader.reset() after catching EOFException loader.reset() # call DataLoader.reset() after catching EOFException
def set_data_source(loader, places): def set_data_source(loader, places):
...@@ -579,11 +585,11 @@ class DataLoader(object): ...@@ -579,11 +585,11 @@ class DataLoader(object):
else: else:
raise ValueError('Unsupported data format') raise ValueError('Unsupported data format')
image = fluid.data(name='image', shape=[None, 784], dtype='float32') image = static.data(name='image', shape=[None, 784], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64') label = static.data(name='label', shape=[None, 1], dtype='int64')
# Define DataLoader # Define DataLoader
loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE) loader = paddle.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)
# Define network # Define network
loss = simple_net(image, label) loss = simple_net(image, label)
...@@ -591,17 +597,17 @@ class DataLoader(object): ...@@ -591,17 +597,17 @@ class DataLoader(object):
# Set data source of DataLoader # Set data source of DataLoader
# #
# If DataLoader is iterable, places must be given and the number of places must be the same with device number. # If DataLoader is iterable, places must be given and the number of places must be the same with device number.
# - If you are using GPU, call `fluid.cuda_places()` to get all GPU places. # - If you are using GPU, call `paddle.static.cuda_places()` to get all GPU places.
# - If you are using CPU, call `fluid.cpu_places()` to get all CPU places. # - If you are using CPU, call `paddle.static.cpu_places()` to get all CPU places.
# #
# If DataLoader is not iterable, places can be None. # If DataLoader is not iterable, places can be None.
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places() places = static.cuda_places() if USE_GPU else static.cpu_places()
set_data_source(loader, places) set_data_source(loader, places)
exe = fluid.Executor(places[0]) exe = static.Executor(places[0])
exe.run(fluid.default_startup_program()) exe.run(static.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) prog = static.CompiledProgram(static.default_main_program()).with_data_parallel(loss_name=loss.name)
if loader.iterable: if loader.iterable:
train_iterable(exe, prog, loss, loader) train_iterable(exe, prog, loss, loader)
...@@ -609,45 +615,110 @@ class DataLoader(object): ...@@ -609,45 +615,110 @@ class DataLoader(object):
train_non_iterable(exe, prog, loss, loader) train_non_iterable(exe, prog, loss, loader)
Examples 2:
.. code-block:: python
''' '''
Users can use return_list = True in dygraph mode. Example in dynamic graph mode.
''' '''
with fluid.dygraph.guard(places[0]): import numpy as np
loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)
set_data_source(loader, places[0])
for image, label in loader():
relu = fluid.layers.relu(image)
assert image.shape == [BATCH_SIZE, 784]
assert label.shape == [BATCH_SIZE, 1]
assert relu.shape == [BATCH_SIZE, 784]
Examples 2: import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
BATCH_SIZE = 16
BATCH_NUM = 4
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
USE_GPU = False # whether to use GPU
def _get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def __reader__():
for _ in range(BATCH_NUM):
batch_image, batch_label = _get_random_images_and_labels(
[BATCH_SIZE, IMAGE_SIZE], [BATCH_SIZE, CLASS_NUM])
yield batch_image, batch_label
def random_batch_reader():
return __reader__
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
# set device
paddle.set_device('gpu' if USE_GPU else 'cpu')
# create network
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.CrossEntropyLoss()
adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters())
# create data loader
loader = paddle.io.DataLoader.from_generator(capacity=5)
loader.set_batch_generator(random_batch_reader())
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(loader()):
out = layer(image)
loss = loss_fn(out, label)
loss.backward()
adam.step()
adam.clear_grad()
print("Epoch {} batch {}: loss = {}".format(
epoch_id, batch_id, np.mean(loss.numpy())))
Examples 3:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid '''
Example of `drop_last` using in static graph multi-cards mode
'''
import paddle
import paddle.static as static
import numpy as np import numpy as np
import os import os
# We use 2 CPU cores to run inference network # We use 2 CPU cores to run inference network
os.environ['CPU_NUM'] = '2' os.environ['CPU_NUM'] = '2'
paddle.enable_static()
# The data source has only 3 batches, which can not be # The data source has only 3 batches, which can not be
# divided evenly to each CPU core # divided evenly to each CPU core
def batch_generator(): def batch_generator():
for i in range(3): for i in range(3):
yield np.array([i+1]).astype('float32'), yield np.array([i+1]).astype('float32'),
x = fluid.data(name='x', shape=[None], dtype='float32') x = static.data(name='x', shape=[None], dtype='float32')
y = x * x y = x * x
def run_inference(drop_last): def run_inference(drop_last):
loader = fluid.io.DataLoader.from_generator(feed_list=[x], loader = paddle.io.DataLoader.from_generator(feed_list=[x],
capacity=8, drop_last=drop_last) capacity=8, drop_last=drop_last)
loader.set_batch_generator(batch_generator, fluid.cpu_places()) loader.set_batch_generator(batch_generator, static.cpu_places())
exe = fluid.Executor(fluid.CPUPlace()) exe = static.Executor(paddle.CPUPlace())
prog = fluid.CompiledProgram(fluid.default_main_program()) prog = static.CompiledProgram(static.default_main_program())
prog = prog.with_data_parallel() prog = prog.with_data_parallel()
result = [] result = []
...@@ -698,18 +769,22 @@ class DataLoader(object): ...@@ -698,18 +769,22 @@ class DataLoader(object):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import paddle.static as static
image = fluid.data(name='image', shape=[None, 784], dtype='float32') paddle.enable_static()
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
image = static.data(name='image', shape=[None, 784], dtype='float32')
label = static.data(name='label', shape=[None, 1], dtype='int64')
dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset = paddle.distributed.QueueDataset()
dataset.set_batch_size(32) dataset.init(
batch_size=32,
pipe_command='cat',
use_var=[image, label])
dataset.set_filelist(['a.txt', 'b.txt', 'c.txt']) dataset.set_filelist(['a.txt', 'b.txt', 'c.txt'])
dataset.set_use_var([image, label])
dataset.set_pipe_command('cat')
loader = fluid.io.DataLoader.from_dataset(dataset, fluid.cpu_places()) loader = paddle.io.DataLoader.from_dataset(dataset, static.cpu_places())
""" """
return DatasetLoader(dataset, places, drop_last) return DatasetLoader(dataset, places, drop_last)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册