未验证 提交 4da15e6a 编写于 作者: L Lijunhui 提交者: GitHub

Fixed a bug of log_softmax: op input was modified to 'nan' (#32937)

上级 7101af3f
...@@ -104,7 +104,7 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src, ...@@ -104,7 +104,7 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src,
#pragma unroll #pragma unroll
for (int it = 0; it < warp_iter; ++it) { for (int it = 0; it < warp_iter; ++it) {
int element_index = thread_in_warp_idx + it * kernel_warp_size; int element_index = thread_in_warp_idx + it * kernel_warp_size;
if (element_index < element_count) { if (element_index < effective_element_count) {
dst[batch_id * element_count + element_index] = dst[batch_id * element_count + element_index] =
static_cast<T>(elements[it] - max_value - sum); static_cast<T>(elements[it] - max_value - sum);
} else { } else {
...@@ -226,7 +226,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output, ...@@ -226,7 +226,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
#pragma unroll #pragma unroll
for (int iter = 0; iter < warp_iter; ++iter) { for (int iter = 0; iter < warp_iter; ++iter) {
int element_index = thread_in_warp_idx + iter * kernel_warp_size; int element_index = thread_in_warp_idx + iter * kernel_warp_size;
if (element_index < element_count) { if (element_index < effective_element_count) {
grad_input[batch_id * element_count + element_index] = static_cast<T>( grad_input[batch_id * element_count + element_index] = static_cast<T>(
(grad_output_register[iter] - std::exp(output_register[iter]) * sum)); (grad_output_register[iter] - std::exp(output_register[iter]) * sum));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册