From 9ed8bafd2abf0b74ca4c14056120e2c0c386cdf5 Mon Sep 17 00:00:00 2001 From: xiaoye <50870160+xiaoyewww@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:07:31 +0800 Subject: [PATCH] [fix] move exception throw out of omp parallel for loop (#55064) --- paddle/phi/kernels/cpu/embedding_kernel.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc index cfba1787a15..0d937e6364e 100644 --- a/paddle/phi/kernels/cpu/embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -48,14 +48,8 @@ struct EmbeddingCPUFunctor { dev_ctx_.template Alloc(out_); auto* output = out_->data(); -#if defined(_OPENMP) && !defined(PADDLE_WITH_CUDA) -#pragma omp parallel for -#endif - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { + if (padding_idx_ == kNoPadding && ids[i] != padding_idx_) { PADDLE_ENFORCE_LT( ids[i], row_number, @@ -74,6 +68,17 @@ struct EmbeddingCPUFunctor { "value.", row_number, ids[i])); + } + } + +#if defined(_OPENMP) && !defined(PADDLE_WITH_CUDA) +#pragma omp parallel for +#endif + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { memcpy(output + i * row_width, table + ids[i] * row_width, row_width * sizeof(T)); -- GitLab