diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 70a38ff3b9aa663f8b6b21c85ba7e1677310f109..54c486a4d2be50246db55d204603574269a37bf7 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -453,13 +453,12 @@ def recompute(function, *args, **kwargs): def recompute_sequential(ctx, functions, *args, **kwargs): """ - recompute intermediate activations to save then memory for 'Sequential' models. + recompute intermediate activations to save the memory for 'Sequential' models. use 'ctx' to transmit some context params, it is similar to 'recompute_hybrid' API. Parameters: ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, - they are useful in 'recompute_hybrid' API. + restored when the forward recalculation of backpropagation is performed. functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 781f44e4061affd78b0b6e7fe8005a66bcca6a9d..44faccf9dd42e751377692d9a315570ab80305b9 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -246,6 +246,7 @@ class _HPRecomputeFunction(PyLayer): def recompute_hybrid(ctx, function, *args, **kwargs): """ + recompute intermediate activations to save the memory in hybrid parallel scene. # NODTE(shenliang03)The current hybrid parallel recompute has limitations. # It cannot handle the following situations: # 1. The calculation output of recompute, there are tensors that do not require gradients. @@ -255,8 +256,7 @@ def recompute_hybrid(ctx, function, *args, **kwargs): Parameters: ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), - represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in - 'recompute_sequential' API. + represents whether to split activations in the mp_group. function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index a9d2a0454257a200497054ca4d3869a5bc5f41c0..c13908ba62d2e920f0614d9d537f5455357c5ce8 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -27,18 +27,20 @@ __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa def recompute(function, *args, **kwargs): """ - recompute intermediate activations to save then memory. + recompute intermediate activations to save the memory. Parameters: function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. - + **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params, + and the other contains ``preserve_rng_state`` and ``use_reentrant``. the key-value pair of ``preserve_rng_state``, + which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value + will be restored when the forward recalculation of backpropagation is performed, its default value is True. + the key-value pair of ``use_reentrant`` is used to indicate which implementation of recompute you will be used. + ``use_reentrant=True`` means to use the PyLayer implementation of recompute, ``use_reentrant=False`` means to + use the Hook implementation of recompute, its default value is True. Returns: Output of function on args.