recompute.py 20.1 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import contextlib
W
wuhuachaocoding 已提交
16
import weakref
17

J
JZ-LIANG 已提交
18
import paddle
W
wuhuachaocoding 已提交
19
from paddle import framework
20
from paddle.autograd import PyLayer
W
wuhuachaocoding 已提交
21 22 23 24
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
    get_rng_state_tracker,
)
from paddle.framework import core, in_dygraph_mode
J
JZ-LIANG 已提交
25

R
Roc 已提交
26
from ..utils.log_util import logger
J
JZ-LIANG 已提交
27

28 29
__all__ = []

J
JZ-LIANG 已提交
30 31 32 33

def detach_variable(inputs):
    out = []
    for inp in inputs:
S
ShenLiang 已提交
34
        if not isinstance(inp, (core.eager.Tensor, core.VarBase)):
J
JZ-LIANG 已提交
35 36 37 38 39 40 41 42 43 44
            out.append(inp)
            continue

        x = inp.detach()
        x.stop_gradient = inp.stop_gradient
        out.append(x)
    return tuple(out)


def check_recompute_necessary(inputs):
45
    if not any(
46
        not input_.stop_gradient
47 48 49
        for input_ in inputs
        if isinstance(input_, (core.eager.Tensor, paddle.Tensor))
    ):
R
Roc 已提交
50
        logger.warning(
J
JZ-LIANG 已提交
51
            "[Recompute]: None of the inputs to current recompute block need grad, "
52 53
            "therefore there is NO need to recompute this block in backward !"
        )
J
JZ-LIANG 已提交
54 55 56


@contextlib.contextmanager
57
def swith_rng_state_tracker(rng_state, tracker):
Q
QingshuChen 已提交
58 59 60
    orig_rng_state = paddle.get_rng_state()
    orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
    paddle.set_rng_state(rng_state)
61
    get_rng_state_tracker().set_states_tracker(tracker)
J
JZ-LIANG 已提交
62 63 64
    try:
        yield
    finally:
Q
QingshuChen 已提交
65 66
        paddle.set_rng_state(orig_rng_state)
        get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
J
JZ-LIANG 已提交
67 68 69 70


class RecomputeFunction(PyLayer):
    @staticmethod
71
    def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
72
        # store for recomputing
J
JZ-LIANG 已提交
73 74
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
75
        ctx.kwargs = kwargs
J
JZ-LIANG 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

        # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
        # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
        # None tensor inputs will be filtered in backward inputs.

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
        ctx.save_for_backward(*tensor_inputs)

        # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
        # one process with multiple gpu and mix-gpu-cpu senarios are not support
        if ctx.preserve_rng_state:
Q
QingshuChen 已提交
97 98
            ctx.fw_rng_state = paddle.get_rng_state()
            ctx.fwd_rng_state_tracker = (
99 100
                get_rng_state_tracker().get_states_tracker()
            )
J
JZ-LIANG 已提交
101 102

        # TODO support AMP
103
        tracer = framework._dygraph_tracer()
104 105 106
        ctx.is_fw_autocast = (
            False if tracer._amp_level == core.AmpLevel.O0 else True
        )
107 108 109 110
        if tracer._amp_level == core.AmpLevel.O2:
            ctx.amp_level = 'O2'
        elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
            ctx.amp_level = 'O1'
111
        else:
112 113 114
            raise ValueError(
                "unsupported amp level: {}".format(tracer._amp_level)
            )
115 116 117 118 119 120

        if tracer._amp_dtype == 'float16':
            ctx.amp_dtype = 'float16'
        elif tracer._amp_dtype in ('bfloat16', 'float32'):
            ctx.amp_dtype = 'bfloat16'
        else:
121 122 123
            raise ValueError(
                "unsupported amp dtype: {}".format(tracer._amp_dtype)
            )
124

125
        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
J
JZ-LIANG 已提交
126 127

        with paddle.no_grad():
128
            outputs = run_function(*args, **kwargs)
J
JZ-LIANG 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        return outputs

    @staticmethod
    def backward(ctx, *args):
        with paddle.fluid.dygraph.guard():
            # TODO need to check the recompute calling is vaild or not

            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensors = ctx.saved_tensor()
            for i, idx in enumerate(tensor_indices):
                inputs[idx] = tensors[i]

            # paddle.enable_grad()
            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

147 148
            # NOTE support AMP
            # need restore auto_cast state as well as w/b list
J
JZ-LIANG 已提交
149
            if ctx.preserve_rng_state:
150
                with swith_rng_state_tracker(
Q
QingshuChen 已提交
151
                    ctx.fw_rng_state, ctx.fwd_rng_state_tracker
152
                ):
153
                    with paddle.amp.auto_cast(
154 155 156 157 158 159
                        enable=ctx.is_fw_autocast,
                        custom_white_list=ctx.amp_white_list,
                        custom_black_list=ctx.amp_black_list,
                        level=ctx.amp_level,
                        dtype=ctx.amp_dtype,
                    ):
160
                        detached_inputs = detach_variable(tuple(inputs))
161 162 163
                        outputs = ctx.run_function(
                            *detached_inputs, **ctx.kwargs
                        )
164
            else:
165 166 167 168 169 170 171
                with paddle.amp.auto_cast(
                    enable=ctx.is_fw_autocast,
                    custom_white_list=ctx.amp_white_list,
                    custom_black_list=ctx.amp_black_list,
                    level=ctx.amp_level,
                    dtype=ctx.amp_dtype,
                ):
J
JZ-LIANG 已提交
172
                    detached_inputs = detach_variable(tuple(inputs))
173
                    outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
J
JZ-LIANG 已提交
174

175
            if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
176
                outputs = (outputs,)
J
JZ-LIANG 已提交
177 178 179 180
            assert len(outputs) == len(args)

            # run backward() with only tensor that requires grad
            forward_outputs_with_grad = []
181
            # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
182
            # pylayer will force the stop_gradient of attention mask to be False, which will make the number of
183 184 185
            # tensor that need grad does not match.
            # the following backward_inputs_with_grad is used to avoid this case.
            backward_inputs_with_grad = []
J
JZ-LIANG 已提交
186
            for i in range(len(outputs)):
187 188 189 190
                if (
                    isinstance(outputs[i], (core.VarBase, core.eager.Tensor))
                    and not outputs[i].stop_gradient
                ):
J
JZ-LIANG 已提交
191
                    forward_outputs_with_grad.append(outputs[i])
192 193
                    backward_inputs_with_grad.append(args[i])

J
JZ-LIANG 已提交
194 195 196 197 198
            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has requires_grad=True, this recompute() is not necessary"
                )

199 200
            # actually backward
            with paddle.amp.auto_cast(enable=False):
201 202 203
                paddle.autograd.backward(
                    forward_outputs_with_grad, backward_inputs_with_grad
                )
J
JZ-LIANG 已提交
204

205 206
            if in_dygraph_mode():
                grads = tuple(
207 208 209 210
                    inp._grad_ivar()
                    for inp in detached_inputs
                    if isinstance(inp, (core.VarBase, core.eager.Tensor))
                )
211 212
            else:
                grads = list(
213 214 215 216
                    inp._grad_ivar()
                    for inp in detached_inputs
                    if isinstance(inp, (core.VarBase, core.eager.Tensor))
                )
J
JZ-LIANG 已提交
217 218 219
            return grads


W
wuhuachaocoding 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
def _recompute_without_reentrant(
    function, preserve_rng_state=True, *args, **kwargs
):
    """
    recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd.
    """

    if preserve_rng_state:
        cur_device = paddle.get_device()
        if 'gpu:' not in cur_device:
            raise RuntimeError(
                "Recompute with RNG perserve is not support current device: {}.".format(
                    cur_device
                )
            )
        fw_cuda_rng_state = paddle.get_cuda_rng_state()
        fwd_cuda_rng_state_tracker = (
            get_rng_state_tracker().get_states_tracker()
        )
    tracer = framework._dygraph_tracer()
    is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
    if tracer._amp_level == core.AmpLevel.O2:
        amp_level = 'O2'
    elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
        amp_level = 'O1'

    if tracer._amp_dtype == 'float16':
        amp_dtype = 'float16'
    elif tracer._amp_dtype in ('bfloat16', 'float32'):
        amp_dtype = 'bfloat16'

    amp_white_list, amp_black_list = tracer._get_amp_op_list()

    class Intermediate_Holder:
        pass

    storage = weakref.WeakKeyDictionary()
    holder_list = []

    def pack(x):
        res = Intermediate_Holder()
        holder_list.append(weakref.ref(res))
        return res

    def unpack(x):
        unpack_counter = 0
        if len(storage) == 0:

            def inner_pack(inner_x):
                nonlocal unpack_counter
                unpack_counter += 1

                if holder_list[unpack_counter - 1]() is None:
                    return

                tmp_tensor = core.eager.Tensor(
                    inner_x.dtype,
                    inner_x.shape,
                    inner_x.name + "cpy",
                    core.VarDesc.VarType.LOD_TENSOR,
                    inner_x.persistable,
                )
                inner_x._share_buffer_to(tmp_tensor)
                storage[holder_list[unpack_counter - 1]()] = tmp_tensor
                return

            def inner_unpack(inner_x):
                raise Exception("An unexcepted backward called on a tensor!")

            if preserve_rng_state:
                with swith_rng_state_tracker(
                    fw_cuda_rng_state, fwd_cuda_rng_state_tracker
                ):
                    with paddle.set_grad_enabled(True):
                        with paddle.amp.auto_cast(
                            enable=is_fw_autocast,
                            custom_white_list=amp_white_list,
                            custom_black_list=amp_black_list,
                            level=amp_level,
                            dtype=amp_dtype,
                        ):
                            with paddle.autograd.saved_tensors_hooks(
                                inner_pack, inner_unpack
                            ):
                                unused_outputs = function(*args, **kwargs)
            else:
                with paddle.set_grad_enabled(True), paddle.amp.auto_cast(
                    enable=is_fw_autocast,
                    custom_white_list=amp_white_list,
                    custom_black_list=amp_black_list,
                    level=amp_level,
                    dtype=amp_dtype,
                ), paddle.autograd.saved_tensors_hooks(
                    inner_pack, inner_unpack
                ):
                    unused_outputs = function(*args, **kwargs)

        if x not in storage:
            raise Exception(
                "Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute."
            )

        return storage[x]

    with paddle.autograd.saved_tensors_hooks(pack, unpack):
        outputs = function(*args, **kwargs)

    return outputs


J
JZ-LIANG 已提交
330 331 332 333
def recompute(function, *args, **kwargs):
    """
    recompute intermediate activations to save then memory.

334
    Parameters:
335
        function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
336 337 338
              whose intermediate activations will be released to save memory in forward stage and will be recomputed
              in backward stage for gradient calculation.
        *args(Tensor): inputs to the function.
W
wuhuachaocoding 已提交
339 340 341 342 343 344 345
        **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params,
                        and the other contains 'preserve_rng_state' and 'use_reentrant'. the key-value pair of preserve_rng_state,
                        which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value
                        will be restored when the forward recalculation of backpropagation is performed, its default value is True.
                        the key-value pair of use_reentrant is used to indicate which implementation of recompute you will be used.
                        'use_reentrant=True' means to use the PyLayer implementation of recompute, 'use_reentrant=False' means to
                        use the Hook implementation of recompute, its default value is True.
J
JZ-LIANG 已提交
346
    Returns:
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
        Output of function on args.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed.fleet.utils import recompute
            import random
            # required: gpu
            def get_fc_block(block_idx, input_size, is_last=False):
                block_name = "block_" + str(block_idx)
                block = paddle.nn.Sequential(
                    (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
                    (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
                    (block_name + "_relu_1", paddle.nn.ReLU()),
                    (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
                    (block_name + "_relu_2", paddle.nn.ReLU()),
                )
                if is_last:
                    block.add_sublayer(
                        block_name + "_fc_2",
                        paddle.nn.Linear(
                            input_size, 1, bias_attr=False
                        )
                    )
                else:
                    block.add_sublayer(
                        block_name + "_fc_2",
                        paddle.nn.Linear(input_size, input_size, bias_attr=False)
                    )
                return block
            class Naive_fc_net(paddle.nn.Layer):
                def __init__(self, input_size=10,
                            recompute_blocks=[1, 3],
                            recompute_kwargs={}):
382
                    super(Naive_fc_net, self).__init__()
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
                    self.recompute_blocks = recompute_blocks
                    self.recompute_kwargs = recompute_kwargs
                    self.runfunc0 = get_fc_block(0, input_size, is_last=False)
                    self.runfunc1 = get_fc_block(1, input_size, is_last=False)
                    self.runfunc2 = get_fc_block(2, input_size, is_last=False)
                    self.runfunc3 = get_fc_block(3, input_size, is_last=False)
                    self.runfunc4 = get_fc_block(4, input_size, is_last=True)
                    self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
                def forward(self, inputs):
                    nums = len(self.total_func)
                    for i in range(nums):
                        if i in self.recompute_blocks:
                            inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
                        else:
                            inputs = self.total_func[i](inputs)
                    return inputs
            def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
                gen = paddle.seed(10)
                gen.manual_seed(10)
                random.seed(10)
                if cuda_state:
                    paddle.set_cuda_rng_state(cuda_state)
                batch_size, input_size = 1, 10
                model = Naive_fc_net(
                    input_size,
                    recompute_blocks=recompute_block,
                    recompute_kwargs=recompute_kwargs)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                loss_ = []
                param_ = []
                grad_ = []
                for _ in range(5):
415
                    x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
416 417
                    y_pred = model(x)
                    loss = y_pred.mean()
418
                    loss_.append(loss.item())
419 420
                    loss.backward()
                    optimizer.step()
421 422
                    param_.append(model.parameters()[9])
                    grad_.append(model.parameters()[3]._grad_ivar())
423 424 425 426 427 428 429 430 431 432
                    optimizer.clear_grad()
                return loss_, param_, grad_
            cuda_state = paddle.get_cuda_rng_state()
            # without recompute
            loss_ref, param_ref, grad_ref = run_model(
                cuda_state, recompute_block=[]
            )
            loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
            print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
            # The result of the recompute_loss should be the same as the normal_loss.
J
JZ-LIANG 已提交
433 434 435 436
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)

W
wuhuachaocoding 已提交
437 438 439 440 441 442 443 444
    # whether to use reentrant method to implement recompute
    use_reentrant = kwargs.pop('use_reentrant', True)

    if kwargs and use_reentrant:
        raise ValueError(
            "Error, if you want to send kwargs(dict parameter) to function, please set use_reentrant=False."
        )

445 446 447
    if framework._dygraph_tracer()._has_grad:
        check_recompute_necessary(args)

W
wuhuachaocoding 已提交
448 449 450 451
    if use_reentrant:
        return RecomputeFunction.apply(function, preserve, *args)
    else:
        return _recompute_without_reentrant(function, preserve, *args, **kwargs)
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496


def recompute_sequential(ctx, functions, *args, **kwargs):
    """
    recompute intermediate activations to save then memory for 'Sequential' models.

    Parameters:
        ctx(dict): include 'segments' and  'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
                   the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
                   restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here,
                   they are useful in 'recompute_hybrid' API.
        functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
              whose intermediate activations will be released to save memory in forward stage and will be recomputed
              in backward stage for gradient calculation.
        *args(Tensor): inputs(tuple) to the function.
        **kwargs(Dict): inputs(dict) to the function.

    Returns:
        Output of function on args and kwargs.

    Examples:
        .. code-block:: python

            model = paddle.nn.Sequential(...)
            input = recompute_sequential({'segments' : 1}, model, input)
    """
    segments = ctx.get('segments', 1)
    preserve_rng_state = ctx.get('preserve_rng_state', True)

    def _run_func(begin, end, funcs):
        def do_run(input):
            for i in range(begin, end + 1):
                input = funcs[i](input)
            return input

        return do_run

    if isinstance(functions, paddle.nn.Sequential):
        functions = list(functions.children())

    segment_size = len(functions) // segments

    end = -1
    for begin in range(0, segment_size * (segments - 1), segment_size):
        end = begin + segment_size - 1
497 498 499 500 501 502
        args = recompute(
            _run_func(begin, end, functions),
            *args,
            preserve_rng_state=preserve_rng_state,
            **kwargs
        )
503
    return _run_func(end + 1, len(functions) - 1, functions)(args)