未验证 提交 95b18683 编写于 作者: K Kaipeng Deng 提交者: GitHub

update DataLoader doc (#28290)

* update DataLoader doc. test=develop
上级 71d62207
...@@ -153,18 +153,22 @@ class DataLoader(object): ...@@ -153,18 +153,22 @@ class DataLoader(object):
multi-process workers will be used to load data asynchronously if multi-process workers will be used to load data asynchronously if
:attr:`num_workers` is set as a positive number. :attr:`num_workers` is set as a positive number.
DataLoader only supports map-style dataset(can get a sample from DataLoader supports map-style dataset and iterable-style dataset.
dataset with a given index) currently, for a map-style dataset,
please see :code:`paddle.io.Dataset`.
batch_sampler please see :code:`paddle.io.BatchSampler` For map-style datast(can get a sample from dataset with a given
index), please see :code:`paddle.io.Dataset`.
For iterable-style datast(get samples from dataset iteratively,
like a Python iterator), please see :code:`paddle.io.IterableDataset`.
For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler`
Args: Args:
dataset(Dataset): the dataset to load data from, should be an dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or instance of subclass of :code:`paddle.io.Dataset` or
:code:`paddle.io.IterableDataset`. :code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed variable list. feed_list (list(Tensor)|tuple(Tensor)): feed variable list.
The variables should be created by :code:`fluid.data()`. The variables should be created by :code:`paddle.static.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is :attr:`feed_list` must be set if :attr:`return_list` is
False. Default None. False. Default None.
places(list(Place)|tuple(Place)|optional): a list of Place, places(list(Place)|tuple(Place)|optional): a list of Place,
...@@ -173,10 +177,10 @@ class DataLoader(object): ...@@ -173,10 +177,10 @@ class DataLoader(object):
will be used. Default None. will be used. Default None.
return_list (bool): whether the return value on each device is return_list (bool): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> LoDTensor, where value on each device would be a dict of str -> Tensor, where
the key of the dict is the name of each fed variables. If the key of the dict is the name of each fed variables. If
:attr:`return_list=True`, the return value on each device would :attr:`return_list=True`, the return value on each device would
be a list(LoDTensor). :attr:`return_list` can only be True be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default False. in dynamic graph mode. Default False.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset` to generate batch indices to draw samples from :attr:`dataset`
...@@ -224,7 +228,8 @@ class DataLoader(object): ...@@ -224,7 +228,8 @@ class DataLoader(object):
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset, BatchSampler, DataLoader from paddle.io import Dataset, BatchSampler, DataLoader
BATCH_NUM = 20 BATCH_NUM = 20
...@@ -234,8 +239,6 @@ class DataLoader(object): ...@@ -234,8 +239,6 @@ class DataLoader(object):
IMAGE_SIZE = 784 IMAGE_SIZE = 784
CLASS_NUM = 10 CLASS_NUM = 10
USE_GPU = False # whether use GPU to run model
# define a random dataset # define a random dataset
class RandomDataset(Dataset): class RandomDataset(Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
...@@ -251,78 +254,34 @@ class DataLoader(object): ...@@ -251,78 +254,34 @@ class DataLoader(object):
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
# get places class SimpleNet(nn.Layer):
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
# --------------------- dygraph mode --------------------
class SimpleNet(fluid.dygraph.Layer):
def __init__(self): def __init__(self):
super(SimpleNet, self).__init__() super(SimpleNet, self).__init__()
self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax') self.fc = nn.Linear(IMAGE_SIZE, CLASS_NUM)
def forward(self, image, label=None): def forward(self, image, label=None):
return self.fc(image) return self.fc(image)
with fluid.dygraph.guard(places[0]): simple_net = SimpleNet()
simple_net = SimpleNet() opt = paddle.optimizer.SGD(learning_rate=1e-3,
opt = fluid.optimizer.SGD(learning_rate=1e-3, parameters=simple_net.parameters())
parameter_list=simple_net.parameters())
loader = DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
for e in range(EPOCH_NUM):
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
# -------------------------------------------------------
# -------------------- static graph ---------------------
paddle.enable_static()
def simple_net(image, label):
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM, act='softmax')
cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)
loss = fluid.layers.reduce_mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
return loss
image = fluid.data(name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
loss = simple_net(image, label)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
loader = DataLoader(dataset, loader = DataLoader(dataset,
feed_list=[image, label], batch_size=BATCH_SIZE,
batch_size=BATCH_SIZE,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
num_workers=2) num_workers=2)
for e in range(EPOCH_NUM): for e in range(EPOCH_NUM):
for i, data in enumerate(loader()): for i, (image, label) in enumerate(loader()):
l = exe.run(prog, feed=data, fetch_list=[loss], return_numpy=True) out = simple_net(image)
print("Epoch {} batch {}: loss = {}".format(e, i, l[0][0])) loss = F.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
# -------------------------------------------------------
.. note:: .. note::
For reading iterable dataset with multiprocess Dataloader, For reading iterable dataset with multiprocess Dataloader,
...@@ -439,6 +398,10 @@ class DataLoader(object): ...@@ -439,6 +398,10 @@ class DataLoader(object):
use_multiprocess=False, use_multiprocess=False,
drop_last=True): drop_last=True):
""" """
.. warning::
This API will be deprecated in the future, it is recommended to use
:code:`paddle.io.DataLoader` which supports multi-processes acceleration.
.. note:: .. note::
**The framework ensures that the data loading order of DataLoader is exactly the same as the user-defined data source.** **The framework ensures that the data loading order of DataLoader is exactly the same as the user-defined data source.**
...@@ -684,6 +647,10 @@ class DataLoader(object): ...@@ -684,6 +647,10 @@ class DataLoader(object):
@staticmethod @staticmethod
def from_dataset(dataset, places, drop_last=True): def from_dataset(dataset, places, drop_last=True):
""" """
.. warning::
This API will be deprecated in the future, it is recommended to use
:code:`paddle.io.DataLoader` which supports multi-processes acceleration.
Create an iterable DataLoader object for loading data from Dataset. Create an iterable DataLoader object for loading data from Dataset.
Dataset is only supported in Linux system currently. Dataset is only supported in Linux system currently.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册