diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index b3bf3889a347b5c487ea454543f324a1edfdc63e..ba22372e7914753c55964856274211f48edf3758 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -31,10 +31,24 @@ __all__ = [] def detach_variable(inputs): out = [] for inp in inputs: - if not isinstance(inp, core.eager.Tensor): + if not isinstance(inp, core.eager.Tensor) and ( + type(inp) is not tuple or not isinstance(inp[0], core.eager.Tensor) + ): + # the inp is not a tensor or not a tuple of tensors out.append(inp) continue + if type(inp) is tuple: + detach_inp = [] + for i in inp: + # detach all tensors in the tuple + assert isinstance(i, core.eager.Tensor) + tmp_i = i.detach() + tmp_i.stop_gradient = i.stop_gradient + detach_inp.append(tmp_i) + out.append(tuple(detach_inp)) + continue + x = inp.detach() x.stop_gradient = inp.stop_gradient out.append(x) @@ -42,11 +56,16 @@ def detach_variable(inputs): def check_recompute_necessary(inputs): - if not any( - not input_.stop_gradient - for input_ in inputs - if isinstance(input_, (core.eager.Tensor, paddle.Tensor)) - ): + necessary_for_each_input = [] + for input_ in inputs: + if isinstance(input_, (core.eager.Tensor, paddle.Tensor)): + necessary_for_each_input.append(input_.stop_gradient) + elif type(input_) is tuple: + for i in input_: + # traverse all tensors in the tuple + if isinstance(i, (core.eager.Tensor, paddle.Tensor)): + necessary_for_each_input.append(i.stop_gradient) + if all(necessary_for_each_input): logger.warning( "[Recompute]: None of the inputs to current recompute block need grad, " "therefore there is NO need to recompute this block in backward !" @@ -81,12 +100,37 @@ class RecomputeFunction(PyLayer): # save input for backward ctx.inputs = [] ctx.tensor_indices = [] + ctx.duplicate_tensor = [False for _ in range(len(args))] tensor_inputs = [] for i, arg in enumerate(args): if paddle.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) + elif type(arg) is tuple: + is_tensors = [paddle.is_tensor(a) for a in arg] + if all(is_tensors): + # the tuple is a tuple of tensors + tensors_stop_gradient = [a.stop_gradient for a in arg] + if not all(tensors_stop_gradient) and any( + tensors_stop_gradient + ): + # tensors in the tuple have different stop_gradient value, which pylayer doesn't support + raise ValueError( + "Recompute receive a tuple containing tensor holds different stop gradient." + ) + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + # Mark the tuple is a tuple of tensors + ctx.duplicate_tensor[i] = True + ctx.inputs.append(None) + elif any(is_tensors): + # the tuple contains tensors and non-tensor values + raise ValueError( + "Recompute receive a tuple containing tensor and non-tensor at same time." + ) + else: + ctx.inputs.append(arg) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) @@ -132,6 +176,7 @@ class RecomputeFunction(PyLayer): # Restore inputs inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices + duplicate_tensor = ctx.duplicate_tensor tensors = ctx.saved_tensor() for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] @@ -198,18 +243,23 @@ class RecomputeFunction(PyLayer): forward_outputs_with_grad, backward_inputs_with_grad ) + grads = [] + for idx, inp in enumerate(detached_inputs): + if isinstance(inp, core.eager.Tensor): + grads.append(inp._grad_ivar()) + elif type(inp) is tuple and duplicate_tensor[idx]: + # input is a tuple and is a tuple of tensors + if all(i.stop_gradient for i in inp): + # all tensors in the tuple doesn't need grad, only return a None for the whole tuple + grads.append(None) + else: + # all tensors in the tuple nees grad, should return a tuple of grads + grads.append(tuple(i._grad_ivar() for i in inp)) + if in_dynamic_mode(): - grads = tuple( - inp._grad_ivar() - for inp in detached_inputs - if isinstance(inp, core.eager.Tensor) - ) + grads = tuple(grads) else: - grads = [ - inp._grad_ivar() - for inp in detached_inputs - if isinstance(inp, core.eager.Tensor) - ] + grads = list(grads) return grads diff --git a/test/legacy_test/test_recompute_with_tuple_input.py b/test/legacy_test/test_recompute_with_tuple_input.py new file mode 100644 index 0000000000000000000000000000000000000000..90b6c37dca14a9c56920eafec2f6d5a0957bbf91 --- /dev/null +++ b/test/legacy_test/test_recompute_with_tuple_input.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.distributed.fleet.utils import recompute + + +class Layer(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear1 = paddle.nn.Linear(10, 10) + self.linear2 = paddle.nn.Linear(10, 10) + self.linear3 = paddle.nn.Linear(10, 10) + self.silu1 = paddle.nn.Silu() + self.silu2 = paddle.nn.Silu() + self.silu3 = paddle.nn.Silu() + + def forward(self, x, y): + assert type(x) is tuple + assert len(x) == 2 + o1 = self.silu1(self.linear1(x[0])) + o2 = self.silu2(self.linear2(x[1])) + o3 = self.silu3(self.linear3(y)) + o = o1 + o2 + o3 + return o + + +class TestPyLayer(unittest.TestCase): + def test_tuple_input(self): + layer = Layer() + x1 = paddle.rand(shape=[10, 10]) + x1.stop_gradient = False + x2 = paddle.rand(shape=[10, 10]) + x2.stop_gradient = False + y = paddle.rand(shape=[10, 10]) + y.stop_gradient = False + o = recompute(layer, (x1, x2), y) + loss = paddle.mean(o, keepdim=True) + loss.backward() + + def test_tuple_input_with_non_tensor(self): + layer = Layer() + x1 = paddle.rand(shape=[10, 10]) + x1.stop_gradient = False + y = paddle.rand(shape=[10, 10]) + y.stop_gradient = False + try: + o = recompute(layer, (x1, True), y) + except ValueError: + pass + + def test_tuple_input_with_different_stop_gradient(self): + layer = Layer() + x1 = paddle.rand(shape=[10, 10]) + x1.stop_gradient = False + x2 = paddle.rand(shape=[10, 10]) + y = paddle.rand(shape=[10, 10]) + y.stop_gradient = False + try: + o = recompute(layer, (x1, True), y) + except ValueError: + pass + + def test_tuple_input_all_no_gradient(self): + layer = Layer() + x1 = paddle.rand(shape=[10, 10]) + x2 = paddle.rand(shape=[10, 10]) + y = paddle.rand(shape=[10, 10]) + y.stop_gradient = False + o = recompute(layer, (x1, x2), y) + loss = paddle.mean(o, keepdim=True) + loss.backward() + + +if __name__ == '__main__': + unittest.main()