提交 f1bb80cc 编写于 作者: D DesmonDay

fix grad and add value assertion

上级 7d315faa
...@@ -150,7 +150,7 @@ void ArgsortGradKernel(const Context& dev_ctx, ...@@ -150,7 +150,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
int64_t size = in_grad->numel(); int64_t size = in_grad->numel();
if (rank == 0) { if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0); phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad);
return; return;
} }
......
...@@ -43,7 +43,7 @@ void ArgsortGradKernel(const Context& dev_ctx, ...@@ -43,7 +43,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
if (out_grad.numel() == 0) return; if (out_grad.numel() == 0) return;
if (rank == 0) { if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0); phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad);
return; return;
} }
......
...@@ -868,10 +868,14 @@ class TestSundryAPI(unittest.TestCase): ...@@ -868,10 +868,14 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
self.assertEqual(out1.numpy(), x1.numpy())
self.assertEqual(out2.numpy(), x2.numpy())
self.assertEqual(out1.grad.shape, []) self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, []) self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, []) self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, []) self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 1)
self.assertEqual(x2.grad.numpy(), 1)
def test_argsort(self): def test_argsort(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
...@@ -886,10 +890,14 @@ class TestSundryAPI(unittest.TestCase): ...@@ -886,10 +890,14 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2.numpy(), 0)
self.assertEqual(out1.grad.shape, []) self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, []) self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, []) self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, []) self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
......
...@@ -659,10 +659,14 @@ class TestSundryAPI(unittest.TestCase): ...@@ -659,10 +659,14 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
self.assertEqual(out1.numpy(), x1.numpy())
self.assertEqual(out2.numpy(), x2.numpy())
self.assertEqual(out1.grad.shape, []) self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, []) self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, []) self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, []) self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 1)
self.assertEqual(x2.grad.numpy(), 1)
def test_argsort(self): def test_argsort(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册