提交 e69d9c84 编写于 作者: B Bai Yifan 提交者: qingqing01

code fix (#13365)

上级 ae67dcea
......@@ -31,7 +31,8 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
i += blockDim.x * gridDim.x) {
int idx = i * class_num + labels[i];
logit_grad[idx] -= static_cast<T>(1.);
logit_grad[idx] -=
ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册