From 625dd72276d9673f16ebcc889f145340a73fe679 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 4 Apr 2022 18:41:53 +0800 Subject: [PATCH] fix recompute (#41396) --- .../distributed/fleet/utils/recompute.py | 147 +++++++++++++++++- .../tests/unittests/test_dygraph_recompute.py | 111 ++++++------- 2 files changed, 191 insertions(+), 67 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 4ccb48ef72e..c767be77d83 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -14,9 +14,11 @@ import paddle from paddle.fluid import core -from paddle.autograd import PyLayer +from paddle.autograd import PyLayer, EagerPyLayer + from paddle.fluid import framework import contextlib +from paddle.fluid.framework import in_dygraph_mode import logging logger = logging.getLogger(__name__) @@ -32,7 +34,7 @@ __all__ = [] def detach_variable(inputs): out = [] for inp in inputs: - if not isinstance(inp, core.VarBase): + if not isinstance(inp, (core.eager.Tensor, core.VarBase)): out.append(inp) continue @@ -44,7 +46,7 @@ def detach_variable(inputs): def check_recompute_necessary(inputs): if not any(input_.stop_gradient == False for input_ in inputs - if isinstance(input_, paddle.Tensor)): + if isinstance(input_, (core.eager.Tensor, paddle.Tensor))): logger.warn( "[Recompute]: None of the inputs to current recompute block need grad, " "therefore there is NO need to recompute this block in backward !") @@ -60,6 +62,140 @@ def swith_rng_state(rng_state): paddle.set_cuda_rng_state(orig_cuda_rng_state) +class EagerRecomputeFunction(EagerPyLayer): + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + + # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input + # the order of tensors in backward()'s output should be the same as tensors in forward()'s input + # None tensor inputs will be filtered in backward inputs. + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + ctx.save_for_backward(*tensor_inputs) + + # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. + # one process with multiple gpu and mix-gpu-cpu senarios are not support + if ctx.preserve_rng_state: + cur_device = paddle.get_device() + if 'gpu:' not in cur_device: + raise RuntimeError( + "Recompute with RNG perserve is not support current device: {}.". + format(cur_device)) + ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' + else: + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) + + if tracer._amp_dtype == 'float16': + ctx.amp_dtype = 'float16' + elif tracer._amp_dtype in ('bfloat16', 'float32'): + ctx.amp_dtype = 'bfloat16' + else: + raise ValueError("unsupported amp dtype: {}".format( + tracer._amp_dtype)) + + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # TODO need to check the recompute calling is vaild or not + + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensor() + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # paddle.enable_grad() + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # NOTE support AMP + # need restore auto_cast state as well as w/b list + if ctx.preserve_rng_state: + with swith_rng_state(ctx.fw_cuda_rng_state): + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + else: + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, core.eager.Tensor): + outputs = (outputs, ) + assert len(outputs) == len(args) + + # run backward() with only tensor that requires grad + forward_outputs_with_grad = [] + # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output, + # pylayer will force the stop_gradient of attention mask to be False, which will make the number of + # tensor that need grad does not match. + # the following backward_inputs_with_grad is used to avoid this case. + backward_inputs_with_grad = [] + for i in range(len(outputs)): + if isinstance( + outputs[i], + core.eager.Tensor) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs_with_grad.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True, this recompute() is not necessary" + ) + + # actually backward + with paddle.amp.auto_cast(enable=False): + paddle.autograd.backward(forward_outputs_with_grad, + backward_inputs_with_grad) + + grads = tuple( + inp.grad for inp in detached_inputs + if isinstance(inp, core.eager.Tensor)) + return grads + + class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): @@ -315,4 +451,7 @@ def recompute(function, *args, **kwargs): raise ValueError("Unexpected keyword arguments: " + ",".join( arg for arg in kwargs)) - return RecomputeFunction.apply(function, preserve, *args) + if in_dygraph_mode(): + return EagerRecomputeFunction.apply(function, preserve, *args) + else: + return RecomputeFunction.apply(function, preserve, *args) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index 4a4bcd2b816..fa9ea5d086c 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -23,6 +23,7 @@ from paddle.distributed.fleet.utils import recompute import random import paddle.fluid.layers as layers +from paddle.fluid.framework import _test_eager_guard def get_fc_block(block_idx, input_size, is_last=False): @@ -141,96 +142,75 @@ def run_model(recompute_block=[], class TestPyLayer(unittest.TestCase): - def test_fc_net_with_dropout(self): + def test_base_case(self, enable_autocast=False, pure_fp16=False): def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): self.assertEqual(loss_ref, loss) self.assertEqual(param_ref, param) self.assertEqual(grad_ref, grad) # without recompute - loss_ref, param_ref, grad_ref = run_model(recompute_block=[]) - - # recompute second block - loss, param, grad = run_model(recompute_block=[1]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute fourth block - loss, param, grad = run_model(recompute_block=[3]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute second to fourth block - loss, param, grad = run_model(recompute_block=[1, 2, 3]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute second & fourth block - loss, param, grad = run_model(recompute_block=[1, 3]) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - def test_fc_net_without_restore_rng(self): loss_ref, param_ref, grad_ref = run_model( - recompute_block=[2], - recompute_kwargs={"preserve_rng_state": False}, - enable_autocast=True) - - def test_fc_net_with_amp(self): - def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): - self.assertEqual(loss_ref, loss) - self.assertEqual(param_ref, param) - self.assertEqual(grad_ref, grad) - - # without recompute - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[], enable_autocast=True) + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) # recompute second block - loss, param, grad = run_model(recompute_block=[1], enable_autocast=True) + loss, param, grad = run_model( + recompute_block=[1], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute fourth block - loss, param, grad = run_model(recompute_block=[3], enable_autocast=True) + loss, param, grad = run_model( + recompute_block=[3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute second to fourth block loss, param, grad = run_model( - recompute_block=[1, 2, 3], enable_autocast=True) + recompute_block=[1, 2, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # recompute second & fourth block loss, param, grad = run_model( - recompute_block=[1, 3], enable_autocast=True) + recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - def test_fc_net_with_fp16(self): - def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): - self.assertEqual(loss_ref, loss) - self.assertEqual(param_ref, param) - self.assertEqual(grad_ref, grad) - - # without recompute - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[], enable_autocast=True, pure_fp16=True) - - # recompute second block - loss, param, grad = run_model( - recompute_block=[1], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): + with _test_eager_guard(): + self.test_base_case() + self.test_base_case() - # recompute fourth block - loss, param, grad = run_model( - recompute_block=[3], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_without_restore_rng(self): + with _test_eager_guard(): + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], + recompute_kwargs={"preserve_rng_state": False}, + enable_autocast=True) - # recompute second to fourth block - loss, param, grad = run_model( - recompute_block=[1, 2, 3], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_amp(self): + with _test_eager_guard(): + self.test_base_case(enable_autocast=True) + self.test_base_case(enable_autocast=True) - # recompute second & fourth block - loss, param, grad = run_model( - recompute_block=[1, 3], enable_autocast=True, pure_fp16=True) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_fp16(self): + with _test_eager_guard(): + self.test_base_case(enable_autocast=True, pure_fp16=True) + self.test_base_case(enable_autocast=True, pure_fp16=True) def test_recompute_kwargs(self): + with _test_eager_guard(): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(ValueError): + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], recompute_kwargs=kwargs) paddle.set_device("gpu") kwargs = {"is_test": False} with self.assertRaises(ValueError): @@ -238,6 +218,11 @@ class TestPyLayer(unittest.TestCase): recompute_block=[2], recompute_kwargs=kwargs) def test_recompute_cpu_rng(self): + with _test_eager_guard(): + paddle.set_device("cpu") + with self.assertRaises(RuntimeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) + paddle.set_device("cpu") with self.assertRaises(RuntimeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) -- GitLab