未验证 提交 2b4dd5b9 编写于 作者: W wuhuachaocoding 提交者: GitHub

update recompute doc. (#50088)

上级 3c557e2f
......@@ -453,13 +453,12 @@ def recompute(function, *args, **kwargs):
def recompute_sequential(ctx, functions, *args, **kwargs):
"""
recompute intermediate activations to save then memory for 'Sequential' models.
recompute intermediate activations to save the memory for 'Sequential' models. use 'ctx' to transmit some context params, it is similar to 'recompute_hybrid' API.
Parameters:
ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here,
they are useful in 'recompute_hybrid' API.
restored when the forward recalculation of backpropagation is performed.
functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
......
......@@ -246,6 +246,7 @@ class _HPRecomputeFunction(PyLayer):
def recompute_hybrid(ctx, function, *args, **kwargs):
"""
recompute intermediate activations to save the memory in hybrid parallel scene.
# NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients.
......@@ -255,8 +256,7 @@ def recompute_hybrid(ctx, function, *args, **kwargs):
Parameters:
ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted
in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False),
represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in
'recompute_sequential' API.
represents whether to split activations in the mp_group.
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
......
......@@ -27,18 +27,20 @@ __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
def recompute(function, *args, **kwargs):
"""
recompute intermediate activations to save then memory.
recompute intermediate activations to save the memory.
Parameters:
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to
indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. The default
preserve_rng_state is True.
**kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params,
and the other contains ``preserve_rng_state`` and ``use_reentrant``. the key-value pair of ``preserve_rng_state``,
which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value
will be restored when the forward recalculation of backpropagation is performed, its default value is True.
the key-value pair of ``use_reentrant`` is used to indicate which implementation of recompute you will be used.
``use_reentrant=True`` means to use the PyLayer implementation of recompute, ``use_reentrant=False`` means to
use the Hook implementation of recompute, its default value is True.
Returns:
Output of function on args.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册