未验证 提交 adaa2510 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick] Fix bug in log_softmax kernel when lastdim is larger than 100000 (#53656)

上级 fb3dbccc
......@@ -499,8 +499,7 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
// write data to softmax_output according to the LogMode
if (LogMode) {
LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max,
std::log(thread_exp));
LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
if (input_align_shift == output_align_shift) {
ThreadVecWriteVec<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, input_align_shift, reduction);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册