From b65e93269744efed630e409507a489cee37e6125 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Sat, 6 May 2023 21:38:37 +0800 Subject: [PATCH] Add PADDLE_THROW in take_along_axis kernel when the datatype of index is wrong. (#53556) --- paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu | 10 +++++----- paddle/phi/kernels/gpu/take_along_axis_kernel.cu | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu index f38ab641669..6cea7592836 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu @@ -30,11 +30,6 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int axis, DenseTensor* x_grad) { - PADDLE_ENFORCE_EQ( - dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU, - true, - errors::PreconditionNotMet("This kernel only runs on GPU.")); - // We need to know the shape of input matrix to determine the shape of grad // matrix of input. x_grad->Resize(x.dims()); @@ -55,6 +50,11 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, } else if (index_type == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The data type of input index is expected " + "to be int32 or int64, but recieved %s.", + phi::DataTypeToString(index_type))); } } diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu index 6ad828f2680..ba4c6ba27e6 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -28,11 +28,6 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const DenseTensor& index, int axis, DenseTensor* out) { - PADDLE_ENFORCE_EQ( - dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU, - true, - errors::PreconditionNotMet("This kernel only runs on GPU device.")); - out->Resize(index.dims()); dev_ctx.template Alloc(out); @@ -41,6 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx, phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); } else if (index_type == DataType::INT64) { phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The data type of input index is expected " + "to be int32 or int64, but recieved %s.", + phi::DataTypeToString(index_type))); } } -- GitLab