From f2ed40112e2604237024cce6ee3f2415bdf4c3d8 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 24 May 2023 23:18:49 +0800 Subject: [PATCH] suppport optional input for unbind_grad (#54085) --- .../auto_code_generator/generator/codegen_utils.py | 1 + .../paddle/fluid/tests/unittests/test_unbind_op.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py index f51d270cc8a..0ec006555e4 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py @@ -67,6 +67,7 @@ ops_to_fill_zero_for_empty_grads = { "multiply_grad", "divide_grad", "matmul_grad", + "unbind_grad", } # For API dispatch used at python-level diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 989eb43b050..763aa2c3f24 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -280,5 +280,18 @@ class TestUnbindBool(unittest.TestCase): np.testing.assert_array_equal(xs[0].numpy(), [True, True]) +class TestUnbindGradOptionalInput(unittest.TestCase): + def test_grad(self): + a = paddle.zeros([3, 2, 3]) + a.stop_gradient = False + x, y = a.unbind(-2) + x.sum().backward() # y_grad is empty + + a_grad = a.detach() + a_grad[:, 0, :] = 1 + + np.testing.assert_array_equal(a.grad.numpy(), a_grad.numpy()) + + if __name__ == '__main__': unittest.main() -- GitLab