From 1d12832669460a7a6f822e7b454010c0918271fe Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Mon, 18 Jul 2022 18:33:23 +0800 Subject: [PATCH] fix duplicate slice logic in _grad (#44396) --- .../autograd/test_autograd_functional_dynamic.py | 5 +++++ python/paddle/incubate/autograd/functional.py | 10 ++++++---- python/paddle/incubate/autograd/primapi.py | 11 +++++++++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py index 4b61580452..5eda21eb4c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py @@ -223,6 +223,11 @@ class TestVJP(TestAutogradFunctional): self.func_vjp_nested() self.func_vjp_aliased_input() + def test_input_single_tensor(self): + self.assertIsInstance( + paddle.incubate.autograd.vjp(paddle.tanh, paddle.rand((3, 4)))[1], + paddle.fluid.framework.Variable) + @utils.place(config.DEVICES) @utils.parameterize( diff --git a/python/paddle/incubate/autograd/functional.py b/python/paddle/incubate/autograd/functional.py index 6c740005f8..3be95c88d1 100644 --- a/python/paddle/incubate/autograd/functional.py +++ b/python/paddle/incubate/autograd/functional.py @@ -565,13 +565,15 @@ def _grad(ys, xs, v=None): inputs. """ if paddle.fluid._non_static_mode(): + # paddle.grad returns a list though the inputs is a signle Tensor. The + # follow code snippet fixes the problem by return the first element of + # xs_grad when the xs is a signle Tensor. xs_grad = paddle.grad(ys, xs, v, create_graph=True, allow_unused=True) + if isinstance(xs, paddle.fluid.framework.Variable) and isinstance( + xs_grad, typing.Sequence) and len(xs_grad) > 0: + xs_grad = xs_grad[0] else: xs_grad = paddle.incubate.autograd.grad(ys, xs, v) - - if isinstance(xs, paddle.fluid.framework.Variable): - xs_grad = xs_grad[0] - return _replace_none_with_zero_tensor(xs_grad, xs) diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 5b3ad0dd78..a319874e25 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -132,9 +132,16 @@ def grad(outputs, inputs, grad_outputs=None): paddle.incubate.autograd.disable_prim() paddle.disable_static() """ - if not utils.prim_enabled(): - return backward.gradients(outputs, inputs, grad_outputs) + grad_inputs = backward.gradients(outputs, inputs, grad_outputs) + # backward.gradients returns a list though the inputs is a signle Tensor. + # The follow code snippet fixes the problem by return the first element + # of grad_inputs when the inputs is a signle Tensor. + if isinstance(inputs, framework.Variable) and isinstance( + grad_inputs, typing.Sequence) and len(grad_inputs) > 0: + return grad_inputs[0] + else: + return grad_inputs if not isinstance(outputs, (framework.Variable, typing.Sequence)): raise TypeError(f'Expected outputs is Tensor|Sequence[Tesnor], ' -- GitLab