recompute.py 7.5 KB
Newer Older
J
JZ-LIANG 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
from paddle.autograd import PyLayer
from paddle.fluid import framework
import contextlib

import logging
22 23 24 25 26 27
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
    fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
J
JZ-LIANG 已提交
28

29 30
__all__ = []

J
JZ-LIANG 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

def detach_variable(inputs):
    out = []
    for inp in inputs:
        if not isinstance(inp, core.VarBase):
            out.append(inp)
            continue

        x = inp.detach()
        x.stop_gradient = inp.stop_gradient
        out.append(x)
    return tuple(out)


def check_recompute_necessary(inputs):
    if not any(input_.stop_gradient == False for input_ in inputs
               if isinstance(input_, paddle.Tensor)):
48
        logger.warn(
J
JZ-LIANG 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
            "[Recompute]: None of the inputs to current recompute block need grad, "
            "therefore there is NO need to recompute this block in backward !")


@contextlib.contextmanager
def swith_rng_state(rng_state):
    orig_cuda_rng_state = paddle.get_cuda_rng_state()
    paddle.set_cuda_rng_state(rng_state)
    try:
        yield
    finally:
        paddle.set_cuda_rng_state(orig_cuda_rng_state)


class RecomputeFunction(PyLayer):
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_recompute_necessary(args)

        # store for recomputing 
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state

        # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
        # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
        # None tensor inputs will be filtered in backward inputs.

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
        ctx.save_for_backward(*tensor_inputs)

        # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
        # one process with multiple gpu and mix-gpu-cpu senarios are not support
        if ctx.preserve_rng_state:
            cur_device = paddle.get_device()
            if 'gpu:' not in cur_device:
                raise RuntimeError(
                    "Recompute with RNG perserve is not support current device: {}.".
                    format(cur_device))
            ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()

        # TODO support AMP
100
        tracer = framework._dygraph_tracer()
101 102 103 104 105
        if tracer._amp_level == 0:
            ctx.is_fw_autocast = False
        else:
            ctx.is_fw_autocast = True
        ctx.amp_mode = 'O1'
106
        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
J
JZ-LIANG 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

        with paddle.no_grad():
            outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
        with paddle.fluid.dygraph.guard():
            # TODO need to check the recompute calling is vaild or not

            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensors = ctx.saved_tensor()
            for i, idx in enumerate(tensor_indices):
                inputs[idx] = tensors[i]

            # paddle.enable_grad()
            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

128 129
            # NOTE support AMP
            # need restore auto_cast state as well as w/b list
J
JZ-LIANG 已提交
130 131
            if ctx.preserve_rng_state:
                with swith_rng_state(ctx.fw_cuda_rng_state):
132 133 134
                    with paddle.amp.auto_cast(
                            enable=ctx.is_fw_autocast,
                            custom_white_list=ctx.amp_white_list,
135 136
                            custom_black_list=ctx.amp_black_list,
                            level=ctx.amp_mode):
137 138 139 140 141 142
                        detached_inputs = detach_variable(tuple(inputs))
                        outputs = ctx.run_function(*detached_inputs)
            else:
                with paddle.amp.auto_cast(
                        enable=ctx.is_fw_autocast,
                        custom_white_list=ctx.amp_white_list,
143 144
                        custom_black_list=ctx.amp_black_list,
                        level=ctx.amp_mode):
J
JZ-LIANG 已提交
145 146 147 148 149 150 151 152 153
                    detached_inputs = detach_variable(tuple(inputs))
                    outputs = ctx.run_function(*detached_inputs)

            if isinstance(outputs, core.VarBase):
                outputs = (outputs, )
            assert len(outputs) == len(args)

            # run backward() with only tensor that requires grad
            forward_outputs_with_grad = []
154 155 156 157 158
            # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
            # pylayer will force the stop_gradient of attention mask to be False, which will make the number of 
            # tensor that need grad does not match.
            # the following backward_inputs_with_grad is used to avoid this case.
            backward_inputs_with_grad = []
J
JZ-LIANG 已提交
159 160 161 162
            for i in range(len(outputs)):
                if isinstance(outputs[i],
                              core.VarBase) and not outputs[i].stop_gradient:
                    forward_outputs_with_grad.append(outputs[i])
163 164
                    backward_inputs_with_grad.append(args[i])

J
JZ-LIANG 已提交
165 166 167 168 169 170
            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has requires_grad=True, this recompute() is not necessary"
                )

            # actually backward            
171 172
            paddle.autograd.backward(forward_outputs_with_grad,
                                     backward_inputs_with_grad)
J
JZ-LIANG 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

            grads = list(inp._grad_ivar() for inp in detached_inputs
                         if isinstance(inp, core.VarBase))
            return grads


def recompute(function, *args, **kwargs):
    """
    recompute intermediate activations to save then memory.

    Args:
        function: layer of sequence of layers that describes part of forward pass of the model whose 
        intermediate activations will be released to save memory in forward stage and will be recomputed 
        in backward stage for gradient calculation.
        preserve_rng_state(bool, optional):  if preserve the RNG state of forward and restore it in backward. 
        args: inputs to the function

    Returns:
        Output of function on args
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(
            arg for arg in kwargs))

    return RecomputeFunction.apply(function, preserve, *args)