diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 287f0670a81e1675b036f583b7de117e0713b0d4..98e0464134f05bff46b8962a13a01c8761d7c9b3 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -200,6 +200,10 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, softmax[beg_idx] -= diff_max_sum; beg_idx += step; } + + // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to + // calculate diff_max_sum, __syncthreads() is needed here. + __syncthreads(); if (threadIdx.x == 0) max_data[blockIdx.x] = 0; }