From d68c38efd488e2a19134f4c9d46a2c1cd0543cca Mon Sep 17 00:00:00 2001 From: seemingwang Date: Tue, 18 Oct 2022 19:22:45 +0800 Subject: [PATCH] add embedding range check (#46991) * add embedding range check * change head file * change head file * fix --- paddle/phi/kernels/gpu/embedding_kernel.cu | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index bb22fea5f6..b6bf4bce42 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -13,13 +13,12 @@ // limitations under the License. #include "paddle/phi/kernels/embedding_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/embedding_util.h" - namespace phi { template @@ -35,6 +34,16 @@ __global__ void EmbeddingFW(T *output, while (idy < K) { auto id = static_cast(ids[idy]); + if (PaddingFlag == false || id != padding_idx) { + PADDLE_ENFORCE(id >= 0, + "Id should no less than 0 but received an id value: %lld.", + id); + PADDLE_ENFORCE( + id < N, + "Id should smaller than %lld but received an id value: %lld.", + N, + id); + } T *out = output + idy * D; const T *tab = table + id * D; for (int i = idx; i < D; i += blockDim.x) { -- GitLab