未验证 提交 dcadc256 编写于 作者: F Feng Xing 提交者: GitHub

negative label in softmax cross entropy (#36907)

上级 09bc9c06
......@@ -73,17 +73,21 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
// thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) {
int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d;
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (labels[ids] == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
if (labels[ids] < 0) { // label is negative
loss[ids] = static_cast<T>(0.0);
} else { // label is positive of zero
int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d;
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (labels[ids] == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -Log(softmax[idx]);
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册