diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 44c233be5750d4b48a63f3b274c5c6a5830c0482..8d1d3f6f1614a9a7d564c3d2a90e583435df7e02 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, size_t width_stride = gridDim.x * blockDim.x; size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); + size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) + + ((height & (BLOCK_H - 1)) ? BLOCK_H : 0); #pragma unroll for (size_t w = idx; w < full_width; w += width_stride) { @@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, __syncthreads(); size_t offset = w + threadIdx.y * width; #pragma unroll - for (size_t h = threadIdx.y; h < height; + for (size_t h = threadIdx.y; h < full_height; h += BLOCK_H) { // block-stride loop across matrix height sdata[threadIdx.y][threadIdx.x] += - (w < width) ? in[offset] : (static_cast(0)); + (w < width && h < height) ? in[offset] : (static_cast(0)); offset += width * BLOCK_H; } __syncthreads(); @@ -184,21 +186,24 @@ __global__ void FP16MatrixColReduce( size_t width_stride = gridDim.x * blockDim.x; size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + ((width & (BLOCK_W - 1)) ? BLOCK_W : 0); - + size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) + + ((height & (BLOCK_H - 1)) ? BLOCK_H : 0); #pragma unroll for (size_t w = idx; w < full_width; w += width_stride) { for (int r = 0; r < repeats; r++) { sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0; } __syncthreads(); +#pragma unroll for (int r = 0; r < repeats; r++) { size_t offset = w + (r * BLOCK_W + threadIdx.y) * width; #pragma unroll - for (size_t h = r * BLOCK_H + threadIdx.y; h < height; + for (size_t h = threadIdx.y + r * BLOCK_W; h < full_height; h += BLOCK_H) { // block-stride loop across matrix height sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] += - (w < width) ? in[offset + r * BLOCK_W * width] - : (static_cast(0)); + (w < width && h < height) + ? in[offset] + : (static_cast(0)); offset += width * BLOCK_H; } } @@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, dout_data, out_data, nums, stream); PADDLE_ENFORCE_CUDA_SUCCESS(err); + return; } constexpr int block_x = 32;