From 61cae0dff33a20d0af97cf2cf380ef0982181758 Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Fri, 11 Jun 2021 11:20:16 +0800 Subject: [PATCH] [cherry-pick]Fixed a bug of log_softmax: op input was modified to 'nan' (#32937) (#33436) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用op benchmark时发现,当输入数据量小于某个值时,python 端 log_softmax 接口的输入值经过计算过后 会被改变为nan。输出正常。 cherry-pick自 #32937 --- paddle/fluid/operators/log_softmax_op.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu index e4fe92c6256..12c607adb44 100644 --- a/paddle/fluid/operators/log_softmax_op.cu +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -104,7 +104,7 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src, #pragma unroll for (int it = 0; it < warp_iter; ++it) { int element_index = thread_in_warp_idx + it * kernel_warp_size; - if (element_index < element_count) { + if (element_index < effective_element_count) { dst[batch_id * element_count + element_index] = static_cast(elements[it] - max_value - sum); } else { @@ -226,7 +226,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output, #pragma unroll for (int iter = 0; iter < warp_iter; ++iter) { int element_index = thread_in_warp_idx + iter * kernel_warp_size; - if (element_index < element_count) { + if (element_index < effective_element_count) { grad_input[batch_id * element_count + element_index] = static_cast( (grad_output_register[iter] - std::exp(output_register[iter]) * sum)); } -- GitLab