未验证 提交 9bf70039 编写于 作者: H hong 提交者: GitHub

fix softmax with cross entropy out of bound; test=develop (#25549) (#25567)

* fix softmax with cross entropy out of bound; test=develop (#25549)

* fix index error; test=develop
上级 11c231a7
...@@ -28,9 +28,11 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, ...@@ -28,9 +28,11 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
int idx_n = i / remain; int idx_n = i / remain;
int idx_remain = i % remain; int idx_remain = i % remain;
int idx = idx_n * d + labels[i] * remain + idx_remain; int tmp = labels[i];
logit_grad[idx] -= if (ignore_index != tmp) {
ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.); int idx = idx_n * d + tmp * remain + idx_remain;
logit_grad[idx] -= static_cast<T>(1.);
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册