From b3090ad406796a198e7de8c53c51299d81cd8a52 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 3 Dec 2019 07:14:58 +0800 Subject: [PATCH] fix synchronization problem in softmax_with_cross_entropy_op, test=develop (#21480) --- paddle/fluid/operators/softmax_with_cross_entropy_op.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 287f0670a81..98e0464134f 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -200,6 +200,10 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, softmax[beg_idx] -= diff_max_sum; beg_idx += step; } + + // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to + // calculate diff_max_sum, __syncthreads() is needed here. + __syncthreads(); if (threadIdx.x == 0) max_data[blockIdx.x] = 0; } -- GitLab