提交 30a2946e 编写于 作者: A acosta123 提交者: Cheerego

Update use_py_reader_en.rst (#793)

* Update use_py_reader_en.rst

* Update use_py_reader_en.rst

* Update use_py_reader_en.rst

* Update doc/fluid/user_guides/howto/prepare_data/use_py_reader_en.rst
Co-Authored-By: Nacosta123 <42226556+acosta123@users.noreply.github.com>

* Update use_py_reader_en.rst
上级 93ec3b37
......@@ -4,8 +4,7 @@
Use PyReader to read training and test data
############################################
Paddle Fluid supports PyReader, which implements feeding data from Python to C++. Different from :ref:`user_guide_use_numpy_array_as_train_data_en` , the process of loading data to Python is asynchronous with the process of :code:`Executor::Run()` reading data when PyReader is in use.
Moreover, PyReader is able to work with :code:`double_buffer_reader` to upgrade the performance of reading data.
Besides Python Reader, we provide PyReader. The performance of PyReader is better than :ref:`user_guide_use_numpy_array_as_train_data` , because the process of loading data is asynchronous with the process of training model when PyReader is in use. And PyReader can coordinate with :code:`double_buffer_reader` to improve the performance of reading data. What's more, :code:`double_buffer_reader` can achieve the transformation from CPU Tensor to GPU Tensor, which improve the efficiency of reading data to some extent.
Create PyReader Object
################################
......@@ -17,7 +16,7 @@ You can create PyReader object as follows:
import paddle.fluid as fluid
py_reader = fluid.layers.py_reader(capacity=64,
shapes=[(-1,3,224,224), (-1,1)],
shapes=[(-1,784), (-1,1)],
dtypes=['float32', 'int64'],
name='py_reader',
use_double_buffer=True)
......@@ -28,14 +27,14 @@ In the code, ``capacity`` is buffer size of PyReader;
``name`` is name of PyReader instance;
``use_double_buffer`` is True by default, which means :code:`double_buffer_reader` is used.
To create some different PyReader objects (Usually, you have to create two different PyReader objects for training and testing phase), the names of objects must be different. For example, In the same task, PyReader objects in training and testing period are created as follows:
Attention: If you want to create multiple PyReader objects(such as two different PyReader in training and inference period respectively), you have to appoint different names for different PyReader objects,since PaddlePaddle uses different names to distinguish different variables, and `Program.clone()` (reference to :ref:`api_fluid_Program_clone` )can't copy PyReader objects.
.. code-block:: python
import paddle.fluid as fluid
train_py_reader = fluid.layers.py_reader(capacity=64,
shapes=[(-1,3,224,224), (-1,1)],
shapes=[(-1,784), (-1,1)],
dtypes=['float32', 'int64'],
name='train',
use_double_buffer=True)
......@@ -46,11 +45,10 @@ To create some different PyReader objects (Usually, you have to create two diffe
name='test',
use_double_buffer=True)
Note: You could not copy PyReader object with :code:`Program.clone()` so you have to create PyReader objects in training and testing phase with the method mentioned above
While using PyReader, if you need to share the model parameters of training and test periods, you can use :code:`fluid.unique_name.guard()` .
Notes: Paddle use different names to distinguish different variables, and the names are generated by the counter in :code:`unique_name` module. By the way, the counts rise by one every time a variable name is generated. :code:`fluid.unique_name.guard()` aims to reset the counter in :code:`unique_name` module, in order to ensure that the variable names are the same when calling :code:`fluid.unique_name.guard()` repeatedly, so that parameters can be shared.
Because you could not copy PyReader with :code:`Program.clone()` so you have to share the parameters of training phase with testing phase through :code:`fluid.unique_name.guard()` .
Details are as follows:
An example of configuring networks during the training and test periods by PyReader is as follows:
.. code-block:: python
......@@ -61,41 +59,97 @@ Details are as follows:
import numpy
def network(is_train):
# Create py_reader object and give different names
# when is_train = True and is_train = False
reader = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
# Use read_file() method to read out the data from py_reader
img, label = fluid.layers.read_file(reader)
...
# Here, we omitted the definition of loss of the model
return loss , reader
# Create main program and startup program for training
train_prog = fluid.Program()
train_startup = fluid.Program()
with fluid.program_guard(train_prog, train_startup):
# Use fluid.unique_name.guard() to share parameters with test network
with fluid.unique_name.guard():
train_loss, train_reader = network(True)
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(train_loss)
# Create main program and startup program for testing
test_prog = fluid.Program()
test_startup = fluid.Program()
with fluid.program_guard(test_prog, test_startup):
# Use fluid.unique_name.guard() to share parameters with train network
with fluid.unique_name.guard():
test_loss, test_reader = network(False)
Configure data source of PyReader objects
##########################################
PyReader provides :code:`decorate_tensor_provider` and :code:`decorate_paddle_reader` , both of which receieve Python :code:`generator` as data source.The difference is:
PyReader object sets the data source by :code:`decorate_paddle_reader()` or :code:`decorate_tensor_provider()` :code:`decorate_paddle_reader()` and :code:`decorate_tensor_provider()` both receive the Python generator :code:`generator` as parameters. :code:`generator` generates a batch of data every time by yield ways inside.
The differences of :code:`decorate_paddle_reader()` and :code:`decorate_tensor_provider()` ways are:
- :code:`generator` of :code:`decorate_paddle_reader()` should return data of Numpy Array type, but :code:`generator` of :code:`decorate_tensor_provider()` should return LoDTensor type.
- :code:`decorate_tensor_provider()` requires that the returned data type and size of LoDTensor of :code:`generator` have to match the appointed dtypes and shapes parameters while configuring py_reader, but :code:`decorate_paddle_reader()` doesn't have the requirements, since the data type and size can transform inside.
Specific ways are as follows:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
1. :code:`decorate_tensor_provider` : :code:`generator` generates a :code:`list` or :code:`tuple` each time, with each element of :code:`list` or :code:`tuple` being :code:`LoDTensor` or Numpy array, and :code:`LoDTensor` or :code:`shape` of Numpy array must be the same as :code:`shapes` stated while PyReader is created.
BATCH_SIZE = 32
# Case 1: Use decorate_paddle_reader() method to set the data source of py_reader
# The generator yields Numpy-typed batched data
def fake_random_numpy_reader():
image = np.random.random(size=(BATCH_SIZE, 784))
label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9)
yield image, label
2. :code:`decorate_paddle_reader` : :code:`generator` generates a :code:`list` or :code:`tuple` each time, with each element of :code:`list` or :code:`tuple` being Numpy array,but the :code:`shape` of Numpy array doesn't have to be the same as :code:`shape` stated while PyReader is created. :code:`decorate_paddle_reader` will :code:`reshape` Numpy array internally.
py_reader1 = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name='py_reader1',
use_double_buffer=True)
py_reader1.decorate_paddle_reader(fake_random_reader)
# Case 2: Use decorate_tensor_provider() method to set the data source of py_reader
# The generator yields Tensor-typed batched data
def fake_random_tensor_provider():
image = np.random.random(size=(BATCH_SIZE, 784)).astype('float32')
label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9).astype('int64')
image_tensor = fluid.LoDTensor()
image_tensor.set(image, fluid.CPUPlace())
label_tensor = fluid.LoDTensor()
label_tensor.set(label, fluid.CPUPlace())
yield image_tensor, label_tensor
py_reader2 = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name='py_reader2',
use_double_buffer=True)
py_reader2.decorate_tensor_provider(fake_random_tensor_provider)
example usage:
.. code-block:: python
......@@ -142,32 +196,75 @@ example usage:
Train and test model with PyReader
##################################
Details are as follows(the remaining part of the code above):
Examples by using PyReader to train models and test are as follows:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import paddle.dataset.mnist as mnist
import six
def network(is_train):
# Create py_reader object and give different names
# when is_train = True and is_train = False
reader = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
img, label = fluid.layers.read_file(reader)
...
# Here, we omitted the definition of loss of the model
return loss , reader
# Create main program and startup program for training
train_prog = fluid.Program()
train_startup = fluid.Program()
# Define train network
with fluid.program_guard(train_prog, train_startup):
# Use fluid.unique_name.guard() to share parameters with test network
with fluid.unique_name.guard():
train_loss, train_reader = network(True)
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(train_loss)
# Create main program and startup program for testing
test_prog = fluid.Program()
test_startup = fluid.Program()
# Define test network
with fluid.program_guard(test_prog, test_startup):
# Use fluid.unique_name.guard() to share parameters with train network
with fluid.unique_name.guard():
test_loss, test_reader = network(False)
place = fluid.CUDAPlace(0)
startup_exe = fluid.Executor(place)
startup_exe.run(train_startup)
startup_exe.run(test_startup)
exe = fluid.Executor(place)
trainer = fluid.ParallelExecutor(
use_cuda=True, loss_name=train_loss.name, main_program=train_prog)
# Run startup program
exe.run(train_startup)
exe.run(test_startup)
tester = fluid.ParallelExecutor(
use_cuda=True, share_vars_from=trainer, main_program=test_prog)
# Compile programs
train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(loss_name=train_loss.name)
test_prog = fluid.CompiledProgram(test_prog).with_data_parallel(share_vars_from=train_prog)
# Set the data source of py_reader using decorate_paddle_reader() method
train_reader.decorate_paddle_reader(
paddle.reader.shuffle(paddle.batch(mnist.train(), 512), buf_size=8192))
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
for epoch_id in xrange(10):
for epoch_id in six.moves.range(10):
train_reader.start()
try:
while True:
print 'train_loss', numpy.array(
trainer.run(fetch_list=[train_loss.name]))
loss = exe.run(program=train_prog, fetch_list=[train_loss])
print 'train_loss', loss
except fluid.core.EOFException:
print 'End of epoch', epoch_id
train_reader.reset()
......@@ -175,8 +272,8 @@ Details are as follows(the remaining part of the code above):
test_reader.start()
try:
while True:
print 'test loss', numpy.array(
tester.run(fetch_list=[test_loss.name]))
loss = exe.run(program=test_prog, fetch_list=[test_loss])
print 'test loss', loss
except fluid.core.EOFException:
print 'End of testing'
test_reader.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册