未验证 提交 110febdc 编写于 作者: G Guo Sheng 提交者: GitHub

Fix gradients with ignore_idx in softmax_with_cross_entropy (#28622)

* Fix gradients with ignore_idx in softmax_with_cross_entropy.
test=develop

* Fix gradients with ignore_idx in softmax_with_cross_entropy on cpu.
Remove softmax_with_cross_entropy from op_threshold_white_list.
test=develop

* Fix test_softmax_cross_entropy_op.py.
test=develop
上级 a3bc3bcd
...@@ -37,11 +37,17 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, ...@@ -37,11 +37,17 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
template <typename T> template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num, __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
const int d, const int remain) { const int d, const int remain, const int64_t* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) { CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d; int idx_n = index / d;
int idx_remain = index % remain; int idx_remain = index % remain;
logit_grad[index] *= loss_grad[idx_n * remain + idx_remain]; int idx_lbl = idx_n * remain + idx_remain;
if (labels[idx_lbl] == ignore_index) {
logit_grad[index] = static_cast<T>(0.);
} else {
logit_grad[index] *= loss_grad[idx_lbl];
}
} }
} }
...@@ -260,6 +266,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { ...@@ -260,6 +266,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
int idx_remain = idx % remain; int idx_remain = idx % remain;
// labels, loss view as [n, remain] // labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain; int idx_lbl = idx_n * remain + idx_remain;
// It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) { if (idx_axis != labels_[idx_lbl]) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]); log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
} else { } else {
...@@ -513,7 +520,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -513,7 +520,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
int num = n * d; int num = n * d;
grid = (num + block - 1) / block; grid = (num + block - 1) / block;
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num, Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
d, remain); d, remain, label_data, ignore_index);
} }
} }
}; };
......
...@@ -82,6 +82,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -82,6 +82,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
} }
const bool soft_label = context.Attr<bool>("soft_label"); const bool soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
...@@ -115,8 +116,14 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -115,8 +116,14 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
for (int j = 0; j < remain; j++) { for (int j = 0; j < remain; j++) {
int idx = i * remain + j; int idx = i * remain + j;
logit_grad_data[i * d + label_data[idx] * remain + j] -= if (label_data[idx] == ignore_index) {
out_grad_data[idx]; for (int k = 0; k < axis_dim; ++k) {
logit_grad_data[i * d + k * remain + j] = 0;
}
} else {
logit_grad_data[i * d + label_data[idx] * remain + j] -=
out_grad_data[idx];
}
} }
} }
} }
......
...@@ -83,9 +83,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -83,9 +83,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.attrs = { self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode, "numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label, "soft_label": self.soft_label,
"ignore_index": self.ignore_index,
} }
if self.ignore_index >= 0:
self.attrs['ignore_index'] = self.ignore_index
if self.axis != -1: if self.axis != -1:
self.attrs['axis'] = self.axis self.attrs['axis'] = self.axis
...@@ -93,7 +93,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -93,7 +93,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.05) self.check_grad(["Logits"], "Loss", max_relative_error=5e-5)
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册