From 4341ebd9abbb6ca97d0c2e4a5f8129256414fc6f Mon Sep 17 00:00:00 2001 From: Ghost Screaming Date: Thu, 13 Apr 2023 16:53:51 +0800 Subject: [PATCH] Fix ignore index of c_softmax_with_cross_entropy_op. (#52835) * Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result is wrong. * Remove climits. * Fix bug of c_softmax_with_cross_entropy_op. Support ignore_index is negative number. --- .../c_softmax_with_cross_entropy_op.cu | 69 ++++++++++--------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index c37266a9b42..6a2dab9005a 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -40,6 +40,7 @@ template __global__ void MaskLabelByIndex(T* predicted_logits, const T* logit, const IndexT* label, + const IndexT ignore_index, const int start_index, const int end_index, const int64_t N, @@ -47,13 +48,15 @@ __global__ void MaskLabelByIndex(T* predicted_logits, const int nranks) { CUDA_KERNEL_LOOP(i, N) { auto real_label = label[i]; - PADDLE_ENFORCE((real_label < D * nranks) && (real_label >= 0), + PADDLE_ENFORCE(((real_label < D * nranks) && (real_label >= 0)) || + (real_label == ignore_index), "The index is out of bounds, " "please check whether the value of label and " "input meet the class number. It should " - "be less than [%d], but received [%d]", - D * nranks, - real_label); + "be less than [%ld] or equal to [%ld], but received [%ld]", + static_cast(D * nranks), + static_cast(ignore_index), + static_cast(real_label)); if (real_label >= start_index && real_label < end_index) { predicted_logits[i] = logit[i * D + real_label - start_index]; @@ -204,20 +207,22 @@ struct CSoftmaxWithCrossEntropyFunctor { const auto& label_type = framework::TransToProtoVarType(labels->dtype()); if (label_type == framework::proto::VarType::INT32) { - MaskLabelByIndex - <<>>(predicted_logits.data(), - softmax_2d.data(), - labels->data(), - start_index, - end_index, - N, - D, - nranks); + MaskLabelByIndex<<>>( + predicted_logits.data(), + softmax_2d.data(), + labels->data(), + static_cast(ignore_index), + start_index, + end_index, + N, + D, + nranks); } else if (label_type == framework::proto::VarType::INT64) { MaskLabelByIndex <<>>(predicted_logits.data(), softmax_2d.data(), labels->data(), + ignore_index, start_index, end_index, N, @@ -362,25 +367,27 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { const auto& label_type = framework::TransToProtoVarType(labels->dtype()); if (label_type == framework::proto::VarType::INT32) { - MaskLabelByIndex - <<>>(predicted_logits.data(), - softmax_2d.data(), - labels->data(), - start_index, - end_index, - N, - D, - nranks); + MaskLabelByIndex<<>>( + predicted_logits.data(), + softmax_2d.data(), + labels->data(), + static_cast(ignore_index), + start_index, + end_index, + N, + D, + nranks); } else if (label_type == framework::proto::VarType::INT64) { - MaskLabelByIndex - <<>>(predicted_logits.data(), - softmax_2d.data(), - labels->data(), - start_index, - end_index, - N, - D, - nranks); + MaskLabelByIndex<<>>( + predicted_logits.data(), + softmax_2d.data(), + labels->data(), + static_cast(ignore_index), + start_index, + end_index, + N, + D, + nranks); } in_out.clear(); -- GitLab