未验证 提交 426de255 编写于 作者: H Huihuang Zheng 提交者: GitHub

Refine Executor API English Doc for 2.0rc (#27857)

As the title
上级 9b3b3b74
...@@ -480,7 +480,7 @@ class Executor(object): ...@@ -480,7 +480,7 @@ class Executor(object):
and single/multiple-CPU running. and single/multiple-CPU running.
Args: Args:
place(fluid.CPUPlace()|fluid.CUDAPlace(n)|None): This parameter represents place(paddle.CPUPlace()|paddle.CUDAPlace(n)|None): This parameter represents
which device the executor runs on. When this parameter is None, PaddlePaddle which device the executor runs on. When this parameter is None, PaddlePaddle
will set the default device according to its installation version. If Paddle will set the default device according to its installation version. If Paddle
is CPU version, the default device would be set to `CPUPlace()` . If Paddle is is CPU version, the default device would be set to `CPUPlace()` . If Paddle is
...@@ -492,43 +492,42 @@ class Executor(object): ...@@ -492,43 +492,42 @@ class Executor(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import paddle.fluid.compiler as compiler
import numpy import numpy
import os import os
# Executor is only used in static graph mode
paddle.enable_static()
# Set place explicitly. # Set place explicitly.
# use_cuda = True # use_cuda = True
# place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
# exe = fluid.Executor(place) # exe = paddle.static.Executor(place)
# If you don't set place, PaddlePaddle sets the default device. # If you don't set place, PaddlePaddle sets the default device.
exe = fluid.Executor() exe = paddle.static.Executor()
train_program = fluid.Program() train_program = paddle.static.Program()
startup_program = fluid.Program() startup_program = paddle.static.Program()
with fluid.program_guard(train_program, startup_program): with paddle.static.program_guard(train_program, startup_program):
data = fluid.data(name='X', shape=[None, 1], dtype='float32') data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10) hidden = paddle.static.nn.fc(data, 10)
loss = fluid.layers.mean(hidden) loss = paddle.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss) paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
# Run the startup program once and only once. # Run the startup program once and only once.
# Not need to optimize/compile the startup program. # Not need to optimize/compile the startup program.
startup_program.random_seed=1
exe.run(startup_program) exe.run(startup_program)
# Run the main program directly without compile. # Run the main program directly without compile.
x = numpy.random.random(size=(10, 1)).astype('float32') x = numpy.random.random(size=(10, 1)).astype('float32')
loss_data, = exe.run(train_program, loss_data, = exe.run(train_program, feed={"X": x}, fetch_list=[loss.name])
feed={"X": x},
fetch_list=[loss.name])
# Or, compiled the program and run. See `CompiledProgram` # Or, compiled the program and run. See `CompiledProgram`
# for more detail. # for more details.
# NOTE: If you use CPU to run the program or Paddle is # NOTE: If you use CPU to run the program or Paddle is
# CPU version, you need to specify the CPU_NUM, otherwise, # CPU version, you need to specify the CPU_NUM, otherwise,
# fluid will use all the number of the logic core as # PaddlePaddle will use all the number of the logic core as
# the CPU_NUM, in that case, the batch size of the input # the CPU_NUM, in that case, the batch size of the input
# should be greater than CPU_NUM, if not, the process will be # should be greater than CPU_NUM, if not, the process will be
# failed by an exception. # failed by an exception.
...@@ -540,12 +539,10 @@ class Executor(object): ...@@ -540,12 +539,10 @@ class Executor(object):
# If you don't set place and PaddlePaddle is CPU version # If you don't set place and PaddlePaddle is CPU version
os.environ['CPU_NUM'] = str(2) os.environ['CPU_NUM'] = str(2)
compiled_prog = compiler.CompiledProgram( compiled_prog = paddle.static.CompiledProgram(
train_program).with_data_parallel( train_program).with_data_parallel(loss_name=loss.name)
loss_name=loss.name) loss_data, = exe.run(compiled_prog, feed={"X": x}, fetch_list=[loss.name])
loss_data, = exe.run(compiled_prog,
feed={"X": x},
fetch_list=[loss.name])
""" """
def __init__(self, place=None): def __init__(self, place=None):
...@@ -842,10 +839,10 @@ class Executor(object): ...@@ -842,10 +839,10 @@ class Executor(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
cpu = fluid.CPUPlace() cpu = paddle.CPUPlace()
exe = fluid.Executor(cpu) exe = paddle.static.Executor(cpu)
# execute training or testing # execute training or testing
exe.close() exe.close()
""" """
...@@ -928,17 +925,17 @@ class Executor(object): ...@@ -928,17 +925,17 @@ class Executor(object):
Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor
will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some
operators of the :code:`Program` or :code:`CompiledProgram` according to fetch_list. And you could operators of the :code:`Program` or :code:`CompiledProgram` according to fetch_list. And you could
specify the scope to store the :code:`Variables` during the executor running if the scope specify the scope to store the :code:`Tensor` during the executor running if the scope
is not set, the executor will use the global scope, i.e. :code:`fluid.global_scope()`. is not set, the executor will use the global scope, i.e. :code:`paddle.static.global_scope()`.
Args: Args:
program(Program|CompiledProgram): This parameter represents the :code:`Program` or program(Program|CompiledProgram): This parameter represents the :code:`Program` or
:code:`CompiledProgram` to be executed. If this parameter is not provided, that :code:`CompiledProgram` to be executed. If this parameter is not provided, that
parameter is None, the program will be set to :code:`fluid.default_main_program()`. parameter is None, the program will be set to :code:`paddle.static.default_main_program()`.
The default is None. The default is None.
feed(list|dict): This parameter represents the input variables of the model. feed(list|dict): This parameter represents the input Tensors of the model.
If it is single card training, the feed is dict type, and if it is multi-card If it is single card training, the feed is dict type, and if it is multi-card
training, the parameter feed can be dict or list type variable. If the training, the parameter feed can be dict or list of Tensors. If the
parameter type is dict, the data in the feed will be split and sent to parameter type is dict, the data in the feed will be split and sent to
multiple devices (CPU/GPU), that is to say, the input data will be evenly multiple devices (CPU/GPU), that is to say, the input data will be evenly
sent to different devices, so you should make sure the number of samples of sent to different devices, so you should make sure the number of samples of
...@@ -946,23 +943,23 @@ class Executor(object): ...@@ -946,23 +943,23 @@ class Executor(object):
if the parameter type is list, those data are copied directly to each device, if the parameter type is list, those data are copied directly to each device,
so the length of this list should be equal to the number of places. so the length of this list should be equal to the number of places.
The default is None. The default is None.
fetch_list(list): This parameter represents the variables that need to be returned fetch_list(list): This parameter represents the Tensors that need to be returned
after the model runs. The default is None. after the model runs. The default is None.
feed_var_name(str): This parameter represents the name of the input variable of feed_var_name(str): This parameter represents the name of the input Tensor of
the feed operator. The default is "feed". the feed operator. The default is "feed".
fetch_var_name(str): This parameter represents the name of the output variable of fetch_var_name(str): This parameter represents the name of the output Tensor of
the fetch operator. The default is "fetch". the fetch operator. The default is "fetch".
scope(Scope): the scope used to run this program, you can switch scope(Scope): the scope used to run this program, you can switch
it to different scope. default is :code:`fluid.global_scope()` it to different scope. default is :code:`paddle.static.global_scope()`
return_numpy(bool): This parameter indicates whether convert the fetched variables return_numpy(bool): This parameter indicates whether convert the fetched Tensors
(the variable specified in the fetch list) to numpy.ndarray. if it is False, (the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
the type of the return value is a list of :code:`LoDTensor`. The default is True. the type of the return value is a list of :code:`LoDTensor`. The default is True.
use_program_cache(bool): This parameter indicates whether the input :code:`Program` is cached. use_program_cache(bool): This parameter indicates whether the input :code:`Program` is cached.
If the parameter is True, the model may run faster in the following cases: If the parameter is True, the model may run faster in the following cases:
the input program is :code:`fluid.Program`, and the parameters(program, feed variable name the input program is :code:`paddle.static.Program`, and the parameters(program, feed Tensor name
and fetch_list variable) of this interface remains unchanged during running. and fetch_list Tensor) of this interface remains unchanged during running.
The default is False. The default is False.
return_merged(bool): This parameter indicates whether fetched variables (the variables return_merged(bool): This parameter indicates whether fetched Tensors (the Tensors
specified in the fetch list) should be merged according to the execution device dimension. specified in the fetch list) should be merged according to the execution device dimension.
If :code:`return_merged` is False, the type of the return value is a two-dimensional list If :code:`return_merged` is False, the type of the return value is a two-dimensional list
of :code:`Tensor` / :code:`LoDTensorArray` ( :code:`return_numpy` is False) or a two-dimensional of :code:`Tensor` / :code:`LoDTensorArray` ( :code:`return_numpy` is False) or a two-dimensional
...@@ -996,29 +993,30 @@ class Executor(object): ...@@ -996,29 +993,30 @@ class Executor(object):
number of CPU cores or GPU cards, if it is less than, it is recommended that number of CPU cores or GPU cards, if it is less than, it is recommended that
the batch be discarded. the batch be discarded.
2. If the number of CPU cores or GPU cards available is greater than 1, the fetch 2. If the number of CPU cores or GPU cards available is greater than 1, the fetch
results are spliced together in dimension 0 for the same variable values results are spliced together in dimension 0 for the same Tensor values
(variables in fetch_list) on different devices. (Tensors in fetch_list) on different devices.
Examples 1: Examples 1:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import numpy import numpy
# First create the Executor. # First create the Executor.
place = fluid.CPUPlace() # fluid.CUDAPlace(0) paddle.enable_static()
exe = fluid.Executor(place) place = paddle.CPUPlace() # paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = fluid.data(name='X', shape=[None, 1], dtype='float32') data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10) hidden = paddle.static.nn.fc(data, 10)
loss = fluid.layers.mean(hidden) loss = paddle.mean(hidden)
adam = fluid.optimizer.Adam() adam = paddle.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
i = fluid.layers.zeros(shape=[1], dtype='int64') i = paddle.zeros(shape=[1], dtype='int64')
array = fluid.layers.array_write(x=loss, i=i) array = paddle.fluid.layers.array_write(x=loss, i=i)
# Run the startup program once and only once. # Run the startup program once and only once.
exe.run(fluid.default_startup_program()) exe.run(paddle.static.default_startup_program())
x = numpy.random.random(size=(10, 1)).astype('float32') x = numpy.random.random(size=(10, 1)).astype('float32')
loss_val, array_val = exe.run(feed={'X': x}, loss_val, array_val = exe.run(feed={'X': x},
...@@ -1029,46 +1027,52 @@ class Executor(object): ...@@ -1029,46 +1027,52 @@ class Executor(object):
Examples 2: Examples 2:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import numpy as np import numpy as np
# First create the Executor. # First create the Executor.
place = fluid.CUDAPlace(0) paddle.enable_static()
exe = fluid.Executor(place) place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = fluid.data(name='X', shape=[None, 1], dtype='float32') data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
class_dim = 2 class_dim = 2
prediction = fluid.layers.fc(input=data, size=class_dim) prediction = paddle.static.nn.fc(data, class_dim)
loss = fluid.layers.mean(prediction) loss = paddle.mean(prediction)
adam = fluid.optimizer.Adam() adam = paddle.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
# Run the startup program once and only once. # Run the startup program once and only once.
exe.run(fluid.default_startup_program()) exe.run(paddle.static.default_startup_program())
build_strategy = fluid.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
binary = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel( binary = paddle.static.CompiledProgram(
paddle.static.default_main_program()).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
batch_size = 6 batch_size = 6
x = np.random.random(size=(batch_size, 1)).astype('float32') x = np.random.random(size=(batch_size, 1)).astype('float32')
# Set return_merged as False to fetch unmerged results: # Set return_merged as False to fetch unmerged results:
unmerged_prediction, = exe.run(binary, feed={'X': x}, unmerged_prediction, = exe.run(binary,
feed={'X': x},
fetch_list=[prediction.name], fetch_list=[prediction.name],
return_merged=False) return_merged=False)
# If the user uses two GPU cards to run this python code, the printed result will be # If the user uses two GPU cards to run this python code, the printed result will be
# (2, 3, class_dim). The first dimension value of the printed result is the number of used # (2, 3, class_dim). The first dimension value of the printed result is the number of used
# GPU cards, and the second dimension value is the quotient of batch_size and the # GPU cards, and the second dimension value is the quotient of batch_size and the
# number of used GPU cards. # number of used GPU cards.
print("The unmerged prediction shape: {}".format(np.array(unmerged_prediction).shape)) print("The unmerged prediction shape: {}".format(
np.array(unmerged_prediction).shape))
print(unmerged_prediction) print(unmerged_prediction)
# Set return_merged as True to fetch merged results: # Set return_merged as True to fetch merged results:
merged_prediction, = exe.run(binary, feed={'X': x}, merged_prediction, = exe.run(binary,
feed={'X': x},
fetch_list=[prediction.name], fetch_list=[prediction.name],
return_merged=True) return_merged=True)
# If the user uses two GPU cards to run this python code, the printed result will be # If the user uses two GPU cards to run this python code, the printed result will be
# (6, class_dim). The first dimension value of the printed result is the batch_size. # (6, class_dim). The first dimension value of the printed result is the batch_size.
print("The merged prediction shape: {}".format(np.array(merged_prediction).shape)) print("The merged prediction shape: {}".format(
np.array(merged_prediction).shape))
print(merged_prediction) print(merged_prediction)
# Out: # Out:
...@@ -1085,6 +1089,7 @@ class Executor(object): ...@@ -1085,6 +1089,7 @@ class Executor(object):
# [-0.24635398 -0.13003758] # [-0.24635398 -0.13003758]
# [-0.49232286 -0.25939852] # [-0.49232286 -0.25939852]
# [-0.44514108 -0.2345845 ]] # [-0.44514108 -0.2345845 ]]
""" """
try: try:
return self._run_impl( return self._run_impl(
...@@ -1508,9 +1513,9 @@ class Executor(object): ...@@ -1508,9 +1513,9 @@ class Executor(object):
thread(int): number of thread a user wants to run in this function. Default is 0, which thread(int): number of thread a user wants to run in this function. Default is 0, which
means using thread num of dataset means using thread num of dataset
debug(bool): whether a user wants to run infer_from_dataset, default is False debug(bool): whether a user wants to run infer_from_dataset, default is False
fetch_list(Variable List): fetch variable list, each variable will be printed during fetch_list(Tensor List): fetch Tensor list, each Tensor will be printed during
training, default is None training, default is None
fetch_info(String List): print information for each variable, default is None fetch_info(String List): print information for each Tensor, default is None
print_period(int): the number of mini-batches for each print, default is 100 print_period(int): the number of mini-batches for each print, default is 100
fetch_handler(FetchHandler): a user define class for fetch output. fetch_handler(FetchHandler): a user define class for fetch output.
...@@ -1521,19 +1526,21 @@ class Executor(object): ...@@ -1521,19 +1526,21 @@ class Executor(object):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
place = fluid.CPUPlace() # you can set place = fluid.CUDAPlace(0) to use gpu paddle.enable_static()
exe = fluid.Executor(place) place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu
x = fluid.data(name="x", shape=[None, 10, 10], dtype="int64") exe = paddle.static.Executor(place)
y = fluid.data(name="y", shape=[None, 1], dtype="int64", lod_level=1) x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64")
dataset = fluid.DatasetFactory().create_dataset() y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1)
dataset = paddle.fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y]) dataset.set_use_var([x, y])
dataset.set_thread(1) dataset.set_thread(1)
filelist = [] # you should set your own filelist, e.g. filelist = ["dataA.txt"] # you should set your own filelist, e.g. filelist = ["dataA.txt"]
filelist = []
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program()) exe.run(paddle.static.default_startup_program())
exe.infer_from_dataset(program=fluid.default_main_program(), exe.infer_from_dataset(program=paddle.static.default_main_program(),
dataset=dataset) dataset=dataset)
""" """
...@@ -1627,9 +1634,9 @@ class Executor(object): ...@@ -1627,9 +1634,9 @@ class Executor(object):
thread(int): number of thread a user wants to run in this function. Default is 0, which thread(int): number of thread a user wants to run in this function. Default is 0, which
means using thread num of dataset means using thread num of dataset
debug(bool): whether a user wants to run train_from_dataset debug(bool): whether a user wants to run train_from_dataset
fetch_list(Variable List): fetch variable list, each variable will be printed fetch_list(Tensor List): fetch Tensor list, each variable will be printed
during training during training
fetch_info(String List): print information for each variable, its length should be equal fetch_info(String List): print information for each Tensor, its length should be equal
to fetch_list to fetch_list
print_period(int): the number of mini-batches for each print, default is 100 print_period(int): the number of mini-batches for each print, default is 100
fetch_handler(FetchHandler): a user define class for fetch output. fetch_handler(FetchHandler): a user define class for fetch output.
...@@ -1641,19 +1648,21 @@ class Executor(object): ...@@ -1641,19 +1648,21 @@ class Executor(object):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
place = fluid.CPUPlace() # you can set place = fluid.CUDAPlace(0) to use gpu paddle.enable_static()
exe = fluid.Executor(place) place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu
x = fluid.data(name="x", shape=[None, 10, 10], dtype="int64") exe = paddle.static.Executor(place)
y = fluid.data(name="y", shape=[None, 1], dtype="int64", lod_level=1) x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64")
dataset = fluid.DatasetFactory().create_dataset() y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1)
dataset = paddle.fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y]) dataset.set_use_var([x, y])
dataset.set_thread(1) dataset.set_thread(1)
filelist = [] # you should set your own filelist, e.g. filelist = ["dataA.txt"] # you should set your own filelist, e.g. filelist = ["dataA.txt"]
filelist = []
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program()) exe.run(paddle.static.default_startup_program())
exe.train_from_dataset(program=fluid.default_main_program(), exe.train_from_dataset(program=paddle.static.default_main_program(),
dataset=dataset) dataset=dataset)
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册