未验证 提交 9ed8bafd 编写于 作者: X xiaoye 提交者: GitHub

[fix] move exception throw out of omp parallel for loop (#55064)

上级 eb12739e
......@@ -48,14 +48,8 @@ struct EmbeddingCPUFunctor {
dev_ctx_.template Alloc<T>(out_);
auto* output = out_->data<T>();
#if defined(_OPENMP) && !defined(PADDLE_WITH_CUDA)
#pragma omp parallel for
#endif
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
if (padding_idx_ == kNoPadding && ids[i] != padding_idx_) {
PADDLE_ENFORCE_LT(
ids[i],
row_number,
......@@ -74,6 +68,17 @@ struct EmbeddingCPUFunctor {
"value.",
row_number,
ids[i]));
}
}
#if defined(_OPENMP) && !defined(PADDLE_WITH_CUDA)
#pragma omp parallel for
#endif
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(T));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册