From a522f8c6838b6773c5bbaad8a70ca6af2b0ec423 Mon Sep 17 00:00:00 2001 From: Feng Xing <79969986+xingfeng01@users.noreply.github.com> Date: Mon, 1 Nov 2021 20:14:12 +0800 Subject: [PATCH] negative label in softmax cross entropy (#36891) --- .../softmax_with_cross_entropy_op.cu | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index b81a37a687..6a9dca9fe2 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -73,17 +73,21 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, // thread ids compute loss[ids] using softmax[idx] if (ids < n * d) { - int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (labels[ids] == ignore_idx) { - loss[ids] = static_cast(0.0); + if (labels[ids] < 0) { // label is negative + loss[ids] = static_cast(0.0); + } else { // label is positive of zero + int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (labels[ids] == ignore_idx) { + loss[ids] = static_cast(0.0); + } else { + loss[ids] = -Log(softmax[idx]); + } } else { + // IgnoreIndex is false loss[ids] = -Log(softmax[idx]); } - } else { - // IgnoreIndex is false - loss[ids] = -Log(softmax[idx]); } } } -- GitLab