未验证 提交 a522f8c6 编写于 作者: F Feng Xing 提交者: GitHub

negative label in softmax cross entropy (#36891)

上级 4c93c4c3
......@@ -73,6 +73,9 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
// thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) {
if (labels[ids] < 0) { // label is negative
loss[ids] = static_cast<T>(0.0);
} else { // label is positive of zero
int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d;
if (IgnoreIndex == true) {
// IgnoreIndex is true
......@@ -86,6 +89,7 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
loss[ids] = -Log(softmax[idx]);
}
}
}
}
/*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册