diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 774e8db0df52c9f785dfc2a68d22370b2e96f1a6..0a47750ead7ec98bd3275e107605a96cddec1e6b 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -14,3 +14,4 @@ from .fs import LocalFS, HDFSClient from .ps_util import DistributedInfer +from .recompute import recompute diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc305ec77d51728cdde20bb504371127823d61d --- /dev/null +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import core +from paddle.autograd import PyLayer +from paddle.fluid import framework +import contextlib + +import logging +logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + +def detach_variable(inputs): + out = [] + for inp in inputs: + if not isinstance(inp, core.VarBase): + out.append(inp) + continue + + x = inp.detach() + x.stop_gradient = inp.stop_gradient + out.append(x) + return tuple(out) + + +def check_recompute_necessary(inputs): + if not any(input_.stop_gradient == False for input_ in inputs + if isinstance(input_, paddle.Tensor)): + logging.warn( + "[Recompute]: None of the inputs to current recompute block need grad, " + "therefore there is NO need to recompute this block in backward !") + + +@contextlib.contextmanager +def swith_rng_state(rng_state): + orig_cuda_rng_state = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(rng_state) + try: + yield + finally: + paddle.set_cuda_rng_state(orig_cuda_rng_state) + + +class RecomputeFunction(PyLayer): + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + + # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input + # the order of tensors in backward()'s output should be the same as tensors in forward()'s input + # None tensor inputs will be filtered in backward inputs. + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + ctx.save_for_backward(*tensor_inputs) + + # NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. + # one process with multiple gpu and mix-gpu-cpu senarios are not support + if ctx.preserve_rng_state: + cur_device = paddle.get_device() + if 'gpu:' not in cur_device: + raise RuntimeError( + "Recompute with RNG perserve is not support current device: {}.". + format(cur_device)) + ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() + + # TODO support AMP + + with paddle.no_grad(): + outputs = run_function(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # TODO need to check the recompute calling is vaild or not + + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensor() + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # paddle.enable_grad() + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # TODO support AMP + + if ctx.preserve_rng_state: + with swith_rng_state(ctx.fw_cuda_rng_state): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + else: + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, core.VarBase): + outputs = (outputs, ) + assert len(outputs) == len(args) + + # run backward() with only tensor that requires grad + forward_outputs_with_grad = [] + backward_inputs = list(args) + for i in range(len(outputs)): + if isinstance(outputs[i], + core.VarBase) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True, this recompute() is not necessary" + ) + + assert len(backward_inputs) == len( + forward_outputs_with_grad + ), "number of forward outputs is [{}], but the backward got [{}] inputs".format( + len(forward_outputs_with_grad), len(backward_inputs)) + + # actually backward + paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + + grads = list(inp._grad_ivar() for inp in detached_inputs + if isinstance(inp, core.VarBase)) + + return grads + + +def recompute(function, *args, **kwargs): + """ + recompute intermediate activations to save then memory. + + Args: + function: layer of sequence of layers that describes part of forward pass of the model whose + intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + preserve_rng_state(bool, optional): if preserve the RNG state of forward and restore it in backward. + args: inputs to the function + + Returns: + Output of function on args + """ + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs: + raise ValueError("Unexpected keyword arguments: " + ",".join( + arg for arg in kwargs)) + + return RecomputeFunction.apply(function, preserve, *args) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2e68dd899ee7be0a92b005e4626440e977f57226..ae3cf5f285836e0c8a680085440402b98232d131 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -176,6 +176,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_layer) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single) + LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py new file mode 100755 index 0000000000000000000000000000000000000000..6de04c14bfa7080bcbf5e3b4c55f98da0f09a863 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils import recompute +import random + +import paddle.fluid.layers as layers + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", paddle.nn.Linear( + input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", paddle.nn.Linear( + input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), ) + if is_last: + block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear( + input_size, 1, bias_attr=False)) # add sublayer + else: + block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear( + input_size, input_size, bias_attr=False)) # add sublayer + return block + + +class Naive_fc_net(paddle.nn.Layer): + def __init__(self, + input_size=10, + recompute_blocks=[1, 3], + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + def forward(self, inputs): + + if 0 in self.recompute_blocks: + inputs = recompute(self.runfunc0, inputs) + else: + inputs = self.runfunc0(inputs) + + if 1 in self.recompute_blocks: + inputs = recompute(self.runfunc1, inputs) + else: + inputs = self.runfunc1(inputs) + + if 2 in self.recompute_blocks: + inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) + else: + inputs = self.runfunc2(inputs) + + if 3 in self.recompute_blocks: + inputs = recompute(self.runfunc3, inputs) + else: + inputs = self.runfunc3(inputs) + + if 4 in self.recompute_blocks: + inputs = recompute(self.runfunc4, inputs) + else: + inputs = self.runfunc4(inputs) + + return inputs + + +def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + + if cuda_state: + paddle.set_cuda_rng_state(cuda_state) + + batch_size, input_size = 1, 10 + model = Naive_fc_net( + input_size, + recompute_blocks=recompute_block, + recompute_kwargs=recompute_kwargs) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + + loss_ = [] + param_ = [] + grad_ = [] + for step in range(10): + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + # x.stop_gradient = False + y_pred = model(x) + loss = y_pred.mean() + + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + + param_.append(np.asarray(model.parameters()[9]).tolist()) + grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + + optimizer.clear_grad() + return loss_, param_, grad_ + + +class TestPyLayer(unittest.TestCase): + def test_fc_net_with_dropout(self): + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + cuda_state = paddle.get_cuda_rng_state() + # without recompute + loss_ref, param_ref, grad_ref = run_model( + cuda_state, recompute_block=[]) + + # recompute second block + loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute fourth block + loss, param, grad = run_model(cuda_state, recompute_block=[3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second to fourth block + loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second & fourth block + loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_recompute_kwargs(self): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(ValueError): + loss_ref, param_ref, grad_ref = run_model( + None, recompute_block=[2], recompute_kwargs=kwargs) + + def test_recompute_cpu_rng(self): + paddle.set_device("cpu") + with self.assertRaises(RuntimeError): + loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2]) + + +if __name__ == '__main__': + unittest.main()