From 536d9a3b44e29ae6b2c42d2e63b01a1b226fd95d Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 26 Feb 2021 18:32:04 +0800 Subject: [PATCH] [cherry-pick]fix error message & label check in softmax_with_cross_entropy (#31123) * fix error message & label check in softmax_with_cross_entropy * fix error message & label check in softmax_with_cross_entropy * fix print comment * fix ignore_index check in softmax_with_cross_entropy --- .../operators/softmax_with_cross_entropy_op.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index cb4eeab56a6..f3e7a33d9b1 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -253,12 +253,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { public: HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, T* log_softmax, int64_t d, - int axis_dim) + int axis_dim, int ignore_idx) : labels_(labels), loss_(loss), log_softmax_(log_softmax), d_(d), - axis_dim_(axis_dim) {} + axis_dim_(axis_dim), + ignore_idx_(ignore_idx) {} __device__ void operator()(int64_t idx) const { // logits view as [n, axis_dim, remain], where d = axis_dim * remain @@ -268,6 +269,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { int64_t idx_remain = idx % remain; // labels, loss view as [n, remain] int64_t idx_lbl = idx_n * remain + idx_remain; + PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ || + labels_[idx_lbl] == ignore_idx_, + "The value of label[%ld] expected >= 0 and < %ld, or == %d," + "but got %ld. Please check input value.", + idx_lbl, d_, ignore_idx_, labels_[idx_lbl]); // It also would ignore labels not in range(class_num). if (idx_axis != labels_[idx_lbl]) { log_softmax_[idx] = exp_on_device(log_softmax_[idx]); @@ -284,6 +290,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { T* log_softmax_; int64_t d_; int axis_dim_; + int ignore_idx_; }; template @@ -351,7 +358,7 @@ static void HardLabelSoftmaxWithCrossEntropy( labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ } else { \ for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ - labels_data, loss_data, softmax_data, d, axis_dim)); \ + labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ } \ } break -- GitLab