未验证 提交 536d9a3b 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]fix error message & label check in softmax_with_cross_entropy (#31123)

* fix error message & label check in softmax_with_cross_entropy

* fix error message & label check in softmax_with_cross_entropy

* fix print comment

* fix ignore_index check in softmax_with_cross_entropy
上级 84a5ed9f
...@@ -253,12 +253,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { ...@@ -253,12 +253,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
public: public:
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
T* log_softmax, int64_t d, T* log_softmax, int64_t d,
int axis_dim) int axis_dim, int ignore_idx)
: labels_(labels), : labels_(labels),
loss_(loss), loss_(loss),
log_softmax_(log_softmax), log_softmax_(log_softmax),
d_(d), d_(d),
axis_dim_(axis_dim) {} axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int64_t idx) const { __device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain // logits view as [n, axis_dim, remain], where d = axis_dim * remain
...@@ -268,6 +269,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { ...@@ -268,6 +269,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
int64_t idx_remain = idx % remain; int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain] // labels, loss view as [n, remain]
int64_t idx_lbl = idx_n * remain + idx_remain; int64_t idx_lbl = idx_n * remain + idx_remain;
PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
labels_[idx_lbl] == ignore_idx_,
"The value of label[%ld] expected >= 0 and < %ld, or == %d,"
"but got %ld. Please check input value.",
idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
// It also would ignore labels not in range(class_num). // It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) { if (idx_axis != labels_[idx_lbl]) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]); log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
...@@ -284,6 +290,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { ...@@ -284,6 +290,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
T* log_softmax_; T* log_softmax_;
int64_t d_; int64_t d_;
int axis_dim_; int axis_dim_;
int ignore_idx_;
}; };
template <typename T> template <typename T>
...@@ -351,7 +358,7 @@ static void HardLabelSoftmaxWithCrossEntropy( ...@@ -351,7 +358,7 @@ static void HardLabelSoftmaxWithCrossEntropy(
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} else { \ } else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \ for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
labels_data, loss_data, softmax_data, d, axis_dim)); \ labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} \ } \
} break } break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册