diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu index 11c16581d60c41148c870f99ad544802e2ef64d2..987cd1addec04ca42a2d44de7ad48fecbe03a783 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu @@ -27,7 +27,7 @@ __global__ void CrossEntropyWithSparseKernel(const T *logits, const S *labels, c for (size_t i = 0; i < batch_size; ++i) { T logit = logits[i * class_num + labels[i]]; if (logit <= 0) { - logit += epsilon; + logit = epsilon; } total_loss += -logf(logit); } @@ -54,8 +54,9 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label template __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { losses[threadIdx.x] = 0; + T epsilon = 1e-6; for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { - losses[threadIdx.x] -= logf(logits[i]) * labels[i]; + losses[threadIdx.x] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i]; dlogits[i] = logits[i] - labels[i]; } }