diff --git a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu index c4fffce6ccc076372361a3d1a5431a4bf76dc8fc..6a9b1014e9f8c3e49d4f531a0eadad3b2449fae4 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu @@ -55,6 +55,11 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, } else if (index_type == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The data type of input index is expected " + "to be int32 or int64, but recieved %s.", + phi::DataTypeToString(index_type))); } } diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu index 3c919f5acd1cdf0b10ad4cfd712ecabd46701909..a9ff0d99db573b615b2690a6e3486961637c2a7b 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -41,6 +41,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx, phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); } else if (index_type == DataType::INT64) { phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The data type of input index is expected " + "to be int32 or int64, but recieved %s.", + phi::DataTypeToString(index_type))); } } diff --git a/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py b/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py index b5a9c2169ff96b55fa97ea88bd281ddc8565eb1d..6c6c084bcf610fecc9cab3a89506e449d4aa8c63 100644 --- a/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py +++ b/python/paddle/fluid/tests/unittests/test_take_along_axis_op.py @@ -107,6 +107,18 @@ class TestTakeAlongAxisAPI(unittest.TestCase): np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) paddle.enable_static() + def test_api_dygraph_dtype(self): + paddle.disable_static(self.place[0]) + with self.assertRaises(AssertionError): + x_tensor = paddle.to_tensor(self.x_np) + self.index = paddle.to_tensor(self.index_np).astype("float32") + out = paddle.take_along_axis(x_tensor, self.index, self.axis) + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis) + ) + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + paddle.enable_static() + class TestTakeAlongAxisAPICase1(TestTakeAlongAxisAPI): def setUp(self):