diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index f429cf306268f9627753f09814bf9a80a773a5ea..ae4bbae6a69a08c182f6bd20f873d0022ec2b57d 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -13,12 +13,16 @@ # limitations under the License. import contextlib +import weakref import paddle +from paddle import framework from paddle.autograd import PyLayer from paddle.autograd.py_layer import LegacyPyLayer -from paddle.fluid import core, framework -from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, +) +from paddle.framework import core, in_dygraph_mode from ..utils.log_util import logger @@ -52,10 +56,6 @@ def check_recompute_necessary(inputs): @contextlib.contextmanager def swith_rng_state_tracker(rng_state, tracker): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( - get_rng_state_tracker, - ) - orig_cuda_rng_state = paddle.get_cuda_rng_state() orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() @@ -71,10 +71,6 @@ def swith_rng_state_tracker(rng_state, tracker): class LegacyRecomputeFunction(LegacyPyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( - get_rng_state_tracker, - ) - # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state @@ -223,10 +219,6 @@ class LegacyRecomputeFunction(LegacyPyLayer): class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( - get_rng_state_tracker, - ) - # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state @@ -382,6 +374,116 @@ class RecomputeFunction(PyLayer): return grads +def _recompute_without_reentrant( + function, preserve_rng_state=True, *args, **kwargs +): + """ + recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd. + """ + + if preserve_rng_state: + cur_device = paddle.get_device() + if 'gpu:' not in cur_device: + raise RuntimeError( + "Recompute with RNG perserve is not support current device: {}.".format( + cur_device + ) + ) + fw_cuda_rng_state = paddle.get_cuda_rng_state() + fwd_cuda_rng_state_tracker = ( + get_rng_state_tracker().get_states_tracker() + ) + tracer = framework._dygraph_tracer() + is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + amp_level = 'O1' + + if tracer._amp_dtype == 'float16': + amp_dtype = 'float16' + elif tracer._amp_dtype in ('bfloat16', 'float32'): + amp_dtype = 'bfloat16' + + amp_white_list, amp_black_list = tracer._get_amp_op_list() + + class Intermediate_Holder: + pass + + storage = weakref.WeakKeyDictionary() + holder_list = [] + + def pack(x): + res = Intermediate_Holder() + holder_list.append(weakref.ref(res)) + return res + + def unpack(x): + unpack_counter = 0 + if len(storage) == 0: + + def inner_pack(inner_x): + nonlocal unpack_counter + unpack_counter += 1 + + if holder_list[unpack_counter - 1]() is None: + return + + tmp_tensor = core.eager.Tensor( + inner_x.dtype, + inner_x.shape, + inner_x.name + "cpy", + core.VarDesc.VarType.LOD_TENSOR, + inner_x.persistable, + ) + inner_x._share_buffer_to(tmp_tensor) + storage[holder_list[unpack_counter - 1]()] = tmp_tensor + return + + def inner_unpack(inner_x): + raise Exception("An unexcepted backward called on a tensor!") + + if preserve_rng_state: + with swith_rng_state_tracker( + fw_cuda_rng_state, fwd_cuda_rng_state_tracker + ): + with paddle.set_grad_enabled(True): + with paddle.amp.auto_cast( + enable=is_fw_autocast, + custom_white_list=amp_white_list, + custom_black_list=amp_black_list, + level=amp_level, + dtype=amp_dtype, + ): + with paddle.autograd.saved_tensors_hooks( + inner_pack, inner_unpack + ): + unused_outputs = function(*args, **kwargs) + else: + with paddle.set_grad_enabled(True), paddle.amp.auto_cast( + enable=is_fw_autocast, + custom_white_list=amp_white_list, + custom_black_list=amp_black_list, + level=amp_level, + dtype=amp_dtype, + ), paddle.autograd.saved_tensors_hooks( + inner_pack, inner_unpack + ): + unused_outputs = function(*args, **kwargs) + + if x not in storage: + raise Exception( + "Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute." + ) + + return storage[x] + + with paddle.autograd.saved_tensors_hooks(pack, unpack): + outputs = function(*args, **kwargs) + + return outputs + + def recompute(function, *args, **kwargs): """ recompute intermediate activations to save then memory. @@ -391,11 +493,13 @@ def recompute(function, *args, **kwargs): whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. - + **kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params, + and the other contains 'preserve_rng_state' and 'use_reentrant'. the key-value pair of preserve_rng_state, + which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value + will be restored when the forward recalculation of backpropagation is performed, its default value is True. + the key-value pair of use_reentrant is used to indicate which implementation of recompute you will be used. + 'use_reentrant=True' means to use the PyLayer implementation of recompute, 'use_reentrant=False' means to + use the Hook implementation of recompute, its default value is True. Returns: Output of function on args. @@ -487,10 +591,21 @@ def recompute(function, *args, **kwargs): # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) + # whether to use reentrant method to implement recompute + use_reentrant = kwargs.pop('use_reentrant', True) + + if kwargs and use_reentrant: + raise ValueError( + "Error, if you want to send kwargs(dict parameter) to function, please set use_reentrant=False." + ) + if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args, **kwargs) + if use_reentrant: + return RecomputeFunction.apply(function, preserve, *args) + else: + return _recompute_without_reentrant(function, preserve, *args, **kwargs) def recompute_sequential(ctx, functions, *args, **kwargs): diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index 39646b8066d4b5d7d7f4fa5be45bef2d7393b882..4c0db922338ed36586707948062d88db0f0f031d 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -272,7 +272,7 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): loss_ref, param_ref, grad_ref = run_model( recompute_block=[2], recompute_kwargs=kwargs ) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index 087609b62da070934addf8b6467502367ace4af9..5e982587c252c4a365213bd050a7e4c6320f5863 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -21,32 +21,44 @@ import paddle from paddle.distributed.fleet.utils import recompute +class Model(paddle.nn.Layer): + def __init__(self, block_idx, input_size, is_last=False): + super(Model, self).__init__() + block_name = "block_" + str(block_idx) + self.block = paddle.nn.Sequential( + ( + block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False), + ), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + ( + block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False), + ), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + self.block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, bias_attr=False), + ) # add sublayer + else: + self.block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear(input_size, input_size, bias_attr=False), + ) # add sublayer + + # add pos param for test kwargs of recompute. + def forward(self, x, pos=None): + if pos is None: + return self.block(x) + else: + return self.block(x) + pos + + def get_fc_block(block_idx, input_size, is_last=False): - block_name = "block_" + str(block_idx) - block = paddle.nn.Sequential( - ( - block_name + "_fc_0", - paddle.nn.Linear(input_size, input_size, bias_attr=False), - ), - (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), - (block_name + "_relu_1", paddle.nn.ReLU()), - ( - block_name + "_fc_1", - paddle.nn.Linear(input_size, input_size, bias_attr=False), - ), - (block_name + "_relu_2", paddle.nn.ReLU()), - ) - if is_last: - block.add_sublayer( - block_name + "_fc_2", - paddle.nn.Linear(input_size, 1, bias_attr=False), - ) # add sublayer - else: - block.add_sublayer( - block_name + "_fc_2", - paddle.nn.Linear(input_size, input_size, bias_attr=False), - ) # add sublayer - return block + return Model(block_idx, input_size, is_last=False) class Naive_fc_net(paddle.nn.Layer): @@ -143,6 +155,10 @@ def run_model( segments=segments, recompute_kwargs=recompute_kwargs, ) + + if pure_fp16: + model = paddle.amp.decorate(models=model, level='O2') + loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD( learning_rate=0.01, parameters=model.parameters() @@ -158,7 +174,7 @@ def run_model( x_data = np.random.randn(batch_size, input_size).astype(np.float32) x = paddle.to_tensor(x_data) - # x.stop_gradient = False + x.stop_gradient = False level = 'O2' if pure_fp16 else 'O1' with paddle.amp.auto_cast(True, level=level): y_pred = model(x) @@ -178,7 +194,7 @@ def run_model( return loss_, param_, grad_ -class TestPyLayer(unittest.TestCase): +class TestRecompute(unittest.TestCase): def test_base_case(self, enable_autocast=False, pure_fp16=False): def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): self.assertEqual(loss_ref, loss) @@ -192,46 +208,55 @@ class TestPyLayer(unittest.TestCase): pure_fp16=pure_fp16, ) - # recompute second block - loss, param, grad = run_model( - recompute_block=[1], - enable_autocast=enable_autocast, - pure_fp16=pure_fp16, - ) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute fourth block - loss, param, grad = run_model( - recompute_block=[3], - enable_autocast=enable_autocast, - pure_fp16=pure_fp16, - ) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute second to fourth block - loss, param, grad = run_model( - recompute_block=[1, 2, 3], - enable_autocast=enable_autocast, - pure_fp16=pure_fp16, - ) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute second & fourth block - loss, param, grad = run_model( - recompute_block=[1, 3], - enable_autocast=enable_autocast, - pure_fp16=pure_fp16, - ) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - - # recompute_sequential with segments=1 using fleet - loss, param, grad = run_model( - recompute_block=[], - use_fleet_sq=True, - enable_autocast=enable_autocast, - pure_fp16=pure_fp16, - ) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # test for recompute + # True: PyLayer of recompute + # False: HooK of recompute + for flag in [True, False]: + # recompute second block + loss, param, grad = run_model( + recompute_block=[1], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16, + recompute_kwargs={"use_reentrant": flag}, + ) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute fourth block + loss, param, grad = run_model( + recompute_block=[3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16, + recompute_kwargs={"use_reentrant": flag}, + ) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second to fourth block + loss, param, grad = run_model( + recompute_block=[1, 2, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16, + recompute_kwargs={"use_reentrant": flag}, + ) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second & fourth block + loss, param, grad = run_model( + recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16, + recompute_kwargs={"use_reentrant": flag}, + ) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute_sequential with segments=1 using fleet + loss, param, grad = run_model( + recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16, + recompute_kwargs={"use_reentrant": flag}, + ) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) # with base recompute, and segments=2 loss_ref, param_ref, grad_ref = run_model( @@ -255,11 +280,15 @@ class TestPyLayer(unittest.TestCase): self.test_base_case() def test_fc_net_without_restore_rng(self): - loss_ref, param_ref, grad_ref = run_model( - recompute_block=[2], - recompute_kwargs={"preserve_rng_state": False}, - enable_autocast=True, - ) + for flag in [True, False]: + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], + recompute_kwargs={ + "preserve_rng_state": False, + "use_reentrant": flag, + }, + enable_autocast=True, + ) def test_fc_net_with_amp(self): self.test_base_case(enable_autocast=True) @@ -269,16 +298,28 @@ class TestPyLayer(unittest.TestCase): def test_recompute_kwargs(self): paddle.set_device("gpu") - kwargs = {"is_test": False} - with self.assertRaises(TypeError): + pos = paddle.randn(shape=[10, 10], dtype="float32") + pos.stop_gradient = False + + kwargs = {"pos": pos, "use_reentrant": True} + with self.assertRaises(ValueError): loss_ref, param_ref, grad_ref = run_model( recompute_block=[2], recompute_kwargs=kwargs ) + kwargs = {"pos": pos, "use_reentrant": False} + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], recompute_kwargs=kwargs + ) + def test_recompute_cpu_rng(self): paddle.set_device("cpu") - with self.assertRaises(RuntimeError): - loss_ref, param_ref, grad_ref = run_model(recompute_block=[2]) + for flag in [True, False]: + with self.assertRaises(RuntimeError): + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[2], + recompute_kwargs={"use_reentrant": flag}, + ) if __name__ == '__main__':