recompute_hybrid.py 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.autograd import PyLayer
17 18
from paddle.fluid import core, framework

19
from ..meta_parallel.parallel_layers.random import get_rng_state_tracker
20
from ..meta_parallel.pp_utils import utils
21 22 23 24 25
from .recompute import (
    check_recompute_necessary,
    detach_variable,
    swith_rng_state_tracker,
)
26 27 28 29 30 31 32 33 34 35 36 37 38

__all__ = []


def _split_activation(tensor, mp_group):

    mp_degree = mp_group.nranks
    mp_rank = mp_group.rank
    if mp_degree < 2:
        return tensor

    tensor_numel = paddle.numel(tensor)
    assert tensor_numel != 0, "can't recompute zero element"
39 40 41 42 43
    assert (
        tensor_numel % mp_degree == 0
    ), "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format(
        tensor_numel, mp_degree
    )
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

    # use inplace operation to save memory
    data = tensor.flatten_()

    part_size = tensor_numel // mp_degree
    start = part_size * mp_rank
    end = start + part_size
    return data[start:end]


def _merge_activation(tensor, mp_group):
    mp_degree = mp_group.nranks
    mp_rank = mp_group.rank
    if mp_degree < 2:
        return tensor

    # adapt to new dygraph
    tensor_shape = list(tensor.shape)
    tensor_shape[0] *= mp_group.nranks
    out = paddle.empty(tensor_shape, tensor.dtype)
    task = mp_group.process_group.all_gather(tensor.cuda(), out)
    task.wait()
    return out


class _HPRecomputeFunction(PyLayer):
    """
    Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
    1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
    2. Offload support for activation
    3. Support MP segmentation of activation to further reduce cuda memory
    4. Adapt to the random state of MP
    """

    @staticmethod
79 80 81 82 83 84 85 86 87 88
    def forward(
        ctx,
        run_function,
        all_outputs,
        mp_group,
        offload,
        partition,
        *args,
        **kwargs
    ):
89 90 91 92 93 94 95 96

        # store for recomputing
        ctx.run_function = run_function

        ctx.kwargs = kwargs

        # store the rng states
        ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
97 98 99
        ctx.fwd_cuda_rng_state_tracker = (
            get_rng_state_tracker().get_states_tracker()
        )
100 101 102 103 104 105 106 107 108 109 110 111 112

        # save config info
        ctx.mp_group = mp_group
        ctx.offload = offload
        ctx.partition = partition

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        ctx.tensor_shapes = []
        tensor_inputs = []

        cur_device = paddle.get_device()
113 114
        assert (
            'gpu:' in paddle.get_device()
115
        ), "Recompute with RNG is not support current device: {}.".format(
116 117
            cur_device
        )
118 119 120

        # TODO support AMP
        tracer = framework._dygraph_tracer()
121 122 123
        ctx.is_fw_autocast = (
            False if tracer._amp_level == core.AmpLevel.O0 else True
        )
124 125 126 127 128
        if tracer._amp_level == core.AmpLevel.O2:
            ctx.amp_level = 'O2'
        elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
            ctx.amp_level = 'O1'
        else:
129 130 131
            raise ValueError(
                "unsupported amp level: {}".format(tracer._amp_level)
            )
132 133 134 135 136 137 138 139 140 141
        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

        with paddle.no_grad():
            outputs = run_function(*args, **kwargs)

        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                state = arg.stop_gradient
                if partition:
                    ctx.tensor_shapes.append(arg.shape)
142 143 144
                    partition = _split_activation(
                        arg.detach(), mp_group
                    ).clone()
145 146 147 148 149 150 151 152
                    # TODO(shenliang03) not use calculate stream to D2H to speed
                    arg = partition.cpu() if offload else partition
                else:
                    arg = arg.cpu() if offload else arg
                arg.stop_gradient = state
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
153 154 155 156 157 158 159 160 161 162 163 164

                # In new dygraph mode, in some cases a subset of outputs is identity to the subset of inputs,
                #  which is inplace operating. When the inputs' stop_gradient is True, an
                #  error will occurs because the stop_gradient=True and inpalce-op are not
                #  supported in the same time. The solution is to mark the inputs non_differentiable
                #  if its stop_gradient is True.
                # Note:
                #  If not marked non_differentiable, all output tensors' attr `stop gradient`
                #  will be reset to `False` in c++ backend.
                #  See https://github.com/PaddlePaddle/Paddle/blob/9d62efb0e6e5373823039d9eda96cd5905426c0a/paddle/fluid/pybind/eager_py_layer.cc#L388
                if framework.in_dygraph_mode() and state:
                    ctx.mark_non_differentiable(arg)
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
            else:
                ctx.inputs.append(arg)

        ctx.save_for_backward(*tensor_inputs)

        if paddle.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            return tuple(outputs)

    @staticmethod
    def backward(ctx, *args):
        with paddle.fluid.dygraph.guard():
            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensor_shapes = ctx.tensor_shapes
            tensors = list(ctx.saved_tensor())

            device_id = paddle.distributed.ParallelEnv().device_id
            for i, idx in enumerate(tensor_indices):
                if ctx.partition:
                    state = tensors[i].stop_gradient
190 191 192 193 194
                    tensors[i] = (
                        _merge_activation(tensors[i], ctx.mp_group)
                        .detach()
                        .reshape_(tensor_shapes[i])
                    )
195
                    tensors[i].stop_gradient = state
196 197 198
                inputs[idx] = (
                    tensors[i].cuda(device_id) if ctx.offload else tensors[i]
                )
199 200 201 202 203

            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

            # need restore auto_cast state as well as w/b list
204 205 206 207 208 209 210 211 212
            with swith_rng_state_tracker(
                ctx.fwd_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
            ):
                with paddle.amp.auto_cast(
                    enable=ctx.is_fw_autocast,
                    custom_white_list=ctx.amp_white_list,
                    custom_black_list=ctx.amp_black_list,
                    level=ctx.amp_level,
                ):
213 214 215 216
                    detached_inputs = detach_variable(tuple(inputs))
                    outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)

            if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
217
                outputs = (outputs,)
218 219 220 221 222 223
            assert len(outputs) == len(args)

            forward_outputs_with_grad = []
            backward_inputs = []

            for i in range(len(outputs)):
224 225 226 227
                if (
                    isinstance(outputs[i], (core.VarBase, core.eager.Tensor))
                    and not outputs[i].stop_gradient
                ):
228 229 230 231 232 233 234 235 236 237
                    forward_outputs_with_grad.append(outputs[i])
                    backward_inputs.append(args[i])

            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has stop_gradient=False, this recompute() is not necessary"
                )

            # actually backward
            paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
238 239 240 241 242
            grads = tuple(
                inp._grad_ivar()
                for inp in detached_inputs
                if isinstance(inp, (core.VarBase, core.eager.Tensor))
            )
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
            return grads


def recompute_hybrid(ctx, function, *args, **kwargs):
    """
    # NODTE(shenliang03)The current hybrid parallel recompute has limitations.
    # It cannot handle the following situations:
    # 1. The calculation output of recompute, there are tensors that do not require gradients.
    # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach().
    # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor

    Parameters:
        ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted
                   in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False),
                   represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in
                   'recompute_sequential' API.
        function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
              whose intermediate activations will be released to save memory in forward stage and will be recomputed
              in backward stage for gradient calculation.
        *args(Tensor): inputs(tuple) to the function.

        **kwargs(Dict): inputs(dict) to the function.

    Returns:
        Output of function on args and kwargs.

    """
    mp_group = ctx.get('mp_group', None)
271 272 273
    assert (
        mp_group is not None
    ), "ctx must contains mp_group and mp_group can not be None."
274 275 276 277

    offload = ctx.get('offload', False)
    partition = ctx.get('partition', False)

278 279 280
    if framework._dygraph_tracer()._has_grad:
        check_recompute_necessary(args)

281
    all_outputs = []
282 283 284
    _HPRecomputeFunction.apply(
        function, all_outputs, mp_group, offload, partition, *args, **kwargs
    )
285 286 287 288 289 290 291 292 293

    if len(all_outputs) == 1:
        return all_outputs[0]
    else:
        for output in all_outputs:
            if paddle.is_tensor(output) and not utils.is_float_tensor(output):
                output.stop_gradient = True

        return tuple(all_outputs)