未验证 提交 2b4dd5b9 编写于 作者: W wuhuachaocoding 提交者: GitHub

update recompute doc. (#50088)

上级 3c557e2f
...@@ -453,13 +453,12 @@ def recompute(function, *args, **kwargs): ...@@ -453,13 +453,12 @@ def recompute(function, *args, **kwargs):
def recompute_sequential(ctx, functions, *args, **kwargs): def recompute_sequential(ctx, functions, *args, **kwargs):
""" """
recompute intermediate activations to save then memory for 'Sequential' models. recompute intermediate activations to save the memory for 'Sequential' models. use 'ctx' to transmit some context params, it is similar to 'recompute_hybrid' API.
Parameters: Parameters:
ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, restored when the forward recalculation of backpropagation is performed.
they are useful in 'recompute_hybrid' API.
functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation. in backward stage for gradient calculation.
......
...@@ -246,6 +246,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -246,6 +246,7 @@ class _HPRecomputeFunction(PyLayer):
def recompute_hybrid(ctx, function, *args, **kwargs): def recompute_hybrid(ctx, function, *args, **kwargs):
""" """
recompute intermediate activations to save the memory in hybrid parallel scene.
# NODTE(shenliang03)The current hybrid parallel recompute has limitations. # NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations: # It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients. # 1. The calculation output of recompute, there are tensors that do not require gradients.
...@@ -255,8 +256,7 @@ def recompute_hybrid(ctx, function, *args, **kwargs): ...@@ -255,8 +256,7 @@ def recompute_hybrid(ctx, function, *args, **kwargs):
Parameters: Parameters:
ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted
in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False),
represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in represents whether to split activations in the mp_group.
'recompute_sequential' API.
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation. in backward stage for gradient calculation.
......
...@@ -27,18 +27,20 @@ __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa ...@@ -27,18 +27,20 @@ __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
def recompute(function, *args, **kwargs): def recompute(function, *args, **kwargs):
""" """
recompute intermediate activations to save then memory. recompute intermediate activations to save the memory.
Parameters: Parameters:
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation. in backward stage for gradient calculation.
*args(Tensor): inputs to the function. *args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params,
indicate whether to save the forward rng. If it is True, then the last forward rng value will be and the other contains ``preserve_rng_state`` and ``use_reentrant``. the key-value pair of ``preserve_rng_state``,
restored when the forward recalculation of backpropagation is performed. The default which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value
preserve_rng_state is True. will be restored when the forward recalculation of backpropagation is performed, its default value is True.
the key-value pair of ``use_reentrant`` is used to indicate which implementation of recompute you will be used.
``use_reentrant=True`` means to use the PyLayer implementation of recompute, ``use_reentrant=False`` means to
use the Hook implementation of recompute, its default value is True.
Returns: Returns:
Output of function on args. Output of function on args.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册