提交 f1bb80cc 编写于 作者: D DesmonDay

fix grad and add value assertion

上级 7d315faa
......@@ -150,7 +150,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
int64_t size = in_grad->numel();
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0);
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad);
return;
}
......
......@@ -43,7 +43,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
if (out_grad.numel() == 0) return;
if (rank == 0) {
phi::funcs::set_constant(dev_ctx, in_grad, 1.0);
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad);
return;
}
......
......@@ -868,10 +868,14 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.numpy(), x1.numpy())
self.assertEqual(out2.numpy(), x2.numpy())
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 1)
self.assertEqual(x2.grad.numpy(), 1)
def test_argsort(self):
x1 = paddle.rand([])
......@@ -886,10 +890,14 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2.numpy(), 0)
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)
class TestSundryAPIStatic(unittest.TestCase):
......
......@@ -659,10 +659,14 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1.numpy(), x1.numpy())
self.assertEqual(out2.numpy(), x2.numpy())
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 1)
self.assertEqual(x2.grad.numpy(), 1)
def test_argsort(self):
x1 = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册