未验证 提交 b65e9326 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add PADDLE_THROW in take_along_axis kernel when the datatype of index is wrong. (#53556)

上级 08b44e67
...@@ -30,11 +30,6 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, ...@@ -30,11 +30,6 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int axis, int axis,
DenseTensor* x_grad) { DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU,
true,
errors::PreconditionNotMet("This kernel only runs on GPU."));
// We need to know the shape of input matrix to determine the shape of grad // We need to know the shape of input matrix to determine the shape of grad
// matrix of input. // matrix of input.
x_grad->Resize(x.dims()); x_grad->Resize(x.dims());
...@@ -55,6 +50,11 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, ...@@ -55,6 +50,11 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
phi::funcs::gpu_scatter_add_kernel<T, int64_t>( phi::funcs::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx); *x_grad, axis, index, out_grad, dev_ctx);
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The data type of input index is expected "
"to be int32 or int64, but recieved %s.",
phi::DataTypeToString(index_type)));
} }
} }
......
...@@ -28,11 +28,6 @@ void TakeAlongAxisKernel(const Context& dev_ctx, ...@@ -28,11 +28,6 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& index, const DenseTensor& index,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU,
true,
errors::PreconditionNotMet("This kernel only runs on GPU device."));
out->Resize(index.dims()); out->Resize(index.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
...@@ -41,6 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx, ...@@ -41,6 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
phi::funcs::gpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx); phi::funcs::gpu_gather_kernel<T, int32_t>(x, axis, index, *out, dev_ctx);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
phi::funcs::gpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx); phi::funcs::gpu_gather_kernel<T, int64_t>(x, axis, index, *out, dev_ctx);
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The data type of input index is expected "
"to be int32 or int64, but recieved %s.",
phi::DataTypeToString(index_type)));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册