未验证 提交 bb2310a6 编写于 作者: Y Yuang Liu 提交者: GitHub

recompute support tuple (#56793)

上级 23bc4c26
......@@ -31,10 +31,24 @@ __all__ = []
def detach_variable(inputs):
out = []
for inp in inputs:
if not isinstance(inp, core.eager.Tensor):
if not isinstance(inp, core.eager.Tensor) and (
type(inp) is not tuple or not isinstance(inp[0], core.eager.Tensor)
):
# the inp is not a tensor or not a tuple of tensors
out.append(inp)
continue
if type(inp) is tuple:
detach_inp = []
for i in inp:
# detach all tensors in the tuple
assert isinstance(i, core.eager.Tensor)
tmp_i = i.detach()
tmp_i.stop_gradient = i.stop_gradient
detach_inp.append(tmp_i)
out.append(tuple(detach_inp))
continue
x = inp.detach()
x.stop_gradient = inp.stop_gradient
out.append(x)
......@@ -42,11 +56,16 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs):
if not any(
not input_.stop_gradient
for input_ in inputs
if isinstance(input_, (core.eager.Tensor, paddle.Tensor))
):
necessary_for_each_input = []
for input_ in inputs:
if isinstance(input_, (core.eager.Tensor, paddle.Tensor)):
necessary_for_each_input.append(input_.stop_gradient)
elif type(input_) is tuple:
for i in input_:
# traverse all tensors in the tuple
if isinstance(i, (core.eager.Tensor, paddle.Tensor)):
necessary_for_each_input.append(i.stop_gradient)
if all(necessary_for_each_input):
logger.warning(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !"
......@@ -81,12 +100,37 @@ class RecomputeFunction(PyLayer):
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.duplicate_tensor = [False for _ in range(len(args))]
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
elif type(arg) is tuple:
is_tensors = [paddle.is_tensor(a) for a in arg]
if all(is_tensors):
# the tuple is a tuple of tensors
tensors_stop_gradient = [a.stop_gradient for a in arg]
if not all(tensors_stop_gradient) and any(
tensors_stop_gradient
):
# tensors in the tuple have different stop_gradient value, which pylayer doesn't support
raise ValueError(
"Recompute receive a tuple containing tensor holds different stop gradient."
)
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
# Mark the tuple is a tuple of tensors
ctx.duplicate_tensor[i] = True
ctx.inputs.append(None)
elif any(is_tensors):
# the tuple contains tensors and non-tensor values
raise ValueError(
"Recompute receive a tuple containing tensor and non-tensor at same time."
)
else:
ctx.inputs.append(arg)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
......@@ -132,6 +176,7 @@ class RecomputeFunction(PyLayer):
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
duplicate_tensor = ctx.duplicate_tensor
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
......@@ -198,18 +243,23 @@ class RecomputeFunction(PyLayer):
forward_outputs_with_grad, backward_inputs_with_grad
)
grads = []
for idx, inp in enumerate(detached_inputs):
if isinstance(inp, core.eager.Tensor):
grads.append(inp._grad_ivar())
elif type(inp) is tuple and duplicate_tensor[idx]:
# input is a tuple and is a tuple of tensors
if all(i.stop_gradient for i in inp):
# all tensors in the tuple doesn't need grad, only return a None for the whole tuple
grads.append(None)
else:
# all tensors in the tuple nees grad, should return a tuple of grads
grads.append(tuple(i._grad_ivar() for i in inp))
if in_dynamic_mode():
grads = tuple(
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.eager.Tensor)
)
grads = tuple(grads)
else:
grads = [
inp._grad_ivar()
for inp in detached_inputs
if isinstance(inp, core.eager.Tensor)
]
grads = list(grads)
return grads
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.distributed.fleet.utils import recompute
class Layer(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear1 = paddle.nn.Linear(10, 10)
self.linear2 = paddle.nn.Linear(10, 10)
self.linear3 = paddle.nn.Linear(10, 10)
self.silu1 = paddle.nn.Silu()
self.silu2 = paddle.nn.Silu()
self.silu3 = paddle.nn.Silu()
def forward(self, x, y):
assert type(x) is tuple
assert len(x) == 2
o1 = self.silu1(self.linear1(x[0]))
o2 = self.silu2(self.linear2(x[1]))
o3 = self.silu3(self.linear3(y))
o = o1 + o2 + o3
return o
class TestPyLayer(unittest.TestCase):
def test_tuple_input(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x1.stop_gradient = False
x2 = paddle.rand(shape=[10, 10])
x2.stop_gradient = False
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
o = recompute(layer, (x1, x2), y)
loss = paddle.mean(o, keepdim=True)
loss.backward()
def test_tuple_input_with_non_tensor(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x1.stop_gradient = False
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
try:
o = recompute(layer, (x1, True), y)
except ValueError:
pass
def test_tuple_input_with_different_stop_gradient(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x1.stop_gradient = False
x2 = paddle.rand(shape=[10, 10])
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
try:
o = recompute(layer, (x1, True), y)
except ValueError:
pass
def test_tuple_input_all_no_gradient(self):
layer = Layer()
x1 = paddle.rand(shape=[10, 10])
x2 = paddle.rand(shape=[10, 10])
y = paddle.rand(shape=[10, 10])
y.stop_gradient = False
o = recompute(layer, (x1, x2), y)
loss = paddle.mean(o, keepdim=True)
loss.backward()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册