未验证 提交 17391cd4 编写于 作者: C chengduo 提交者: GitHub

[Cherry pick] Polish executor doc (#20348)

* polish executor doc
test=develop

* polish compiler and executor

* polish doc
test=document_fix test=develop

* Update doc
test=document_fix

* fix fluid.data
test=document_fix

* fix doc batch dim
test=document_fix
上级 f11393d6
......@@ -27,11 +27,11 @@ paddle.fluid.Variable.numpy (ArgSpec(args=['self'], varargs=None, keywords=None,
paddle.fluid.Variable.set_value (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', '69deb77a9dedc61f7b731a6a7709fa5b'))
paddle.fluid.Variable.to_string (ArgSpec(args=['self', 'throw_on_error', 'with_details'], varargs=None, keywords=None, defaults=(False,)), ('document', '65cd237e2d30c12e412c9cafbbd00791'))
paddle.fluid.load_op_library (ArgSpec(args=['lib_filename'], varargs=None, keywords=None, defaults=None), ('document', 'c009b2ea5fb6520f2d2f53aafec788e0'))
paddle.fluid.Executor ('paddle.fluid.executor.Executor', ('document', '34e8c1769313fbeff7817212dda6259e'))
paddle.fluid.Executor ('paddle.fluid.executor.Executor', ('document', '4d963107d87438b5add4a5288855bd04'))
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '3a584496aa1343f36eebf3c46b323a74'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '90b3268b71a8aceedd0dc9e311921d15'))
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period', 'fetch_handler'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100, None)), ('document', '4ff256774ecaeee01c840a5fb5de8f7a'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', '4cfcd9c15b766a51b584cc46d38f1ad8'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'de3878f012e60edad05fb24fd88ce910'))
paddle.fluid.Executor.train_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period', 'fetch_handler'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100, None)), ('document', '73024c79f46b4f14f1060edeaa4919c8'))
paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'f65788d9ead293ada47551339df12203'))
paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', '02fcfc1eda07c03a84ed62422366239c'))
......@@ -46,10 +46,10 @@ paddle.fluid.memory_optimize (ArgSpec(args=['input_program', 'skip_opt_set', 'pr
paddle.fluid.release_memory (ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)), ('document', '2be29dc8ecdec9baa7728fb0c7f80e24'))
paddle.fluid.DistributeTranspilerConfig ('paddle.fluid.transpiler.distribute_transpiler.DistributeTranspilerConfig', ('document', 'beac6f89fe97eb8c66a25de5a09c56d2'))
paddle.fluid.DistributeTranspilerConfig.__init__ (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', '2b4d2e859f2e0c6161f4fed995f7956d'))
paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', 'dbff7bd7d365d755cec5ce977aa9db83'))
paddle.fluid.ParallelExecutor.__init__ (ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '77c739744ea5708b80fb1b37cc89db40'))
paddle.fluid.ParallelExecutor.run (ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '0af092676e5b1320bb4232396154ce4b'))
paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '4069452c2d8772920b3458ffda7ec562'))
paddle.fluid.ParallelExecutor.run (ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', 'bc758cc655d6b79129a38517769965b6'))
paddle.fluid.create_lod_tensor (ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None), ('document', '0627369b86ff974f433f7078d1e78349'))
paddle.fluid.create_random_int_lodtensor (ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None), ('document', '4829bd8c4a4f1b19438500def321cb65'))
paddle.fluid.DataFeedDesc ('paddle.fluid.data_feed_desc.DataFeedDesc', ('document', '43877a0d9357db94d3dbc7359cbe8c73'))
......@@ -58,9 +58,9 @@ paddle.fluid.DataFeedDesc.desc (ArgSpec(args=['self'], varargs=None, keywords=No
paddle.fluid.DataFeedDesc.set_batch_size (ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', 'a34790bff4a2891713ddd644db56418d'))
paddle.fluid.DataFeedDesc.set_dense_slots (ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'fdd07ce63e72bed57f2c0db5bec5720f'))
paddle.fluid.DataFeedDesc.set_use_slots (ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'c23a79dfa04edd014b477bd4b183da06'))
paddle.fluid.CompiledProgram ('paddle.fluid.compiler.CompiledProgram', ('document', '598d294107d44d7620bce76527a92c37'))
paddle.fluid.CompiledProgram ('paddle.fluid.compiler.CompiledProgram', ('document', 'c49ba191cbbbdf7c02b7ac978c06d7e0'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph', 'build_strategy'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', '1c7c6171bbf6d77f2fce0166aa0ec43b'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', '3b61147fc4f54e1724aa9ead8a1d5f26'))
paddle.fluid.ExecutionStrategy ('paddle.fluid.core_avx.ExecutionStrategy', ('document', '535ce28c4671176386e3cd283a764084'))
paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core_avx.ParallelExecutor.ExecutionStrategy) -> None
paddle.fluid.BuildStrategy ('paddle.fluid.core_avx.BuildStrategy', ('document', 'eec64b9b7cba58b0a63687b4c34ffe56'))
......
......@@ -64,18 +64,24 @@ def _prune_feed_ops(program):
class CompiledProgram(object):
"""
Compiles to Graph for execution.
The CompiledProgram is used to transform a program or graph for
various optimizations according to the configuration of build_strategy,
for example, the operators' fusion in the computation graph, memory
optimization during the execution of the computation graph, etc.
For more information about build_strategy, please refer to
:code:`fluid.BuildStrategy`.
1. Users first create the program with layers.
2. Optionally, users use CompiledProgram to optimize the program before run.
3. The original program or CompiledProgram is run by executor.
Args:
program_or_graph (Graph|Program): This parameter is the Program or Graph
being executed.
build_strategy(BuildStrategy): This parameter is used to compile the
program or graph with the specified options, such as operators' fusion
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`fluid.BuildStrategy`. The default is None.
The CompiledProgram is used to transform a program for various
optimizations, for example.
* Pre-compute some logic once so that each run is faster.
* Transform the program so that it can run in multiple devices.
* Transform the program for optimized inference or distributed
training. **Note that: this part is not finished.**
Returns:
CompiledProgram
Example:
.. code-block:: python
......@@ -88,7 +94,7 @@ class CompiledProgram(object):
place = fluid.CUDAPlace(0) # fluid.CPUPlace()
exe = fluid.Executor(place)
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
......@@ -102,17 +108,6 @@ class CompiledProgram(object):
loss_data, = exe.run(compiled_prog,
feed={"X": x},
fetch_list=[loss.name])
Args:
program_or_graph (Graph|Program): If it's Program, it will be first
lowered to a graph for further optimizations. If it's a graph
(potentially optimized before), it will be directly used for
further optimizations. Note: graph is only supported when compiled
with with_data_parallel option.
build_strategy(BuildStrategy): build_strategy is used to
build the graph with the specified options.
For more information, please refer to fluid.BuildStrategy.
Default None.
"""
def __init__(self, program_or_graph, build_strategy=None):
......@@ -146,7 +141,53 @@ class CompiledProgram(object):
exec_strategy=None,
share_vars_from=None,
places=None):
"""Configs the program to run in data parallel way.
"""
This interface is used to transform the input Program or Graph to a multi-graph
to run the model in data parallel mode. Users can use the build_strategy and
exec_strategy to set some optimizations that can be applied during the construction
and computation of the Graph, such as reducing the number of AllReduce operations,
specifying the size of the thread pool used in the computation Graph running the model,
and so on. **Note: If build_strategy is specified when building CompiledProgram and calling
with_data_parallel, build_strategy in CompiledProgram will be overwritten, therefore,
if it is data parallel training, it is recommended to set build_strategy when calling
with_data_parallel interface.**
Args:
loss_name (str): This parameter is the name of the loss variable of the model.
**Note: If it is model training, you must set loss_name, otherwise the
result may be problematic**. The default is None.
build_strategy(BuildStrategy): This parameter is used to compile the
program or graph with the specified options, such as operators' fusion
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`fluid.BuildStrategy`. The default is None.
exec_strategy(ExecutionStrategy): exec_strategy specifies the options that can
be changed when running the current model, such as the thread pool size.
For more information about exec_strategy, please refer to :code:`fluid.ExecutionStrategy`.
The default is None.
share_vars_from(CompiledProgram): If share_vars_from is set, the current
CompiledProgram will share the parameter value with the CompiledProgram
specified by share_vars_from. This parameter needs to be set when model testing
is required during model training, and the data parallel mode is used for
training and testing. Since CompiledProgram will only distribute parameter
variables to other devices when it is first executed, the CompiledProgram
specified by share_vars_from must be run before the current CompiledProgram.
The default is None.
places(list(CUDAPlace)|list(CPUPlace)|None): This parameter specifies the device
on which the model is running. If you want to run on GPU0 and GPU1, places are
[fluid.CUDAPlace(0), fluid.CUDAPlace(1)]; if you want to run with 2 CPUs, places are
[fluid.CPUPlace()] * 2. If the parameter is not set, i.e. the parameter is None,
the available device will be obtained from the environment variable when the model
is executed: If the GPU is used, the currently available device ID is obtained
from the environment variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when
the model is executed; CPU, when the model is executed, the currently available
CPU number is obtained from the environment variable CPU_NUM. For example,
export CPU_NUM=4, if the environment variable is not set, the executor will
add the variable to the environment variable and set its value to 1.
The default is None.
Returns:
CompiledProgram
Example:
.. code-block:: python
......@@ -170,7 +211,7 @@ class CompiledProgram(object):
exe = fluid.Executor(place)
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
......@@ -185,35 +226,6 @@ class CompiledProgram(object):
loss_data, = exe.run(compiled_prog,
feed={"X": x},
fetch_list=[loss.name])
Args:
loss_name (str): The loss name must set in training. Default None.
build_strategy(BuildStrategy): build_strategy is used to
build the graph with the specified options.
For more information, please refer to fluid.BuildStrategy.
Note that, if you set build_strategy in the argument list when
creating CompiledProgram and calling with_data_parallel,
the build_strategy in CompiledProgram will be overwritten by the latter.
Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to
to select the a way to execute the graph, for example how many
threads are used, how many iterations to clean up the temp
variables. For more information, please refer
to fluid.ExecutionStrategy. Default None.
share_vars_from(CompiledProgram): If provided, this CompiledProgram
will share variables from `share_vars_from`. `share_vars_from`
must be run by the executor before this CompiledProgram so that
vars are ready.
places(list(CUDAPlace)|list(CPUPlace)|None): If provided, only compile
program in the given places. Otherwise, the places used when compiled
is determined by the Executor, and the places used are controlled
by environment variables: FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES
if using GPU; or CPU_NUM if using CPU. For example, if you want to
run on GPU 0 and 1, set places=[fluid.CUDAPlace(0), fluid.CUDAPlace(1)].
If you want to run on 2 CPU cores, set places=[fluid.CPUPlace()]*2.
Returns:
self
"""
assert not self._is_data_parallel, "Already compiled with parallel."
assert not self._is_inference, "Cannot compile both data parallel and inference"
......
......@@ -418,16 +418,15 @@ class FetchHandlerExamlpe(FetchHandler):
class Executor(object):
"""
An Executor in Python, supports single/multiple-GPU running,
and single/multiple-CPU running. Python executor takes a program,
adds feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the
program. fetch_list provides the variables(or names) that user wants
to get after program runs. Note: the executor will run all operators
in the program but not only the operators dependent by the fetch_list.
It stores the global variables into the global scope, and creates a
local scope for the temporary variables. The contents in local scope
may be discarded after every minibatch forward/backward finished.
But the global scope variables will be persistent through different runs.
and single/multiple-CPU running. When construction the Executor,
the device is required.
Args:
place(fluid.CPUPlace()|fluid.CUDAPlace(n)): This parameter represents
the executor run on which device.
Returns:
Executor
Examples:
.. code-block:: python
......@@ -444,7 +443,7 @@ class Executor(object):
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
......@@ -477,10 +476,6 @@ class Executor(object):
loss_data, = exe.run(compiled_prog,
feed={"X": x},
fetch_list=[loss.name])
Args:
place(fluid.CPUPlace|fluid.CUDAPlace(n)): indicate the executor run on which device.
"""
def __init__(self, place):
......@@ -595,11 +590,12 @@ class Executor(object):
def close(self):
"""
Close this executor.
Close the executor. This interface is used for distributed training (PServers mode).
This executor can not be used after calling the interface, because
this interface releases resources associated with the current Trainer.
You can no longer use this executor after calling this method.
For the distributed training, this method would free the resource
on PServers related to the current Trainer.
Returns:
None
Examples:
.. code-block:: python
......@@ -683,14 +679,61 @@ class Executor(object):
return_numpy=True,
use_program_cache=False):
"""
Run program by this Executor. Feed data by feed map, fetch result by
fetch_list. Python executor takes a program, add feed operators and
fetch operators to this program according to feed map and fetch_list.
Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run.
Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor
will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some
operators of the :code:`Program` or :code:`CompiledProgram` according to fetch_list. And you could
specify the scope to store the :code:`Variables` during the executor running if the scope
is not set, the executor will use the global scope, i.e. :code:`fluid.global_scope()`.
Args:
program(Program|CompiledProgram): This parameter represents the :code:`Program` or
:code:`CompiledProgram` to be executed. If this parameter is not provided, that
parameter is None, the program will be set to :code:`fluid.default_main_program()`.
The default is None.
feed(list|dict): This parameter represents the input variables of the model.
If it is single card training, the feed is dict type, and if it is multi-card
training, the parameter feed can be dict or list type variable. If the
parameter type is dict, the data in the feed will be split and sent to
multiple devices (CPU/GPU), that is to say, the input data will be evenly
sent to different devices, so you should make sure the number of samples of
the current mini-batch must be greater than the number of places;
if the parameter type is list, those data are copied directly to each device,
so the length of this list should be equal to the number of places.
The default is None.
fetch_list(list): This parameter represents the variables that need to be returned
after the model runs. The default is None.
feed_var_name(str): This parameter represents the name of the input variable of
the feed operator. The default is "feed".
fetch_var_name(str): This parameter represents the name of the output variable of
the fetch operator. The default is "fetch".
scope(Scope): the scope used to run this program, you can switch
it to different scope. default is :code:`fluid.global_scope()`
return_numpy(bool): This parameter indicates whether convert the fetched variables
(the variable specified in the fetch list) to numpy.ndarray. if it is False,
the type of the return value is a list of :code:`LoDTensor`. The default is True.
use_program_cache(bool): This parameter indicates whether the input :code:`Program` is cached.
If the parameter is True, the model may run faster in the following cases:
the input program is :code:`fluid.Program`, and the parameters(program, feed variable name
and fetch_list variable) of this interface remains unchanged during running.
The default is False.
Returns:
Note: the executor will run all operators in the program but not
only the operators dependent by the fetch_list.
List: The fetched result list.
NOTES:
1. If it is multi-card running and the feed parameter is dict type, the input data
will be evenly sent to different cards. For example, using two GPUs to run the model,
the input sample number is 3, that is, [0, 1, 2], the sample number on GPU0 is 1,
that is, [0], and the sample number on GPU1 is 2, that is, [1, 2].
If the number of samples is less than the number of devices, the program will
throw an exception, so when running the model, you should make sure that the
number of samples of the last batch of the data set should be greater than the
number of CPU cores or GPU cards, if it is less than, it is recommended that
the batch be discarded.
2. If the number of CPU cores or GPU cards available is greater than 1, the fetch
results are spliced together in dimension 0 for the same variable values
(variables in fetch_list) on different devices.
Examples:
.. code-block:: python
......@@ -702,7 +745,7 @@ class Executor(object):
place = fluid.CPUPlace() # fluid.CUDAPlace(0)
exe = fluid.Executor(place)
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
adam = fluid.optimizer.Adam()
......@@ -714,29 +757,6 @@ class Executor(object):
x = numpy.random.random(size=(10, 1)).astype('float32')
outs = exe.run(feed={'X': x},
fetch_list=[loss.name])
Args:
program(Program|CompiledProgram): the program that need to run,
if not provided, then default_main_program (not compiled) will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData}
fetch_list(list): a list of variable or variable names that user
wants to get, this method will return them according to this list.
feed_var_name(str): the name for the input variable of
feed Operator.
fetch_var_name(str): the name for the output variable of
fetch Operator.
scope(Scope): the scope used to run this program, you can switch
it to different scope. default is global_scope
return_numpy(bool): if convert the fetched tensor to numpy
use_program_cache(bool): whether to use the cached program
settings across batches. Setting it be true would be faster
only when (1) the program is not compiled with data parallel,
and (2) program, feed variable names and fetch_list variable
names do not changed compared to the last step.
Returns:
list(numpy.array): fetch result according to fetch_list.
"""
try:
return self._run_impl(
......
......@@ -27,15 +27,73 @@ BuildStrategy = core.ParallelExecutor.BuildStrategy
class ParallelExecutor(object):
"""
ParallelExecutor is designed for data parallelism, which focuses on distributing
the data across different nodes and every node operates on the data in parallel.
If you use ParallelExecutor to run the current program on GPU, the node means GPU
device, and ParallelExecutor will get the available GPU device automatically on
the current machine. If you use ParallelExecutor to run the current program on CPU,
the node means the CPU device, and you can specify the CPU device number by adding
'CPU_NUM' environment variable, for example 'CPU_NUM=4', if the environment variable
is not found, ParallelExecutor will call `multiprocessing.cpu_count` to get the number
of CPUs in the system.
The ParallelExecutor is an upgraded version of :code:`fluid.Executor` that supports multi-node model
training and testing based on the data-parallel mode. In data-parallel mode,
ParallelExecutor will broadcast the parameters from Node0 to other nodes during
construction and copy the input Program to other nodes from Node0 to make sure
that the initial state on each node is the same. Each node runs the model independently
and the parameters' gradient is aggregated between those nodes during backward
computation, and then each node independently updates its parameters. If you use
the GPU to run the model, i.e. use_cuda=True, the node refers to the GPU,
ParallelExecutor will automatically get the GPU resources available on the
current machine, users can also set the available GPU resources in the environment
variable, for example: want to use GPU0, GPU1, export CUDA_VISIBLEDEVICES=0,1;
If the operation is performed on the CPU, i.e. use_cuda=False, the node refers to the CPU.
**Note: At this time, the user needs to manually add CPU_NUM to the environment variable
and set the number of CPU devices. For example, export CPU_NUM=4, if the environment
variable is not set, the executor will add the variable to the environment variable
and set it to 1.**
Args:
use_cuda (bool): Whether to use CUDA or not.
loss_name (str): This parameter is the name of the loss variable of the
model. **Note: If it is data-parallel model training, you must set loss_name,
otherwise, the results may be wrong**. The default is None.
main_program (Program): This parameter represents the Program to be executed.
If this parameter is not provided, that parameter is None, the program will
be set to :code:`fluid.default_main_program()`. The default is None.
share_vars_from(ParallelExecutor): If share_vars_from is set, the current
ParallelExecutor will share the parameters with the ParallelExecutor
specified by share_vars_from. This parameter needs to be set when model testing
is required during model training, and the data parallel mode is used for
training and testing. Since ParallelExecutor will only distribute parameter
variables to other devices when it is first executed, the ParallelExecutor
specified by share_vars_from must be run before the current ParallelExecutor.
The default is None.
exec_strategy(ExecutionStrategy): exec_strategy specifies the options that can
be changed when running the current model, such as the thread pool size.
For more information about exec_strategy, please refer to :code:`fluid.ExecutionStrategy`.
The default is None.
build_strategy(BuildStrategy): By configuring build_strategy, we can
optimize the computational graph, such as operators' fusion in the
computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`fluid.BuildStrategy`. The default is None.
num_trainers(int): This parameter needs to be set in GPU distributed training.
If the parameter value is greater than 1, NCCL will be initialized by multi-level
nodes. Each node should have the same number of GPUs. The default is 1.
trainer_id(int): This parameter needs to be set when performing GPU distributed
training. This parameter must be used with the num_trainers parameter.
Trainer_id indicates the "rank" of the current node. The trainer_id starts
counting from 0. The default is 0.
scope(Scope): Specifies the scope in which the program is executed.
The default is fluid.global_scope().
Returns:
ParallelExecutor: The initialized ParallelExecutor object.
Raises:
TypeError: If share_vars_from is provided, but not ParallelExecutor object.
NOTES:
1. If you only use ParallelExecutor to do multi-card test, you don't need to set loss_name
and share_vars_from.
2. If you need to train and test the model with ParallelExecutor, the share_vars_from
must be set when building the ParallelExecutor corresponding to the model test.
Otherwise, the parameters used in the model test and the model training are inconsistent.
Examples:
.. code-block:: python
......@@ -61,7 +119,7 @@ class ParallelExecutor(object):
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
test_program = fluid.default_main_program().clone(for_test=True)
......@@ -84,35 +142,6 @@ class ParallelExecutor(object):
loss_data, = test_exe.run(feed={"X": x},
fetch_list=[loss.name])
Args:
use_cuda (bool): Whether to use CUDA or not.
loss_name (str): The loss name must set in training. Default None.
main_program (Program): The program that need to run, if not provided,
then default_main_program will be used. Default None.
share_vars_from(ParallelExecutor): If provide, it will share variables
from the specified ParallelExecutor. Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to control how to run
the program in ParallelExecutor, for example how many threads are used to
execute the program, how many iterations to clean up the temp variables
which is generated during execution. For more information, please refer
to fluid.ExecutionStrategy. Default None.
build_strategy(BuildStrategy): build_strategy is used to control how to
build the SSA Graph in ParallelExecutor by setting the property,
for example reduce_strategy, gradient_scale_strategy. For more information,
please refer to fluid.BuildStrategy. Default None.
num_trainers(int): If greater than 1, NCCL will be initialized with
multiple rank of nodes, each node should have same number of GPUs.
Distributed training will be enabled then. Default 1.
trainer_id(int): Must use together with num_trainers. trainer_id is the
"rank" of current node starts from 0. Default 0.
scope(Scope): scope to run with, default use fluid.global_scope().
Returns:
ParallelExecutor: The initialized ParallelExecutor object.
Raises:
TypeError: If share_vars_from is provided, but not ParallelExecutor object.
"""
def __init__(self,
......@@ -176,12 +205,51 @@ class ParallelExecutor(object):
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
"""
Run a parallel executor with fetch_list.
This interface is used to run the current model. It should be noted
that the executor will execute all the operators in the Program,
and will not prune some operators in the Program according to the
fetch_list.
Args:
fetch_list(list): This parameter represents the variables that need to be returned
after the model runs. The default is None.
feed(list|dict): This parameter represents the input variables of the model.
If it is single card training, the feed is dict type, and if it is multi-card
training, the parameter feed can be dict or list type variable. If the
parameter type is dict, the data in the feed will be split and sent to
multiple devices (CPU/GPU), that is to say, the input data will be evenly
sent to different devices, so you should make sure the number of samples of
the current mini-batch must be greater than the number of places;
if the parameter type is list, those data are copied directly to each device,
so the length of this list should be equal to the number of places.
The default is None.
feed_dict: Alias for feed parameter, for backward compatibility.
This parameter has been deprecated. Default None.
return_numpy(bool): This parameter indicates whether convert the fetched variables
(the variable specified in the fetch list) to numpy.ndarray. if it is False,
the type of the return value is a list of :code:`LoDTensor`. The default is True.
Returns:
List: The fetched result list.
Raises:
ValueError: If the feed is a list, but its length is not equal the
length of active places, or its element's is not dict.
NOTES:
1. If the feed parameter is dict type, the input data will be evenly distributed
to different cards. For example, using two GPUs to run the model, the input
sample number is 3, that is, [0, 1, 2], the sample number on GPU0 is 1,
that is, [0], and the sample number on GPU1 is 2, that is, [1, 2].
If the number of samples is less than the number of devices, the program will
throw an exception, so when running the model, you should make sure that the
number of samples of the last batch of the data set should be greater than the
number of CPU cores or GPU cards, if it is less than, it is recommended that
the batch be discarded.
2. If the number of CPU cores or GPU cards available is greater than 1, the fetch
results are spliced together in dimension 0 for the same variable values
(variables in fetch_list) on different devices.
The feed parameter can be a dict or a list. If feed is a dict, the
feed data will be split into multiple devices. If feed is a list, we
assume the data has been split into multiple devices, the each
element in the list will be copied to each device directly.
Examples:
.. code-block:: python
......@@ -207,7 +275,7 @@ class ParallelExecutor(object):
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
......@@ -235,42 +303,6 @@ class ParallelExecutor(object):
loss_data, = train_exe.run(feed=[{"X": x}, {"X": x2}],
fetch_list=[loss.name])
Args:
fetch_list(list): The fetched variable names
feed(list|dict|None): The feed variables. If the feed is a dict,
tensors in that dict will be split into each devices. If
the feed is a list, each element of the list will be copied
to each device. Default None.
feed_dict: Alias for feed parameter, for backward compatibility.
This parameter has been deprecated. Default None.
return_numpy(bool): Whether converts the fetched tensor to numpy.
Default: True.
Returns:
List: The fetched result list.
Raises:
ValueError: If the feed is a list, but its length is not equal the
length of active places, or its element's is not dict.
NOTES:
1. If the feed's type is dict, the number of data that feeds to
ParallelExecutor must be bigger than active places. Otherwise,
it will throw exception from C++ side. Special attention should be
paid to check whether the last batch of the dataset is bigger
than active places.
2. If active places are more than one, the fetch results for each
variable is a list, and each element of this list is the variable of
respective active place.
Examples:
.. code-block:: python
pe = fluid.ParallelExecutor(use_cuda=use_cuda,
loss_name=avg_cost.name,
main_program=fluid.default_main_program())
loss = pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))
"""
return self._exe.run(program=self._compiled_program,
scope=self._scope,
......@@ -284,19 +316,17 @@ class ParallelExecutor(object):
def drop_local_exe_scopes(self):
"""
Drop the local execution scope immediately.
During the execution of the Program, the generate intermediate
results are placed in local execution scope, in some model the
creation and deletion of those intermediate results are time-consuming.
To resolve that problem, ParallelExecutor provides an option in
ExecutionStrategy, i.g. num_iteration_per_drop_scope, this option
indicates how many iterations to run before dropping the local execution
scope. But in some situation, each iteration generates different
intermediate results, it will lead to the result that the memory which
is needed by local execution scope gradually increase. And if you want
to run another program at this time, there may be insufficient storage,
At this point you should drop the local execution scope of other Programs.
Drop the local execution scopes immediately. In order to avoid frequently
application and release of temporary variables, the strategy adopted by
ParallelExecutor is to drop the local execution scopes after several iterations.
ParallelExecutor provides the num_iteration_per_drop_scope option in
:code:`fluid.ExecutionStrategy`, which indicates how many iterations are intervened to
drop the local execution scopes. If the num_iteration_per_drop_scope value
is 100, but you want to drop the local execution scopes after 50 iterations,
you can call the interface manually.
Returns:
None
Examples:
.. code-block:: python
......@@ -318,7 +348,7 @@ class ParallelExecutor(object):
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册