未验证 提交 a522f8c6 编写于 作者: F Feng Xing 提交者: GitHub

negative label in softmax cross entropy (#36891)

上级 4c93c4c3
...@@ -73,17 +73,21 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, ...@@ -73,17 +73,21 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
// thread ids compute loss[ids] using softmax[idx] // thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) { if (ids < n * d) {
int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; if (labels[ids] < 0) { // label is negative
if (IgnoreIndex == true) { loss[ids] = static_cast<T>(0.0);
// IgnoreIndex is true } else { // label is positive of zero
if (labels[ids] == ignore_idx) { int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d;
loss[ids] = static_cast<T>(0.0); if (IgnoreIndex == true) {
// IgnoreIndex is true
if (labels[ids] == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -Log(softmax[idx]);
}
} else { } else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]); loss[ids] = -Log(softmax[idx]);
} }
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册