未验证 提交 f2ed4011 编写于 作者: L Leo Chen 提交者: GitHub

suppport optional input for unbind_grad (#54085)

上级 f4abe34b
...@@ -67,6 +67,7 @@ ops_to_fill_zero_for_empty_grads = { ...@@ -67,6 +67,7 @@ ops_to_fill_zero_for_empty_grads = {
"multiply_grad", "multiply_grad",
"divide_grad", "divide_grad",
"matmul_grad", "matmul_grad",
"unbind_grad",
} }
# For API dispatch used at python-level # For API dispatch used at python-level
......
...@@ -280,5 +280,18 @@ class TestUnbindBool(unittest.TestCase): ...@@ -280,5 +280,18 @@ class TestUnbindBool(unittest.TestCase):
np.testing.assert_array_equal(xs[0].numpy(), [True, True]) np.testing.assert_array_equal(xs[0].numpy(), [True, True])
class TestUnbindGradOptionalInput(unittest.TestCase):
def test_grad(self):
a = paddle.zeros([3, 2, 3])
a.stop_gradient = False
x, y = a.unbind(-2)
x.sum().backward() # y_grad is empty
a_grad = a.detach()
a_grad[:, 0, :] = 1
np.testing.assert_array_equal(a.grad.numpy(), a_grad.numpy())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册