提交 2117932d 编写于 作者: L Leo Chen 提交者: Zeng Jinle

Polish en doc of data_loader (#1484)

* polish en doc of data_loader

* follow comments

* use fluid.data()
上级 86f9c0db
.. _user_guide_use_py_reader_en:
############################################
Use PyReader to read training and test data
Asynchronous Data Reading
############################################
Besides Python Reader, we provide PyReader. The performance of PyReader is better than :ref:`user_guide_use_numpy_array_as_train_data_en` , because the process of loading data is asynchronous with the process of training model when PyReader is in use. And PyReader can coordinate with :code:`double_buffer_reader` to improve the performance of reading data. What's more, :code:`double_buffer_reader` can achieve the transformation from CPU Tensor to GPU Tensor, which improve the efficiency of reading data to some extent.
Besides synchronous data reading, we provide DataLoader. The performance of DataLoader is better than :ref:`user_guide_use_numpy_array_as_train_data_en` , because data reading and model training process is asynchronous
when DataLoader is in use, and it can cooperate with :code:`double_buffer_reader` to improve the performance of reading data. What's more, :code:`double_buffer_reader` can achieve the asynchronous transformation from CPU Tensor to GPU Tensor, which improves the efficiency of reading data to some extent.
Create PyReader Object
Create DataLoader Object
################################
You can create PyReader object as follows:
You can create DataLoader object as follows:
.. code-block:: python
import paddle.fluid as fluid
py_reader = fluid.layers.py_reader(capacity=64,
shapes=[(-1,784), (-1,1)],
dtypes=['float32', 'int64'],
name='py_reader',
use_double_buffer=True)
image = fluid.data(name='image', dtype='float32', shape=[None, 784])
label = fluid.data(name='label', dtype='int64', shape=[None, 1])
In the code, ``capacity`` is buffer size of PyReader;
``shapes`` is the size of parameters in the batch (such as image and label in picture classification task);
``dtypes`` is data type of parameters in the batch;
``name`` is name of PyReader instance;
``use_double_buffer`` is True by default, which means :code:`double_buffer_reader` is used.
ITERABLE = True
Attention: If you want to create multiple PyReader objects(such as two different PyReader in training and inference period respectively), you have to appoint different names for different PyReader objects,since PaddlePaddle uses different names to distinguish different variables, and `Program.clone()` (reference to :ref:`api_fluid_Program_clone` )can't copy PyReader objects.
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label], capacity=64, use_double_buffer=True, iterable=ITERABLE)
.. code-block:: python
In the code,
- ``feed_list`` is the list of input variables;
- ``capacity`` is the buffer size of the DataLoader object in batches;
- ``use_double_buffer`` is True by default, which means ``double_buffer_reader`` is used. It is recommended, because it can improve data reading speed;
- ``iterable`` is True by default, which means the DataLoader object is For-Range iterative. When ``iterable = True`` , DataLoader decouples from the Program, which means defining DataLoader objects does not change Program; when When ``iterable = False`` , DataLoader inserts operators related to data reading into Program.
import paddle.fluid as fluid
train_py_reader = fluid.layers.py_reader(capacity=64,
shapes=[(-1,784), (-1,1)],
dtypes=['float32', 'int64'],
name='train',
use_double_buffer=True)
Attention: ``Program.clone()`` (reference to :ref:`api_fluid_Program` )can't copy DataLoader objects.
If you want to create multiple DataLoader objects(such as two different DataLoaders in training and inference period respectively), you have to define different DataLoader objects.
While using DataLoader, if you need to share the model parameters of training and testing periods, you can use :code:`fluid.unique_name.guard()`.
test_py_reader = fluid.layers.py_reader(capacity=64,
shapes=[(-1,3,224,224), (-1,1)],
dtypes=['float32', 'int64'],
name='test',
use_double_buffer=True)
Notes: Paddle use different names to distinguish different variables, and the names are generated by the counter in :code:`unique_name` module, which rises by one every time a variable name is generated. :code:`fluid.unique_name.guard()` aims to reset the counter in :code:`unique_name` module, in order to ensure that the variable names are the same when calling :code:`fluid.unique_name.guard()` repeatedly, so that parameters can be shared.
While using PyReader, if you need to share the model parameters of training and test periods, you can use :code:`fluid.unique_name.guard()` .
Notes: Paddle use different names to distinguish different variables, and the names are generated by the counter in :code:`unique_name` module. By the way, the counts rise by one every time a variable name is generated. :code:`fluid.unique_name.guard()` aims to reset the counter in :code:`unique_name` module, in order to ensure that the variable names are the same when calling :code:`fluid.unique_name.guard()` repeatedly, so that parameters can be shared.
An example of configuring networks during the training and test periods by PyReader is as follows:
An example of configuring networks during the training and testing periods by DataLoader is as follows:
.. code-block:: python
......@@ -56,230 +47,220 @@ An example of configuring networks during the training and test periods by PyRea
import paddle.fluid as fluid
import paddle.dataset.mnist as mnist
import numpy
def network(is_train):
# Create py_reader object and give different names
# when is_train = True and is_train = False
reader = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
# Use read_file() method to read out the data from py_reader
img, label = fluid.layers.read_file(reader)
...
# Here, we omitted the definition of loss of the model
return loss , reader
def network():
image = fluid.data(name='image', dtype='float32', shape=[None, 784])
label = fluid.data(name='label', dtype='int64', shape=[None, 1])
loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=64)
# Define model.
fc = fluid.layers.fc(image, size=10)
xe = fluid.layers.softmax_with_cross_entropy(fc, label)
loss = fluid.layers.reduce_mean(xe)
return loss , loader
# Create main program and startup program for training
# Create main program and startup program for training.
train_prog = fluid.Program()
train_startup = fluid.Program()
with fluid.program_guard(train_prog, train_startup):
# Use fluid.unique_name.guard() to share parameters with test network
# Use fluid.unique_name.guard() to share parameters with test network.
with fluid.unique_name.guard():
train_loss, train_reader = network(True)
train_loss, train_loader = network()
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(train_loss)
# Create main program and startup program for testing
# Create main program and startup program for testing.
test_prog = fluid.Program()
test_startup = fluid.Program()
with fluid.program_guard(test_prog, test_startup):
# Use fluid.unique_name.guard() to share parameters with train network
with fluid.unique_name.guard():
test_loss, test_reader = network(False)
test_loss, test_loader = network()
Configure data source of PyReader objects
##########################################
PyReader object sets the data source by :code:`decorate_paddle_reader()` or :code:`decorate_tensor_provider()` :code:`decorate_paddle_reader()` and :code:`decorate_tensor_provider()` both receive the Python generator :code:`generator` as parameters. :code:`generator` generates a batch of data every time by yield ways inside.
The differences of :code:`decorate_paddle_reader()` and :code:`decorate_tensor_provider()` ways are:
Configure data source of DataLoader object
##########################################
DataLoader object sets the data source by :code:`set_sample_generator()`, :code:`set_sample_list_generator()` or :code:`set_batch_generator()` . These three methods all receive the Python generator :code:`generator` as parameters. The differences of are:
- :code:`generator` of :code:`decorate_paddle_reader()` should return data of Numpy Array type, but :code:`generator` of :code:`decorate_tensor_provider()` should return LoDTensor type.
- :code:`generator` of :code:`set_sample_generator()` should return data of :code:`[img_1, label_1]` type, in which ``img_1`` and ``label_1`` is one sample's data of Numpy array type.
- :code:`decorate_tensor_provider()` requires that the returned data type and size of LoDTensor of :code:`generator` have to match the appointed dtypes and shapes parameters while configuring py_reader, but :code:`decorate_paddle_reader()` doesn't have the requirements, since the data type and size can transform inside.
- :code:`generator` of :code:`set_sample_list_generator()` should return data of :code:`[(img_1, label_1), (img_2, label_2), ..., (img_n, label_n)]` type, in which ``img_i`` and ``label_i`` is one sample's data of Numpy array type, and ``n`` is batch size.
Specific ways are as follows:
- :code:`generator` of :code:`set_batch_generator()` should return data of :code:`[batched_imgs, batched_labels]` type, in which ``batched_imgs`` and ``batched_labels`` is one batch's data of Numpy array or LoDTensor type.
.. code-block:: python
Please note that, when using DataLoader for multi-GPU card (or multi-CPU core) training, the actual total batch size is the batch size of incoming user generator multiplied by the number of devices.
import paddle.fluid as fluid
import numpy as np
When :code:`iterable = True` (default) of DataLoader, ``places`` parameters must be passed to these three methods, specifying whether to convert data to CPU Tensor or GPU Tensor. When :code:`iterable = False` of DataLoader, there is no need to pass the ``places`` parameter.
BATCH_SIZE = 32
For example, suppose we have two readers, ``fake_sample_reader`` returns one sample's data at a time and ``fake_batch_reader`` returns one batch's data at a time.
# Case 1: Use decorate_paddle_reader() method to set the data source of py_reader
# The generator yields Numpy-typed batched data
def fake_random_numpy_reader():
image = np.random.random(size=(BATCH_SIZE, 784))
label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9)
yield image, label
.. code-block:: python
py_reader1 = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name='py_reader1',
use_double_buffer=True)
import paddle.fluid as fluid
import numpy as np
py_reader1.decorate_paddle_reader(fake_random_reader)
# Declare sample reader.
def fake_sample_reader():
for _ in range(100):
sample_image = np.random.random(size=(784, )).astype('float32')
sample_label = np.random.random_integers(size=(1, ), low=0, high=9).astype('int64')
yield sample_image, sample_label
# Case 2: Use decorate_tensor_provider() method to set the data source of py_reader
# The generator yields Tensor-typed batched data
def fake_random_tensor_provider():
image = np.random.random(size=(BATCH_SIZE, 784)).astype('float32')
label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9).astype('int64')
# Declare batch reader.
def fake_batch_reader():
batch_size = 32
for _ in range(100):
batch_image = np.random.random(size=(batch_size, 784)).astype('float32')
batch_label = np.random.random_integers(size=(batch_size, 1), low=0, high=9).astype('int64')
yield batch_image, batch_label
image_tensor = fluid.LoDTensor()
image_tensor.set(image, fluid.CPUPlace())
image1 = fluid.data(name='image1', dtype='float32', shape=[None, 784])
label1 = fluid.data(name='label1', dtype='int64', shape=[None, 1])
label_tensor = fluid.LoDTensor()
label_tensor.set(label, fluid.CPUPlace())
yield image_tensor, label_tensor
image2 = fluid.data(name='image2', dtype='float32', shape=[None, 784])
label2 = fluid.data(name='label2', dtype='int64', shape=[None, 1])
py_reader2 = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name='py_reader2',
use_double_buffer=True)
image3 = fluid.data(name='image3', dtype='float32', shape=[None, 784])
label3 = fluid.data(name='label3', dtype='int64', shape=[None, 1])
py_reader2.decorate_tensor_provider(fake_random_tensor_provider)
example usage:
The corresponding DataLoader are defined as follows:
.. code-block:: python
import paddle.batch
import paddle
import paddle.fluid as fluid
import numpy as np
BATCH_SIZE = 32
ITERABLE = True
USE_CUDA = True
USE_DATA_PARALLEL = True
if ITERABLE:
# If DataLoader is iterable, places should be set.
if USE_DATA_PARALLEL:
# Use all GPU cards or 8 CPU cores to train.
places = fluid.cuda_places() if USE_CUDA else fluid.cpu_places(8)
else:
# Use single GPU card or CPU core.
places = fluid.cuda_places(0) if USE_CUDA else fluid.cpu_places(1)
else:
# If DataLoader is not iterable, places shouldn't be set.
places = None
# Use sample reader to configure data source of DataLoader.
data_loader1 = fluid.io.DataLoader.from_generator(feed_list=[image1, label1], capacity=10, iterable=ITERABLE)
data_loader1.set_sample_generator(fake_sample_reader, batch_size=32, places=places)
# Use sample reader + fluid.io.batch to configure data source of DataLoader.
data_loader2 = fluid.io.DataLoader.from_generator(feed_list=[image2, label2], capacity=10, iterable=ITERABLE)
sample_list_reader = fluid.io.batch(fake_sample_reader, batch_size=32)
sample_list_reader = fluid.io.shuffle(sample_list_reader, buf_size=64) # Shuffle data if needed.
data_loader2.set_sample_list_generator(sample_list_reader, places=places)
# Use batch to configure data source of DataLoader.
data_loader3 = fluid.io.DataLoader.from_generator(feed_list=[image3, label3], capacity=10, iterable=ITERABLE)
data_loader3.set_batch_generator(fake_batch_reader, places=places)
Train and test model with DataLoader
##################################
# Case 1: Use decorate_paddle_reader() method to set the data source of py_reader
# The generator yields Numpy-typed batched data
def fake_random_numpy_reader():
image = np.random.random(size=(784, ))
label = np.random.random_integers(size=(1, ), low=0, high=9)
yield image, label
Examples of using DataLoader to train and test models are as follows:
py_reader1 = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name='py_reader1',
use_double_buffer=True)
- Step 1, we need to set up training network and testing network, define the corresponding DataLoader object, and configure the data source of DataLoader object.
py_reader1.decorate_paddle_reader(paddle.batch(fake_random_numpy_reader, batch_size=BATCH_SIZE))
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.dataset.mnist as mnist
import six
# Case 2: Use decorate_tensor_provider() method to set the data source of py_reader
# The generator yields Tensor-typed batched data
def fake_random_tensor_provider():
image = np.random.random(size=(BATCH_SIZE, 784)).astype('float32')
label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9).astype('int64')
yield image_tensor, label_tensor
ITERABLE = True
py_reader2 = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name='py_reader2',
use_double_buffer=True)
def network():
# Create data holder.
image = fluid.data(name='image', dtype='float32', shape=[None, 784])
label = fluid.data(name='label', dtype='int64', shape=[None, 1])
py_reader2.decorate_tensor_provider(fake_random_tensor_provider)
# Create DataLoader object.
reader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=64, iterable=ITERABLE)
Train and test model with PyReader
##################################
# Define model.
fc = fluid.layers.fc(image, size=10)
xe = fluid.layers.softmax_with_cross_entropy(fc, label)
loss = fluid.layers.reduce_mean(xe)
return loss , reader
Examples by using PyReader to train models and test are as follows:
# Create main program and startup program for training.
train_prog = fluid.Program()
train_startup = fluid.Program()
.. code-block:: python
# Define training network.
with fluid.program_guard(train_prog, train_startup):
# fluid.unique_name.guard() to share parameters with test network
with fluid.unique_name.guard():
train_loss, train_loader = network()
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(train_loss)
import paddle
import paddle.fluid as fluid
import paddle.dataset.mnist as mnist
import six
def network(is_train):
# Create py_reader object and give different names
# when is_train = True and is_train = False
reader = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
img, label = fluid.layers.read_file(reader)
...
# Here, we omitted the definition of loss of the model
return loss , reader
# Create main program and startup program for training
train_prog = fluid.Program()
train_startup = fluid.Program()
# Define train network
with fluid.program_guard(train_prog, train_startup):
# Use fluid.unique_name.guard() to share parameters with test network
with fluid.unique_name.guard():
train_loss, train_reader = network(True)
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(train_loss)
# Create main program and startup program for testing
test_prog = fluid.Program()
test_startup = fluid.Program()
# Define test network
with fluid.program_guard(test_prog, test_startup):
# Use fluid.unique_name.guard() to share parameters with train network
with fluid.unique_name.guard():
test_loss, test_reader = network(False)
# Create main program and startup program for testing.
test_prog = fluid.Program()
test_startup = fluid.Program()
# Define testing network.
with fluid.program_guard(test_prog, test_startup):
# Use fluid.unique_name.guard() to share parameters with train network
with fluid.unique_name.guard():
test_loss, test_loader = network()
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
# Run startup program
# Run startup_program for initialization.
exe.run(train_startup)
exe.run(test_startup)
# Compile programs
# Compile programs.
train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(loss_name=train_loss.name)
test_prog = fluid.CompiledProgram(test_prog).with_data_parallel(share_vars_from=train_prog)
# Set the data source of py_reader using decorate_paddle_reader() method
train_reader.decorate_paddle_reader(
paddle.reader.shuffle(paddle.batch(mnist.train(), 512), buf_size=8192))
# Configure data source of DataLoader.
places = fluid.cuda_places() if ITERABLE else None
train_loader.set_sample_list_generator(
fluid.io.shuffle(fluid.io.batch(mnist.train(), 512), buf_size=1024), places=places)
test_loader.set_sample_list_generator(fluid.io.batch(mnist.test(), 512), places=places)
- Step 2, we choose different ways to run the network according to whether the DataLoader object is iterable or not.
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
If :code:`iterable = True`, the DataLoader object is a Python generator that can iterate directly using for-range. The results returned by for-range are passed to the executor through the ``feed`` parameter of ``exe.run()``.
.. code-block:: python
def run_iterable(program, exe, loss, data_loader):
for data in data_loader():
loss_value = exe.run(program=program, feed=data, fetch_list=[loss])
print('loss is {}'.format(loss_value))
for epoch_id in six.moves.range(10):
train_reader.start()
try:
while True:
loss = exe.run(program=train_prog, fetch_list=[train_loss])
print 'train_loss', loss
except fluid.core.EOFException:
print 'End of epoch', epoch_id
train_reader.reset()
run_iterable(train_prog, exe, train_loss, train_loader)
run_iterable(test_prog, exe, test_loss, test_loader)
test_reader.start()
If :code:`iterable = False`, call the ``start()`` method to start the DataLoader object before each epoch starts, and call the ``reset()`` method to reset the status of the DataLoader object after catching the exception to start the iteration of next epoch, since ``exe.run()`` throws a ``fluid.core.EOFException`` exception at the end of each epoch. When :code:`iterable = False`, there is no need to pass ``feed`` parameter to ``exe.run()``. The specific ways are as follows:
.. code-block:: python
def run_non_iterable(program, exe, loss, data_loader):
data_loader.start()
try:
while True:
loss = exe.run(program=test_prog, fetch_list=[test_loss])
print 'test loss', loss
loss_value = exe.run(program=program, fetch_list=[loss])
print('loss is {}'.format(loss_value))
except fluid.core.EOFException:
print 'End of testing'
test_reader.reset()
print('End of epoch')
data_loader.reset()
Specific steps are as follows:
1. Before the start of every epoch, call :code:`start()` to invoke PyReader;
2. At the end of every epoch, :code:`read_file` throws exception :code:`fluid.core.EOFException` . Call :code:`reset()` after catching up exception to reset the state of PyReader in order to start next epoch.
for epoch_id in six.moves.range(10):
run_non_iterable(train_prog, exe, train_loss, train_loader)
run_non_iterable(test_prog, exe, test_loss, test_loader)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册