diff --git a/paddle/fluid/operators/correlation_op.cu b/paddle/fluid/operators/correlation_op.cu index a51fce8132418b09c8f2db397fc83c8c69a8a429..76e10f90ef833fb7ad65008cd3e426c883d76d7e 100644 --- a/paddle/fluid/operators/correlation_op.cu +++ b/paddle/fluid/operators/correlation_op.cu @@ -42,6 +42,7 @@ __forceinline__ __device__ T blockReduceSum(T val) { int wid = threadIdx.x / warpSize; val = warpReduceSum(val); + __syncthreads(); if (lane == 0) shared[wid] = val; __syncthreads(); diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index f955011675cf5dada3cbcc3711838d50ff33a025..25c722358c4e326897cef98be2b62e5071959cf9 100755 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) { } template -__forceinline__ __device__ U BlockReduceSum(U val) { - static __shared__ U shared[32]; +__forceinline__ __device__ U BlockReduceSum(U val, U *shared) { int lane = threadIdx.x % warpSize; int wid = threadIdx.x / warpSize; val = WarpReduceSum(val); // Each warp performs partial reduction + __syncthreads(); if (lane == 0) shared[wid] = val; // Write reduced value to shared memory __syncthreads(); // Wait for all partial reductions - // read from shared memory only if that warp existed val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast(0); @@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, int64_t feature_size) { __shared__ U mean_share; __shared__ U var_share; + __shared__ U shared_mean[32]; + __shared__ U shared_var[32]; int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x; int64_t end_idx = (blockIdx.x + 1) * feature_size; @@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, var_val += (tmp * tmp); } - mean_val = BlockReduceSum(mean_val); - var_val = BlockReduceSum(var_val); + mean_val = BlockReduceSum(mean_val, shared_mean); + var_val = BlockReduceSum(var_val, shared_var); if (threadIdx.x == 0) { auto scale = static_cast(1.) / static_cast(feature_size); @@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll( } } - d_scale_partial = BlockReduceSum(d_scale_partial); - d_bias_partial = BlockReduceSum(d_bias_partial); + __shared__ U shared_scale[32]; + __shared__ U shared_bias[32]; + d_scale_partial = BlockReduceSum(d_scale_partial, shared_scale); + d_bias_partial = BlockReduceSum(d_bias_partial, shared_bias); if (threadIdx.x == 0) { d_scale[blockIdx.x + col_offset] = d_scale_partial; diff --git a/paddle/fluid/operators/math/math_cuda_utils.h b/paddle/fluid/operators/math/math_cuda_utils.h index e97dbd20ca142af75420ccf3ce349c1bdc928b09..8de4e8221c0e473e4577cf897762b8773f50ebb3 100644 --- a/paddle/fluid/operators/math/math_cuda_utils.h +++ b/paddle/fluid/operators/math/math_cuda_utils.h @@ -188,6 +188,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { val = warpReduceSum(val, mask); + __syncthreads(); if (lane == 0) shared[wid] = val; __syncthreads();