未验证 提交 1508cae7 编写于 作者: R ronnywang 提交者: GitHub

[Zero-Dim] add where, atan2, median 0-Dim ut (#49692)

* add where, atan2, median 0d ut

* add where, atan2, median 0d ut

* update

* update

* update
上级 690d7a69
......@@ -43,6 +43,7 @@ void WhereKernel(const Context& ctx,
int ret = xpu::select(
ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select");
}
......
......@@ -87,6 +87,7 @@ unary_api_list = [
paddle.lgamma,
paddle.poisson,
paddle.bernoulli,
paddle.median,
]
inplace_api_list = [
......@@ -1146,6 +1147,36 @@ class TestSundryAPI(unittest.TestCase):
y = paddle.full([], 0.6)
self.assertFalse(paddle.allclose(x, y))
def test_where(self):
x1 = paddle.full([], 1)
x2 = paddle.full([], 2)
x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.where(x1 > x2, x1, x2)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 2)
self.assertEqual(out.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 1)
def test_atan2(self):
x1 = paddle.full([], 0)
x2 = paddle.full([], 2)
x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.atan2(x1, x2)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)
self.assertEqual(out.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 0.5)
self.assertEqual(x2.grad.numpy(), 0)
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -1785,6 +1816,45 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[4].shape, (2,))
self.assertEqual(res[5].shape, (3,))
@prog_scope()
def test_where(self):
x1 = paddle.full([], 1, 'float32')
x2 = paddle.full([], 2, 'float32')
x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.where(x1 > x2, x1, x2)
loss = paddle.mean(out)
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
feed={},
fetch_list=[out, out.grad_name, x1.grad_name, x2.grad_name],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 2)
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[2], 0)
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)
@prog_scope()
def test_atan2(self):
x1 = paddle.full([], 0, 'float32')
x2 = paddle.full([], 2, 'float32')
x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.atan2(x1, x2)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out])
self.assertEqual(res[0].shape, ())
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -87,6 +87,7 @@ unary_api_list = [
paddle.lgamma,
paddle.poisson,
paddle.bernoulli,
paddle.median,
]
inplace_api_list = [
......@@ -759,6 +760,8 @@ class TestSundryAPI(unittest.TestCase):
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
......@@ -905,6 +908,13 @@ class TestNoBackwardAPI(unittest.TestCase):
self.assertEqual(one_hot_label.shape, [4])
self.assertEqual(one_hot_label.numpy()[2], 1)
def test_where(self):
x1 = paddle.full([], 1)
x2 = paddle.full([], 2)
out = paddle.where(x1 > x2, x1, x2)
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 2)
if __name__ == "__main__":
unittest.main()
......@@ -404,6 +404,10 @@ def median(x, axis=None, keepdim=False, name=None):
"""
if not isinstance(x, Variable):
raise TypeError("In median, the input x should be a Tensor.")
if len(x.shape) == 0:
return x.clone()
is_flatten = axis is None
dims = len(x.shape)
if is_flatten:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册