未验证 提交 1d128326 编写于 作者: X Xiaoxu Chen 提交者: GitHub

fix duplicate slice logic in _grad (#44396)

上级 4c1e77d1
...@@ -223,6 +223,11 @@ class TestVJP(TestAutogradFunctional): ...@@ -223,6 +223,11 @@ class TestVJP(TestAutogradFunctional):
self.func_vjp_nested() self.func_vjp_nested()
self.func_vjp_aliased_input() self.func_vjp_aliased_input()
def test_input_single_tensor(self):
self.assertIsInstance(
paddle.incubate.autograd.vjp(paddle.tanh, paddle.rand((3, 4)))[1],
paddle.fluid.framework.Variable)
@utils.place(config.DEVICES) @utils.place(config.DEVICES)
@utils.parameterize( @utils.parameterize(
......
...@@ -565,13 +565,15 @@ def _grad(ys, xs, v=None): ...@@ -565,13 +565,15 @@ def _grad(ys, xs, v=None):
inputs. inputs.
""" """
if paddle.fluid._non_static_mode(): if paddle.fluid._non_static_mode():
# paddle.grad returns a list though the inputs is a signle Tensor. The
# follow code snippet fixes the problem by return the first element of
# xs_grad when the xs is a signle Tensor.
xs_grad = paddle.grad(ys, xs, v, create_graph=True, allow_unused=True) xs_grad = paddle.grad(ys, xs, v, create_graph=True, allow_unused=True)
if isinstance(xs, paddle.fluid.framework.Variable) and isinstance(
xs_grad, typing.Sequence) and len(xs_grad) > 0:
xs_grad = xs_grad[0]
else: else:
xs_grad = paddle.incubate.autograd.grad(ys, xs, v) xs_grad = paddle.incubate.autograd.grad(ys, xs, v)
if isinstance(xs, paddle.fluid.framework.Variable):
xs_grad = xs_grad[0]
return _replace_none_with_zero_tensor(xs_grad, xs) return _replace_none_with_zero_tensor(xs_grad, xs)
......
...@@ -132,9 +132,16 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -132,9 +132,16 @@ def grad(outputs, inputs, grad_outputs=None):
paddle.incubate.autograd.disable_prim() paddle.incubate.autograd.disable_prim()
paddle.disable_static() paddle.disable_static()
""" """
if not utils.prim_enabled(): if not utils.prim_enabled():
return backward.gradients(outputs, inputs, grad_outputs) grad_inputs = backward.gradients(outputs, inputs, grad_outputs)
# backward.gradients returns a list though the inputs is a signle Tensor.
# The follow code snippet fixes the problem by return the first element
# of grad_inputs when the inputs is a signle Tensor.
if isinstance(inputs, framework.Variable) and isinstance(
grad_inputs, typing.Sequence) and len(grad_inputs) > 0:
return grad_inputs[0]
else:
return grad_inputs
if not isinstance(outputs, (framework.Variable, typing.Sequence)): if not isinstance(outputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected outputs is Tensor|Sequence[Tesnor], ' raise TypeError(f'Expected outputs is Tensor|Sequence[Tesnor], '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册