未验证 提交 3641c5ec 编写于 作者: W wuhuachaocoding 提交者: GitHub

add throw exception when index type is wrong. (#53674)

上级 f48611f3
......@@ -55,6 +55,11 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
} else if (index_type == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The data type of input index is expected "
"to be int32 or int64, but recieved %s.",
phi::DataTypeToString(index_type)));
}
}
......
......@@ -41,6 +41,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
phi::funcs::gpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::gpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The data type of input index is expected "
"to be int32 or int64, but recieved %s.",
phi::DataTypeToString(index_type)));
}
}
......
......@@ -107,6 +107,18 @@ class TestTakeAlongAxisAPI(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)
paddle.enable_static()
def test_api_dygraph_dtype(self):
paddle.disable_static(self.place[0])
with self.assertRaises(AssertionError):
x_tensor = paddle.to_tensor(self.x_np)
self.index = paddle.to_tensor(self.index_np).astype("float32")
out = paddle.take_along_axis(x_tensor, self.index, self.axis)
out_ref = np.array(
np.take_along_axis(self.x_np, self.index_np, self.axis)
)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)
paddle.enable_static()
class TestTakeAlongAxisAPICase1(TestTakeAlongAxisAPI):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册