diff --git a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu index f38ab6416693f306a70dd21d13496c6ee82113ed..6cea75928367305bf2c7fcbdb258b343aa31b2c1 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu @@ -30,11 +30,6 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int axis, DenseTensor* x_grad) { - PADDLE_ENFORCE_EQ( - dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU, - true, - errors::PreconditionNotMet("This kernel only runs on GPU.")); - // We need to know the shape of input matrix to determine the shape of grad // matrix of input. x_grad->Resize(x.dims()); @@ -55,6 +50,11 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, } else if (index_type == DataType::INT64) { phi::funcs::gpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The data type of input index is expected " + "to be int32 or int64, but recieved %s.", + phi::DataTypeToString(index_type))); } } diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu index 6ad828f2680b3cb3b866d1b9f23fe0feec0660b4..ba4c6ba27e68246ec247663d96b03334a9e9fcc6 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -28,11 +28,6 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const DenseTensor& index, int axis, DenseTensor* out) { - PADDLE_ENFORCE_EQ( - dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU, - true, - errors::PreconditionNotMet("This kernel only runs on GPU device.")); - out->Resize(index.dims()); dev_ctx.template Alloc(out); @@ -41,6 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx, phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); } else if (index_type == DataType::INT64) { phi::funcs::gpu_gather_kernel(x, axis, index, *out, dev_ctx); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("The data type of input index is expected " + "to be int32 or int64, but recieved %s.", + phi::DataTypeToString(index_type))); } }