recompute.py 19.8 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
S
ShenLiang 已提交
17 18
from paddle.autograd import PyLayer, EagerPyLayer

J
JZ-LIANG 已提交
19 20
from paddle.fluid import framework
import contextlib
S
ShenLiang 已提交
21
from paddle.fluid.framework import in_dygraph_mode
J
JZ-LIANG 已提交
22 23

import logging
24

25
logger = logging.getLogger(__name__)
26 27
formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
                              datefmt='%Y-%m-%d %H:%M:%S')
28 29 30
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
J
JZ-LIANG 已提交
31

32 33
__all__ = []

J
JZ-LIANG 已提交
34 35 36 37

def detach_variable(inputs):
    out = []
    for inp in inputs:
S
ShenLiang 已提交
38
        if not isinstance(inp, (core.eager.Tensor, core.VarBase)):
J
JZ-LIANG 已提交
39 40 41 42 43 44 45 46 47 48 49
            out.append(inp)
            continue

        x = inp.detach()
        x.stop_gradient = inp.stop_gradient
        out.append(x)
    return tuple(out)


def check_recompute_necessary(inputs):
    if not any(input_.stop_gradient == False for input_ in inputs
S
ShenLiang 已提交
50
               if isinstance(input_, (core.eager.Tensor, paddle.Tensor))):
51
        logger.warn(
J
JZ-LIANG 已提交
52 53 54 55 56
            "[Recompute]: None of the inputs to current recompute block need grad, "
            "therefore there is NO need to recompute this block in backward !")


@contextlib.contextmanager
57 58
def swith_rng_state_tracker(rng_state, tracker):
    from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
J
JZ-LIANG 已提交
59
    orig_cuda_rng_state = paddle.get_cuda_rng_state()
60 61
    orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()

J
JZ-LIANG 已提交
62
    paddle.set_cuda_rng_state(rng_state)
63
    get_rng_state_tracker().set_states_tracker(tracker)
J
JZ-LIANG 已提交
64 65 66 67
    try:
        yield
    finally:
        paddle.set_cuda_rng_state(orig_cuda_rng_state)
68
        get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
J
JZ-LIANG 已提交
69 70


S
ShenLiang 已提交
71
class EagerRecomputeFunction(EagerPyLayer):
72

S
ShenLiang 已提交
73 74
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
75
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
S
ShenLiang 已提交
76

77
        # store for recomputing
S
ShenLiang 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state

        # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
        # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
        # None tensor inputs will be filtered in backward inputs.

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
        ctx.save_for_backward(*tensor_inputs)

        # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
        # one process with multiple gpu and mix-gpu-cpu senarios are not support
        if ctx.preserve_rng_state:
            cur_device = paddle.get_device()
            if 'gpu:' not in cur_device:
                raise RuntimeError(
104 105
                    "Recompute with RNG perserve is not support current device: {}."
                    .format(cur_device))
S
ShenLiang 已提交
106
            ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
107 108
            ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
            ).get_states_tracker()
S
ShenLiang 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

        # TODO support AMP
        tracer = framework._dygraph_tracer()
        ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
        if tracer._amp_level == core.AmpLevel.O2:
            ctx.amp_level = 'O2'
        elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
            ctx.amp_level = 'O1'
        else:
            raise ValueError("unsupported amp level: {}".format(
                tracer._amp_level))

        if tracer._amp_dtype == 'float16':
            ctx.amp_dtype = 'float16'
        elif tracer._amp_dtype in ('bfloat16', 'float32'):
            ctx.amp_dtype = 'bfloat16'
        else:
            raise ValueError("unsupported amp dtype: {}".format(
                tracer._amp_dtype))

        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

        with paddle.no_grad():
            outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
137
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
S
ShenLiang 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
        with paddle.fluid.dygraph.guard():
            # TODO need to check the recompute calling is vaild or not

            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensors = ctx.saved_tensor()
            for i, idx in enumerate(tensor_indices):
                inputs[idx] = tensors[i]

            # paddle.enable_grad()
            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

            # NOTE support AMP
            # need restore auto_cast state as well as w/b list
            if ctx.preserve_rng_state:
155 156
                with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
                                             ctx.fwd_cuda_rng_state_tracker):
S
ShenLiang 已提交
157 158 159 160 161 162 163 164 165
                    with paddle.amp.auto_cast(
                            enable=ctx.is_fw_autocast,
                            custom_white_list=ctx.amp_white_list,
                            custom_black_list=ctx.amp_black_list,
                            level=ctx.amp_level,
                            dtype=ctx.amp_dtype):
                        detached_inputs = detach_variable(tuple(inputs))
                        outputs = ctx.run_function(*detached_inputs)
            else:
166 167 168 169 170
                with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
                                          custom_white_list=ctx.amp_white_list,
                                          custom_black_list=ctx.amp_black_list,
                                          level=ctx.amp_level,
                                          dtype=ctx.amp_dtype):
S
ShenLiang 已提交
171 172 173 174 175 176 177 178 179 180
                    detached_inputs = detach_variable(tuple(inputs))
                    outputs = ctx.run_function(*detached_inputs)

            if isinstance(outputs, core.eager.Tensor):
                outputs = (outputs, )
            assert len(outputs) == len(args)

            # run backward() with only tensor that requires grad
            forward_outputs_with_grad = []
            # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
181
            # pylayer will force the stop_gradient of attention mask to be False, which will make the number of
S
ShenLiang 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
            # tensor that need grad does not match.
            # the following backward_inputs_with_grad is used to avoid this case.
            backward_inputs_with_grad = []
            for i in range(len(outputs)):
                if isinstance(
                        outputs[i],
                        core.eager.Tensor) and not outputs[i].stop_gradient:
                    forward_outputs_with_grad.append(outputs[i])
                    backward_inputs_with_grad.append(args[i])

            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has requires_grad=True, this recompute() is not necessary"
                )

            # actually backward
            with paddle.amp.auto_cast(enable=False):
                paddle.autograd.backward(forward_outputs_with_grad,
                                         backward_inputs_with_grad)

202 203
            grads = tuple(inp.grad for inp in detached_inputs
                          if isinstance(inp, core.eager.Tensor))
S
ShenLiang 已提交
204 205 206
            return grads


J
JZ-LIANG 已提交
207
class RecomputeFunction(PyLayer):
208

J
JZ-LIANG 已提交
209 210
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
211
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
J
JZ-LIANG 已提交
212

213
        # store for recomputing
J
JZ-LIANG 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state

        # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
        # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
        # None tensor inputs will be filtered in backward inputs.

        # save input for backward
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if paddle.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
        ctx.save_for_backward(*tensor_inputs)

        # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
        # one process with multiple gpu and mix-gpu-cpu senarios are not support
        if ctx.preserve_rng_state:
            cur_device = paddle.get_device()
            if 'gpu:' not in cur_device:
                raise RuntimeError(
240 241
                    "Recompute with RNG perserve is not support current device: {}."
                    .format(cur_device))
J
JZ-LIANG 已提交
242
            ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
243 244
            ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
            ).get_states_tracker()
J
JZ-LIANG 已提交
245 246

        # TODO support AMP
247
        tracer = framework._dygraph_tracer()
248 249 250 251 252
        ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
        if tracer._amp_level == core.AmpLevel.O2:
            ctx.amp_level = 'O2'
        elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
            ctx.amp_level = 'O1'
253
        else:
254 255
            raise ValueError("unsupported amp level: {}".format(
                tracer._amp_level))
256 257 258 259 260 261 262 263 264

        if tracer._amp_dtype == 'float16':
            ctx.amp_dtype = 'float16'
        elif tracer._amp_dtype in ('bfloat16', 'float32'):
            ctx.amp_dtype = 'bfloat16'
        else:
            raise ValueError("unsupported amp dtype: {}".format(
                tracer._amp_dtype))

265
        ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
J
JZ-LIANG 已提交
266 267 268 269 270 271 272

        with paddle.no_grad():
            outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
273
        from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
J
JZ-LIANG 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287
        with paddle.fluid.dygraph.guard():
            # TODO need to check the recompute calling is vaild or not

            # Restore inputs
            inputs = list(ctx.inputs)
            tensor_indices = ctx.tensor_indices
            tensors = ctx.saved_tensor()
            for i, idx in enumerate(tensor_indices):
                inputs[idx] = tensors[i]

            # paddle.enable_grad()
            tracer = framework._dygraph_tracer()
            tracer._has_grad = True

288 289
            # NOTE support AMP
            # need restore auto_cast state as well as w/b list
J
JZ-LIANG 已提交
290
            if ctx.preserve_rng_state:
291 292
                with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
                                             ctx.fwd_cuda_rng_state_tracker):
293 294 295
                    with paddle.amp.auto_cast(
                            enable=ctx.is_fw_autocast,
                            custom_white_list=ctx.amp_white_list,
296
                            custom_black_list=ctx.amp_black_list,
297 298
                            level=ctx.amp_level,
                            dtype=ctx.amp_dtype):
299 300 301
                        detached_inputs = detach_variable(tuple(inputs))
                        outputs = ctx.run_function(*detached_inputs)
            else:
302 303 304 305 306
                with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
                                          custom_white_list=ctx.amp_white_list,
                                          custom_black_list=ctx.amp_black_list,
                                          level=ctx.amp_level,
                                          dtype=ctx.amp_dtype):
J
JZ-LIANG 已提交
307 308 309 310 311 312 313 314 315
                    detached_inputs = detach_variable(tuple(inputs))
                    outputs = ctx.run_function(*detached_inputs)

            if isinstance(outputs, core.VarBase):
                outputs = (outputs, )
            assert len(outputs) == len(args)

            # run backward() with only tensor that requires grad
            forward_outputs_with_grad = []
316
            # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
317
            # pylayer will force the stop_gradient of attention mask to be False, which will make the number of
318 319 320
            # tensor that need grad does not match.
            # the following backward_inputs_with_grad is used to avoid this case.
            backward_inputs_with_grad = []
J
JZ-LIANG 已提交
321 322 323 324
            for i in range(len(outputs)):
                if isinstance(outputs[i],
                              core.VarBase) and not outputs[i].stop_gradient:
                    forward_outputs_with_grad.append(outputs[i])
325 326
                    backward_inputs_with_grad.append(args[i])

J
JZ-LIANG 已提交
327 328 329 330 331
            if len(forward_outputs_with_grad) == 0:
                raise RuntimeError(
                    "none of output has requires_grad=True, this recompute() is not necessary"
                )

332 333 334 335
            # actually backward
            with paddle.amp.auto_cast(enable=False):
                paddle.autograd.backward(forward_outputs_with_grad,
                                         backward_inputs_with_grad)
J
JZ-LIANG 已提交
336 337 338 339 340 341 342 343 344 345

            grads = list(inp._grad_ivar() for inp in detached_inputs
                         if isinstance(inp, core.VarBase))
            return grads


def recompute(function, *args, **kwargs):
    """
    recompute intermediate activations to save then memory.

346 347 348 349 350 351 352 353 354
    Parameters:
        function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model  
              whose intermediate activations will be released to save memory in forward stage and will be recomputed 
              in backward stage for gradient calculation. 
        *args(Tensor): inputs to the function.    
        **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to 
              indicate whether to save the forward rng. If it is True, then the last forward rng value will be 
              restored when the forward recalculation of backpropagation is performed. The default 
              preserve_rng_state is True.
J
JZ-LIANG 已提交
355 356

    Returns:
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
        Output of function on args.

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            from paddle.distributed.fleet.utils import recompute
            import random

            # required: gpu

            def get_fc_block(block_idx, input_size, is_last=False):
                block_name = "block_" + str(block_idx)
                block = paddle.nn.Sequential(
                    (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
                    (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
                    (block_name + "_relu_1", paddle.nn.ReLU()),
                    (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
                    (block_name + "_relu_2", paddle.nn.ReLU()),
                )
                if is_last:
                    block.add_sublayer(
                        block_name + "_fc_2",
                        paddle.nn.Linear(
                            input_size, 1, bias_attr=False
                        )
                    )
                else:
                    block.add_sublayer(
                        block_name + "_fc_2",
                        paddle.nn.Linear(input_size, input_size, bias_attr=False)
                    )

                return block


            class Naive_fc_net(paddle.nn.Layer):
                def __init__(self, input_size=10,
                            recompute_blocks=[1, 3],
                            recompute_kwargs={}):
                    super(Naive_fc_net, self).__init__()
                    self.recompute_blocks = recompute_blocks
                    self.recompute_kwargs = recompute_kwargs
                    self.runfunc0 = get_fc_block(0, input_size, is_last=False)
                    self.runfunc1 = get_fc_block(1, input_size, is_last=False)
                    self.runfunc2 = get_fc_block(2, input_size, is_last=False)
                    self.runfunc3 = get_fc_block(3, input_size, is_last=False)
                    self.runfunc4 = get_fc_block(4, input_size, is_last=True)
                    self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]

                def forward(self, inputs):
                    nums = len(self.total_func)
                    for i in range(nums):
                        if i in self.recompute_blocks:
                            inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
                        else:
                            inputs = self.total_func[i](inputs)
                    return inputs

            def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
                gen = paddle.seed(10)
                gen.manual_seed(10)
                np.random.seed(10)
                random.seed(10)
                if cuda_state:
                    paddle.set_cuda_rng_state(cuda_state)

                batch_size, input_size = 1, 10
                model = Naive_fc_net(
                    input_size,
                    recompute_blocks=recompute_block,
                    recompute_kwargs=recompute_kwargs)
                optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                loss_ = []
                param_ = []
                grad_ = []
                for _ in range(5):
                    x_data = np.random.randn(batch_size, input_size).astype(np.float32)
                    x = paddle.to_tensor(x_data)
                    y_pred = model(x)
                    loss = y_pred.mean()
                    loss_.append(np.asarray(loss).tolist())
                    loss.backward()
                    optimizer.step()
                    param_.append(np.asarray(model.parameters()[9]).tolist())
                    grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
                    optimizer.clear_grad()

                return loss_, param_, grad_

            cuda_state = paddle.get_cuda_rng_state()
            # without recompute
            loss_ref, param_ref, grad_ref = run_model(
                cuda_state, recompute_block=[]
            )

            loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
            print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
            # The result of the recompute_loss should be the same as the normal_loss.

J
JZ-LIANG 已提交
458 459 460 461
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
462 463
        raise ValueError("Unexpected keyword arguments: " +
                         ",".join(arg for arg in kwargs))
J
JZ-LIANG 已提交
464

465 466 467
    if framework._dygraph_tracer()._has_grad:
        check_recompute_necessary(args)

S
ShenLiang 已提交
468 469 470 471
    if in_dygraph_mode():
        return EagerRecomputeFunction.apply(function, preserve, *args)
    else:
        return RecomputeFunction.apply(function, preserve, *args)