From e362095e451fba7783fff2ffe2df45f0c6e443ee Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Thu, 16 Jul 2020 10:43:11 +0800 Subject: [PATCH] fix softmax with cross entropy out of bound; test=develop (#25549) --- paddle/fluid/operators/softmax_with_cross_entropy_op.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 344dfe2399..ba56e5e36f 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -27,9 +27,11 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, CUDA_KERNEL_LOOP(index, n * remain) { int idx_n = index / remain; int idx_remain = index % remain; - int idx = idx_n * d + labels[index] * remain + idx_remain; - logit_grad[idx] -= - ignore_index == labels[index] ? static_cast(0.) : static_cast(1.); + int tmp = labels[index]; + if (ignore_index != tmp) { + int idx = idx_n * d + tmp * remain + idx_remain; + logit_grad[idx] -= static_cast(1.); + } } } -- GitLab