recompute.py 22.2 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
17 18
from paddle.autograd import PyLayer
from paddle.autograd.py_layer import LegacyPyLayer
S
ShenLiang 已提交
19

J
JZ-LIANG 已提交
20 21
from paddle.fluid import framework
import contextlib
S
ShenLiang 已提交
22
from paddle.fluid.framework import in_dygraph_mode
J
JZ-LIANG 已提交
23 24

import logging
R
Roc 已提交
25
from ..utils.log_util import logger
J
JZ-LIANG 已提交
26

27 28
__all__ = []

J
JZ-LIANG 已提交
29 30 31 32

def detach_variable(inputs):
    out = []
    for inp in inputs:
S
ShenLiang 已提交
33
        if not isinstance(inp, (core.eager.Tensor, core.VarBase)):
J
JZ-LIANG 已提交
34 35 36 37 38 39 40 41 42 43
            out.append(inp)
            continue

        x = inp.detach()
        x.stop_gradient = inp.stop_gradient
        out.append(x)
    return tuple(out)


def check_recompute_necessary(inputs):
44 45 46 47 48
    if not any(
        input_.stop_gradient == False
        for input_ in inputs
        if isinstance(input_, (core.eager.Tensor, paddle.Tensor))
    ):
R
Roc 已提交
49
        logger.warning(
J
JZ-LIANG 已提交
50
            "[Recompute]: None of the inputs to current recompute block need grad, "
51 52
            "therefore there is NO need to recompute this block in backward !"
        )
J
JZ-LIANG 已提交
53 54 55


@contextlib.contextmanager
56
def swith_rng_state_tracker(rng_state, tracker):
57 58 59 60
    from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
        get_rng_state_tracker,
    )

J
JZ-LIANG 已提交
61
    orig_cuda_rng_state = paddle.get_cuda_rng_state()
62 63
    orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()

J
JZ-LIANG 已提交
64
    paddle.set_cuda_rng_state(rng_state)
65
    get_rng_state_tracker().set_states_tracker(tracker)
J
JZ-LIANG 已提交
66 67 68 69
    try:
        yield
    finally:
        paddle.set_cuda_rng_state(orig_cuda_rng_state)
70
        get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
J
JZ-LIANG 已提交
71 72


73
class LegacyRecomputeFunction(LegacyPyLayer):
S
ShenLiang 已提交
74 75
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
76 77 78
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
            get_rng_state_tracker,
        )
S
ShenLiang 已提交
79

80
        # store for recomputing
S
ShenLiang 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state

        # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
        # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
        # None tensor inputs will be filtered in backward inputs.

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
        ctx.save_for_backward(*tensor_inputs)

        # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
        # one process with multiple gpu and mix-gpu-cpu senarios are not support
        if ctx.preserve_rng_state:
            cur_device = paddle.get_device()
            if 'gpu:' not in cur_device:
                raise RuntimeError(
107 108 109 110
                    "Recompute with RNG perserve is not support current device: {}.".format(
                        cur_device
                    )
                )
S
ShenLiang 已提交
111
            ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
112 113 114
            ctx.fwd_cuda_rng_state_tracker = (
                get_rng_state_tracker().get_states_tracker()
            )
S
ShenLiang 已提交
115 116 117

        # TODO support AMP
        tracer = framework._dygraph_tracer()
118 119 120
        ctx.is_fw_autocast = (
            False if tracer._amp_level == core.AmpLevel.O0 else True
        )
S
ShenLiang 已提交
121 122 123 124 125
        if tracer._amp_level == core.AmpLevel.O2:
            ctx.amp_level = 'O2'
        elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
            ctx.amp_level = 'O1'
        else:
126 127 128
            raise ValueError(
                "unsupported amp level: {}".format(tracer._amp_level)
            )
S
ShenLiang 已提交
129 130 131 132 133 134

        if tracer._amp_dtype == 'float16':
            ctx.amp_dtype = 'float16'
        elif tracer._amp_dtype in ('bfloat16', 'float32'):
            ctx.amp_dtype = 'bfloat16'
        else:
135 136 137
            raise ValueError(
                "unsupported amp dtype: {}".format(tracer._amp_dtype)
            )
S
ShenLiang 已提交
138 139 140 141 142 143 144 145 146

        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

        with paddle.no_grad():
            outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
147 148 149 150
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
            get_rng_state_tracker,
        )

S
ShenLiang 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        with paddle.fluid.dygraph.guard():
            # TODO need to check the recompute calling is vaild or not

            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensors = ctx.saved_tensor()
            for i, idx in enumerate(tensor_indices):
                inputs[idx] = tensors[i]

            # paddle.enable_grad()
            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

            # NOTE support AMP
            # need restore auto_cast state as well as w/b list
            if ctx.preserve_rng_state:
168 169 170
                with swith_rng_state_tracker(
                    ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
                ):
S
ShenLiang 已提交
171
                    with paddle.amp.auto_cast(
172 173 174 175 176 177
                        enable=ctx.is_fw_autocast,
                        custom_white_list=ctx.amp_white_list,
                        custom_black_list=ctx.amp_black_list,
                        level=ctx.amp_level,
                        dtype=ctx.amp_dtype,
                    ):
S
ShenLiang 已提交
178 179 180
                        detached_inputs = detach_variable(tuple(inputs))
                        outputs = ctx.run_function(*detached_inputs)
            else:
181 182 183 184 185 186 187
                with paddle.amp.auto_cast(
                    enable=ctx.is_fw_autocast,
                    custom_white_list=ctx.amp_white_list,
                    custom_black_list=ctx.amp_black_list,
                    level=ctx.amp_level,
                    dtype=ctx.amp_dtype,
                ):
S
ShenLiang 已提交
188 189 190
                    detached_inputs = detach_variable(tuple(inputs))
                    outputs = ctx.run_function(*detached_inputs)

191
            if isinstance(outputs, core.VarBase):
192
                outputs = (outputs,)
S
ShenLiang 已提交
193 194 195 196 197
            assert len(outputs) == len(args)

            # run backward() with only tensor that requires grad
            forward_outputs_with_grad = []
            # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
198
            # pylayer will force the stop_gradient of attention mask to be False, which will make the number of
S
ShenLiang 已提交
199 200 201 202
            # tensor that need grad does not match.
            # the following backward_inputs_with_grad is used to avoid this case.
            backward_inputs_with_grad = []
            for i in range(len(outputs)):
203 204 205 206
                if (
                    isinstance(outputs[i], core.VarBase)
                    and not outputs[i].stop_gradient
                ):
S
ShenLiang 已提交
207 208 209 210 211 212 213 214 215 216
                    forward_outputs_with_grad.append(outputs[i])
                    backward_inputs_with_grad.append(args[i])

            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has requires_grad=True, this recompute() is not necessary"
                )

            # actually backward
            with paddle.amp.auto_cast(enable=False):
217 218 219
                paddle.autograd.backward(
                    forward_outputs_with_grad, backward_inputs_with_grad
                )
S
ShenLiang 已提交
220

221 222 223 224 225
            grads = list(
                inp._grad_ivar()
                for inp in detached_inputs
                if isinstance(inp, core.VarBase)
            )
S
ShenLiang 已提交
226 227 228
            return grads


J
JZ-LIANG 已提交
229 230
class RecomputeFunction(PyLayer):
    @staticmethod
231
    def forward(ctx, run_function, preserve_rng_state, *args, **kwargs):
232 233 234
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
            get_rng_state_tracker,
        )
J
JZ-LIANG 已提交
235

236
        # store for recomputing
J
JZ-LIANG 已提交
237 238
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
239
        ctx.kwargs = kwargs
J
JZ-LIANG 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263

        # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
        # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
        # None tensor inputs will be filtered in backward inputs.

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
        ctx.save_for_backward(*tensor_inputs)

        # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
        # one process with multiple gpu and mix-gpu-cpu senarios are not support
        if ctx.preserve_rng_state:
            cur_device = paddle.get_device()
            if 'gpu:' not in cur_device:
                raise RuntimeError(
264 265 266 267
                    "Recompute with RNG perserve is not support current device: {}.".format(
                        cur_device
                    )
                )
J
JZ-LIANG 已提交
268
            ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
269 270 271
            ctx.fwd_cuda_rng_state_tracker = (
                get_rng_state_tracker().get_states_tracker()
            )
J
JZ-LIANG 已提交
272 273

        # TODO support AMP
274
        tracer = framework._dygraph_tracer()
275 276 277
        ctx.is_fw_autocast = (
            False if tracer._amp_level == core.AmpLevel.O0 else True
        )
278 279 280 281
        if tracer._amp_level == core.AmpLevel.O2:
            ctx.amp_level = 'O2'
        elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
            ctx.amp_level = 'O1'
282
        else:
283 284 285
            raise ValueError(
                "unsupported amp level: {}".format(tracer._amp_level)
            )
286 287 288 289 290 291

        if tracer._amp_dtype == 'float16':
            ctx.amp_dtype = 'float16'
        elif tracer._amp_dtype in ('bfloat16', 'float32'):
            ctx.amp_dtype = 'bfloat16'
        else:
292 293 294
            raise ValueError(
                "unsupported amp dtype: {}".format(tracer._amp_dtype)
            )
295

296
        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
J
JZ-LIANG 已提交
297 298

        with paddle.no_grad():
299
            outputs = run_function(*args, **kwargs)
J
JZ-LIANG 已提交
300 301 302 303
        return outputs

    @staticmethod
    def backward(ctx, *args):
304 305 306 307
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
            get_rng_state_tracker,
        )

J
JZ-LIANG 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321
        with paddle.fluid.dygraph.guard():
            # TODO need to check the recompute calling is vaild or not

            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensors = ctx.saved_tensor()
            for i, idx in enumerate(tensor_indices):
                inputs[idx] = tensors[i]

            # paddle.enable_grad()
            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

322 323
            # NOTE support AMP
            # need restore auto_cast state as well as w/b list
J
JZ-LIANG 已提交
324
            if ctx.preserve_rng_state:
325 326 327
                with swith_rng_state_tracker(
                    ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
                ):
328
                    with paddle.amp.auto_cast(
329 330 331 332 333 334
                        enable=ctx.is_fw_autocast,
                        custom_white_list=ctx.amp_white_list,
                        custom_black_list=ctx.amp_black_list,
                        level=ctx.amp_level,
                        dtype=ctx.amp_dtype,
                    ):
335
                        detached_inputs = detach_variable(tuple(inputs))
336 337 338
                        outputs = ctx.run_function(
                            *detached_inputs, **ctx.kwargs
                        )
339
            else:
340 341 342 343 344 345 346
                with paddle.amp.auto_cast(
                    enable=ctx.is_fw_autocast,
                    custom_white_list=ctx.amp_white_list,
                    custom_black_list=ctx.amp_black_list,
                    level=ctx.amp_level,
                    dtype=ctx.amp_dtype,
                ):
J
JZ-LIANG 已提交
347
                    detached_inputs = detach_variable(tuple(inputs))
348
                    outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
J
JZ-LIANG 已提交
349

350
            if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
351
                outputs = (outputs,)
J
JZ-LIANG 已提交
352 353 354 355
            assert len(outputs) == len(args)

            # run backward() with only tensor that requires grad
            forward_outputs_with_grad = []
356
            # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
357
            # pylayer will force the stop_gradient of attention mask to be False, which will make the number of
358 359 360
            # tensor that need grad does not match.
            # the following backward_inputs_with_grad is used to avoid this case.
            backward_inputs_with_grad = []
J
JZ-LIANG 已提交
361
            for i in range(len(outputs)):
362 363 364 365
                if (
                    isinstance(outputs[i], (core.VarBase, core.eager.Tensor))
                    and not outputs[i].stop_gradient
                ):
J
JZ-LIANG 已提交
366
                    forward_outputs_with_grad.append(outputs[i])
367 368
                    backward_inputs_with_grad.append(args[i])

J
JZ-LIANG 已提交
369 370 371 372 373
            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has requires_grad=True, this recompute() is not necessary"
                )

374 375
            # actually backward
            with paddle.amp.auto_cast(enable=False):
376 377 378
                paddle.autograd.backward(
                    forward_outputs_with_grad, backward_inputs_with_grad
                )
J
JZ-LIANG 已提交
379

380 381
            if in_dygraph_mode():
                grads = tuple(
382 383 384 385
                    inp._grad_ivar()
                    for inp in detached_inputs
                    if isinstance(inp, (core.VarBase, core.eager.Tensor))
                )
386 387
            else:
                grads = list(
388 389 390 391
                    inp._grad_ivar()
                    for inp in detached_inputs
                    if isinstance(inp, (core.VarBase, core.eager.Tensor))
                )
J
JZ-LIANG 已提交
392 393 394 395 396 397 398
            return grads


def recompute(function, *args, **kwargs):
    """
    recompute intermediate activations to save then memory.

399
    Parameters:
400 401 402 403 404 405 406
        function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
              whose intermediate activations will be released to save memory in forward stage and will be recomputed
              in backward stage for gradient calculation.
        *args(Tensor): inputs to the function.
        **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to
              indicate whether to save the forward rng. If it is True, then the last forward rng value will be
              restored when the forward recalculation of backpropagation is performed. The default
407
              preserve_rng_state is True.
J
JZ-LIANG 已提交
408 409

    Returns:
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        Output of function on args.

    Examples:
        .. code-block:: python

            import paddle
            from paddle.distributed.fleet.utils import recompute
            import random
            # required: gpu
            def get_fc_block(block_idx, input_size, is_last=False):
                block_name = "block_" + str(block_idx)
                block = paddle.nn.Sequential(
                    (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
                    (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
                    (block_name + "_relu_1", paddle.nn.ReLU()),
                    (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
                    (block_name + "_relu_2", paddle.nn.ReLU()),
                )
                if is_last:
                    block.add_sublayer(
                        block_name + "_fc_2",
                        paddle.nn.Linear(
                            input_size, 1, bias_attr=False
                        )
                    )
                else:
                    block.add_sublayer(
                        block_name + "_fc_2",
                        paddle.nn.Linear(input_size, input_size, bias_attr=False)
                    )
                return block
            class Naive_fc_net(paddle.nn.Layer):
                def __init__(self, input_size=10,
                            recompute_blocks=[1, 3],
                            recompute_kwargs={}):
                    super(Naive_fc_net, self).__init__()
                    self.recompute_blocks = recompute_blocks
                    self.recompute_kwargs = recompute_kwargs
                    self.runfunc0 = get_fc_block(0, input_size, is_last=False)
                    self.runfunc1 = get_fc_block(1, input_size, is_last=False)
                    self.runfunc2 = get_fc_block(2, input_size, is_last=False)
                    self.runfunc3 = get_fc_block(3, input_size, is_last=False)
                    self.runfunc4 = get_fc_block(4, input_size, is_last=True)
                    self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
                def forward(self, inputs):
                    nums = len(self.total_func)
                    for i in range(nums):
                        if i in self.recompute_blocks:
                            inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
                        else:
                            inputs = self.total_func[i](inputs)
                    return inputs
            def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
                gen = paddle.seed(10)
                gen.manual_seed(10)
                random.seed(10)
                if cuda_state:
                    paddle.set_cuda_rng_state(cuda_state)
                batch_size, input_size = 1, 10
                model = Naive_fc_net(
                    input_size,
                    recompute_blocks=recompute_block,
                    recompute_kwargs=recompute_kwargs)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                loss_ = []
                param_ = []
                grad_ = []
                for _ in range(5):
478
                    x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
479 480
                    y_pred = model(x)
                    loss = y_pred.mean()
481
                    loss_.append(loss.item())
482 483
                    loss.backward()
                    optimizer.step()
484 485
                    param_.append(model.parameters()[9])
                    grad_.append(model.parameters()[3]._grad_ivar())
486 487 488 489 490 491 492 493 494 495
                    optimizer.clear_grad()
                return loss_, param_, grad_
            cuda_state = paddle.get_cuda_rng_state()
            # without recompute
            loss_ref, param_ref, grad_ref = run_model(
                cuda_state, recompute_block=[]
            )
            loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
            print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
            # The result of the recompute_loss should be the same as the normal_loss.
J
JZ-LIANG 已提交
496 497 498 499
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)

500 501 502
    if framework._dygraph_tracer()._has_grad:
        check_recompute_necessary(args)

503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
    return RecomputeFunction.apply(function, preserve, *args, **kwargs)


def recompute_sequential(ctx, functions, *args, **kwargs):
    """
    recompute intermediate activations to save then memory for 'Sequential' models.

    Parameters:
        ctx(dict): include 'segments' and  'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model,
                   the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be
                   restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here,
                   they are useful in 'recompute_hybrid' API.
        functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model
              whose intermediate activations will be released to save memory in forward stage and will be recomputed
              in backward stage for gradient calculation.
        *args(Tensor): inputs(tuple) to the function.
        **kwargs(Dict): inputs(dict) to the function.

    Returns:
        Output of function on args and kwargs.

    Examples:
        .. code-block:: python

            model = paddle.nn.Sequential(...)
            input = recompute_sequential({'segments' : 1}, model, input)
    """
    segments = ctx.get('segments', 1)
    preserve_rng_state = ctx.get('preserve_rng_state', True)

    def _run_func(begin, end, funcs):
        def do_run(input):
            for i in range(begin, end + 1):
                input = funcs[i](input)
            return input

        return do_run

    if isinstance(functions, paddle.nn.Sequential):
        functions = list(functions.children())

    segment_size = len(functions) // segments

    end = -1
    for begin in range(0, segment_size * (segments - 1), segment_size):
        end = begin + segment_size - 1
549 550 551 552 553 554
        args = recompute(
            _run_func(begin, end, functions),
            *args,
            preserve_rng_state=preserve_rng_state,
            **kwargs
        )
555
    return _run_func(end + 1, len(functions) - 1, functions)(args)