From f1bb80ccaf84e487f5e000763fedfdc70de20afd Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 4 Jan 2023 08:39:53 +0000 Subject: [PATCH] fix grad and add value assertion --- paddle/phi/kernels/gpu/argsort_grad_kernel.cu | 2 +- paddle/phi/kernels/xpu/argsort_grad_kernel.cc | 2 +- .../paddle/fluid/tests/unittests/test_zero_dim_tensor.py | 8 ++++++++ .../fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py | 4 ++++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index f28da8704cb..b8d9df64c23 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -150,7 +150,7 @@ void ArgsortGradKernel(const Context& dev_ctx, int64_t size = in_grad->numel(); if (rank == 0) { - phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad); return; } diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc index 00c679f0ab9..96cce046178 100644 --- a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -43,7 +43,7 @@ void ArgsortGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; if (rank == 0) { - phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad); return; } diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 5aec0c8010f..546c0c48f9b 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -868,10 +868,14 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), x1.numpy()) + self.assertEqual(out2.numpy(), x2.numpy()) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 1) + self.assertEqual(x2.grad.numpy(), 1) def test_argsort(self): x1 = paddle.rand([]) @@ -886,10 +890,14 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), 0) + self.assertEqual(out2.numpy(), 0) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0) + self.assertEqual(x2.grad.numpy(), 0) class TestSundryAPIStatic(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 8cb27ecf099..221c46228ea 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -659,10 +659,14 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), x1.numpy()) + self.assertEqual(out2.numpy(), x2.numpy()) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 1) + self.assertEqual(x2.grad.numpy(), 1) def test_argsort(self): x1 = paddle.rand([]) -- GitLab