diff --git a/paddle/phi/kernels/xpu/where_kernel.cc b/paddle/phi/kernels/xpu/where_kernel.cc index ed32d1c631b7b373cf300d2fddaa5af4a8243e6e..e322fece53add10f20d97a689b05fcf3200c745c 100644 --- a/paddle/phi/kernels/xpu/where_kernel.cc +++ b/paddle/phi/kernels/xpu/where_kernel.cc @@ -43,6 +43,7 @@ void WhereKernel(const Context& ctx, int ret = xpu::select( ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select"); } diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 8bea782b74425efc82043fc059ed4236788b0362..dc02baa6d18deb1584c114d32ba50df7b7827f87 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -87,6 +87,7 @@ unary_api_list = [ paddle.lgamma, paddle.poisson, paddle.bernoulli, + paddle.median, ] inplace_api_list = [ @@ -1146,6 +1147,36 @@ class TestSundryAPI(unittest.TestCase): y = paddle.full([], 0.6) self.assertFalse(paddle.allclose(x, y)) + def test_where(self): + x1 = paddle.full([], 1) + x2 = paddle.full([], 2) + x1.stop_gradient = False + x2.stop_gradient = False + out = paddle.where(x1 > x2, x1, x2) + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 2) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0) + self.assertEqual(x2.grad.numpy(), 1) + + def test_atan2(self): + x1 = paddle.full([], 0) + x2 = paddle.full([], 2) + x1.stop_gradient = False + x2.stop_gradient = False + out = paddle.atan2(x1, x2) + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 0) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0.5) + self.assertEqual(x2.grad.numpy(), 0) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -1785,6 +1816,45 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[4].shape, (2,)) self.assertEqual(res[5].shape, (3,)) + @prog_scope() + def test_where(self): + x1 = paddle.full([], 1, 'float32') + x2 = paddle.full([], 2, 'float32') + x1.stop_gradient = False + x2.stop_gradient = False + out = paddle.where(x1 > x2, x1, x2) + loss = paddle.mean(out) + paddle.static.append_backward(loss) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + feed={}, + fetch_list=[out, out.grad_name, x1.grad_name, x2.grad_name], + ) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 2) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[2], 0) + self.assertEqual(res[3].shape, ()) + self.assertEqual(res[3], 1) + + @prog_scope() + def test_atan2(self): + x1 = paddle.full([], 0, 'float32') + x2 = paddle.full([], 2, 'float32') + x1.stop_gradient = False + x2.stop_gradient = False + out = paddle.atan2(x1, x2) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out]) + + self.assertEqual(res[0].shape, ()) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 97925b72beaa76f7c2e69a5f6d467a6ec1b7c976..c87fe306a6b02958c31e74e0dcecd3570f8429f1 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -87,6 +87,7 @@ unary_api_list = [ paddle.lgamma, paddle.poisson, paddle.bernoulli, + paddle.median, ] inplace_api_list = [ @@ -759,6 +760,8 @@ class TestSundryAPI(unittest.TestCase): # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. + + class TestNoBackwardAPI(unittest.TestCase): def setUp(self): paddle.disable_static() @@ -905,6 +908,13 @@ class TestNoBackwardAPI(unittest.TestCase): self.assertEqual(one_hot_label.shape, [4]) self.assertEqual(one_hot_label.numpy()[2], 1) + def test_where(self): + x1 = paddle.full([], 1) + x2 = paddle.full([], 2) + out = paddle.where(x1 > x2, x1, x2) + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 2) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index e2dcbd178ea46c01b652461dced4db5960e919e4..e152c2a366072de90d1d6ce09e32c60e71ebdda0 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -404,6 +404,10 @@ def median(x, axis=None, keepdim=False, name=None): """ if not isinstance(x, Variable): raise TypeError("In median, the input x should be a Tensor.") + + if len(x.shape) == 0: + return x.clone() + is_flatten = axis is None dims = len(x.shape) if is_flatten: