recompute intermediate activations to save then memory for 'Sequential' models.
Parameters:
ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here,
they are useful in 'recompute_hybrid' API.
functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
# NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients.
# 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach().
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
Parameters:
ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted
in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False),
represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in
'recompute_sequential' API.
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
*args(Tensor): inputs(tuple) to the function.
**kwargs(Dict): inputs(dict) to the function.
Returns:
Output of function on args and kwargs.
"""
mp_group=ctx.get('mp_group',None)
assertmp_groupisnotNone,"ctx must contains mp_group and mp_group can not be None."