From 2b4dd5b90c5f4b6ff3f80f6e3bc100714c9e86c2 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Thu, 2 Feb 2023 11:47:03 +0800 Subject: [PATCH] update recompute doc. (#50088) --- .../distributed/fleet/recompute/recompute.py | 5 ++--- .../fleet/recompute/recompute_hybrid.py | 4 ++-- python/paddle/distributed/fleet/utils/__init__.py | 14 ++++++++------ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 70a38ff3b9..54c486a4d2 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -453,13 +453,12 @@ def recompute(function, *args, **kwargs): def recompute_sequential(ctx, functions, *args, **kwargs): """ - recompute intermediate activations to save then memory for 'Sequential' models. + recompute intermediate activations to save the memory for 'Sequential' models. use 'ctx' to transmit some context params, it is similar to 'recompute_hybrid' API. Parameters: ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, - they are useful in 'recompute_hybrid' API. + restored when the forward recalculation of backpropagation is performed. functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 781f44e406..44faccf9dd 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -246,6 +246,7 @@ class _HPRecomputeFunction(PyLayer): def recompute_hybrid(ctx, function, *args, **kwargs): """ + recompute intermediate activations to save the memory in hybrid parallel scene. # NODTE(shenliang03)The current hybrid parallel recompute has limitations. # It cannot handle the following situations: # 1. The calculation output of recompute, there are tensors that do not require gradients. @@ -255,8 +256,7 @@ def recompute_hybrid(ctx, function, *args, **kwargs): Parameters: ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), - represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in - 'recompute_sequential' API. + represents whether to split activations in the mp_group. function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index a9d2a04542..c13908ba62 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -27,18 +27,20 @@ __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa def recompute(function, *args, **kwargs): """ - recompute intermediate activations to save then memory. + recompute intermediate activations to save the memory. Parameters: function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. - + **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params, + and the other contains ``preserve_rng_state`` and ``use_reentrant``. the key-value pair of ``preserve_rng_state``, + which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value + will be restored when the forward recalculation of backpropagation is performed, its default value is True. + the key-value pair of ``use_reentrant`` is used to indicate which implementation of recompute you will be used. + ``use_reentrant=True`` means to use the PyLayer implementation of recompute, ``use_reentrant=False`` means to + use the Hook implementation of recompute, its default value is True. Returns: Output of function on args. -- GitLab