diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu old mode 100644 new mode 100755 index ea1bca8b4d58dfefa7b03b3c821faed2a175931e..b65ae01ddf919f48c868e8938126a62a2c165f5e --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -42,15 +42,46 @@ using CudnnDataType = platform::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; -inline static int GetDesiredBlockDim(int block_dim) { +inline static int GetDesiredBlockDim(int64_t block_dim) { #ifdef __HIPCC__ const int kMaxBlockDim = 256; + const int lwarpSize = 64; #else const int kMaxBlockDim = 512; + const int lwarpSize = 32; #endif - return block_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << (static_cast(std::log2f(block_dim)))); + return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize; +} + +template +static __forceinline__ __device__ U WarpReduceSum(U val) { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val += paddle::platform::CudaShuffleDownSync(mask, val, offset); + } + return val; +} + +template +__forceinline__ __device__ U BlockReduceSum(U val) { + static __shared__ U shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = WarpReduceSum(val); // Each warp performs partial reduction + + if (lane == 0) shared[wid] = val; // Write reduced value to shared memory + + __syncthreads(); // Wait for all partial reductions + + // read from shared memory only if that warp existed + val = + (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast(0); + + if (wid == 0) val = WarpReduceSum(val); // Final reduce within first warp + + return val; } #define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ @@ -70,15 +101,17 @@ inline static int GetDesiredBlockDim(int block_dim) { FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__) -#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \ - log2_block_dim, feature_size, kMaxBlockNum, ...) \ - case (1 << (log2_block_dim)): { \ - for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \ - int col_offset = i * kMaxBlockNum; \ - int block_num = std::min(feature_size - col_offset, kMaxBlockNum); \ - constexpr auto kBlockDim = (1 << (log2_block_dim)); \ - __VA_ARGS__; \ - } \ +#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \ + log2_block_dim, feature_size, kMaxBlockNum, ...) \ + case (1 << (log2_block_dim)): { \ + for (int64_t i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); \ + i++) { \ + int64_t col_offset = i * static_cast(kMaxBlockNum); \ + int block_num = static_cast(std::min( \ + feature_size - col_offset, static_cast(kMaxBlockNum))); \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } \ } break #define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \ @@ -147,31 +180,32 @@ __inline__ __device__ half rsqrt_(const half val) { template __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, T *y, U *mean, U *var, float epsilon, - int feature_size) { - using BlockReduce = cub::BlockReduce, BlockDim>; - __shared__ typename BlockReduce::TempStorage temp_storage; + int64_t feature_size) { __shared__ U mean_share; __shared__ U var_share; - int beg_idx = blockIdx.x * feature_size + threadIdx.x; - int end_idx = (blockIdx.x + 1) * feature_size; + int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x; + int64_t end_idx = (blockIdx.x + 1) * feature_size; // Step 1: Reduce to calculate mean and var U mean_val = 0; U var_val = 0; - for (int i = beg_idx; i < end_idx; i += BlockDim) { + for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { U tmp = static_cast(x[i]); mean_val += tmp; var_val += (tmp * tmp); } - auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(mean_val, var_val), - PairForLayerNormAddFunctor()); + + mean_val = BlockReduceSum(mean_val); + var_val = BlockReduceSum(var_val); + if (threadIdx.x == 0) { - auto tmp = pair.first_ / feature_size; + auto scale = static_cast(1.) / static_cast(feature_size); + auto tmp = mean_val * scale; mean[blockIdx.x] = mean_share = static_cast(tmp); - var[blockIdx.x] = var_share = - static_cast(pair.second_ / feature_size - tmp * tmp); + var_share = static_cast(var_val * scale - mean_share * mean_share); + var_share = var_share > U(0) ? var_share : U(0); + var[blockIdx.x] = var_share; } __syncthreads(); @@ -181,13 +215,13 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, // Step 2: Calculate y if (scale != nullptr) { if (bias != nullptr) { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; + for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast( scale[j] * (static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; + for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) * invvar); @@ -195,13 +229,13 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, } } else { // scale == nullptr if (bias != nullptr) { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; + for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; + for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar); } @@ -278,18 +312,18 @@ __global__ void LayerNormForwardFP16(const T *x, const U *scale, const U *bias, template __inline__ __device__ void cuLoadAddStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const T *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ var, - const float epsilon) { - const int i1 = i1_block + thr_load_row_off; + const int64_t i1_block, const int thr_load_row_off, + const int thr_load_col_off, const int i2_off, const int row_stride, + U *warp_buf1, U *warp_buf2, const T *input, const T *dout, + const int64_t i1_end, const int64_t n2, const U *__restrict__ mean, + const U *__restrict__ var, const float epsilon) { + const int64_t i1 = i1_block + thr_load_row_off; if (i1 >= i1_end) return; U curr_mean = mean[i1]; U curr_invvar = rsqrt_(var[i1] + epsilon); for (int k = 0; k < VPT; ++k) { const int i2 = i2_off + k; - const int load_idx = i1 * n2 + i2; + const int64_t load_idx = i1 * n2 + i2; const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { U curr_input = static_cast(input[load_idx]); @@ -303,8 +337,8 @@ __inline__ __device__ void cuLoadAddStridedInputs( template __global__ void LayerNormBackwardPartGradGammaBeta( - const T *__restrict__ dout, const T *__restrict__ input, const int n1, - const int n2, const U *__restrict__ mean, const U *__restrict__ var, + const T *__restrict__ dout, const T *__restrict__ input, const int64_t n1, + const int64_t n2, const U *__restrict__ mean, const U *__restrict__ var, float epsilon, U *part_grad_gamma, U *part_grad_beta) { // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX // -> blockDim.x @@ -330,7 +364,7 @@ __global__ void LayerNormBackwardPartGradGammaBeta( } __syncthreads(); - for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; + for (int64_t i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; i1_block += VPTX * BDIMY * gridDim.y) { cuLoadAddStridedInputs( i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, @@ -363,7 +397,7 @@ __global__ void LayerNormBackwardPartGradGammaBeta( } __syncthreads(); } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; + int64_t i2 = blockIdx.x * blockDim.x + threadIdx.x; if (threadIdx.y == 0 && i2 < n2) { int row1 = threadIdx.y; int row2 = threadIdx.y + 1; @@ -381,7 +415,7 @@ __global__ void LayerNormBackwardSumGradGammaBeta( const int n1, const int n2, U *grad_gamma, U *grad_beta) { // sum partial gradients for gamma and beta __shared__ U buf[BDIMX * BDIMY]; - int i2 = blockIdx.x * BDIMX + threadIdx.x; + int64_t i2 = blockIdx.x * BDIMX + threadIdx.x; if (i2 < n2) { // each warp does sequential reductions until reduced part_size is num_warps int num_warp_reductions = part_size / BDIMY; @@ -552,22 +586,17 @@ __global__ void LayerNormBackwardComputeGradInput( // Make sure that d_scale != nullptr && d_bias != nullptr // Since d_scale != nullptr, scale would not be nullptr template -__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, - U *d_scale, U *d_bias, T *d_x, - const U *mean, const U *var, - const U *scale, float epsilon, - int batch_size, int feature_size, - int col_offset) { - using BlockReduce = cub::BlockReduce, BlockDim>; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); - int end_idx = batch_size * feature_size + (blockIdx.x + col_offset); - int stride = BlockDim * feature_size; +__global__ void LayerNormBackwardGradientAll( + const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, + const U *var, const U *scale, float epsilon, int64_t batch_size, + int64_t feature_size, int64_t col_offset) { + int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); + int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset); + int64_t stride = BlockDim * feature_size; U d_scale_partial = static_cast(0), d_bias_partial = static_cast(0); - for (int i = beg_idx; i < end_idx; i += stride) { + for (int64_t i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; auto var_val = real_sqrt(static_cast(var[row_idx]) + epsilon); d_scale_partial += static_cast(d_y[i]) * @@ -579,13 +608,12 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, } } - auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(d_scale_partial, d_bias_partial), - PairForLayerNormAddFunctor()); + d_scale_partial = BlockReduceSum(d_scale_partial); + d_bias_partial = BlockReduceSum(d_bias_partial); if (threadIdx.x == 0) { - d_scale[blockIdx.x + col_offset] = pair.first_; - d_bias[blockIdx.x + col_offset] = pair.second_; + d_scale[blockIdx.x + col_offset] = d_scale_partial; + d_bias[blockIdx.x + col_offset] = d_bias_partial; } } @@ -595,16 +623,16 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, template __global__ void LayerNormBackwardGradientScaleOrBias( const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, - const U *var, const U *scale, float epsilon, int batch_size, - int feature_size, int col_offset) { + const U *var, const U *scale, float epsilon, int64_t batch_size, + int64_t feature_size, int col_offset) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; - int end_idx = batch_size * feature_size + blockIdx.x + col_offset; + int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; + int64_t end_idx = batch_size * feature_size + blockIdx.x + col_offset; int stride = BlockDim * feature_size; U d_scale_or_d_bias_partial = static_cast(0); - for (int i = beg_idx; i < end_idx; i += stride) { + for (int64_t i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; auto var_val = static_cast(real_sqrt(static_cast(var[row_idx]) + epsilon)); @@ -639,22 +667,20 @@ __global__ void LayerNormBackwardGradientScaleOrBias( } template -__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, - const U *mean, - const U *var, - float epsilon, - int feature_size) { +__global__ void LayerNormBackwardPostProcessToCalculateDX( + const T *x, T *d_x, const U *mean, const U *var, float epsilon, + int64_t feature_size) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ U d_x_reduce_tmp[2]; - int beg_idx = blockIdx.x * feature_size + threadIdx.x; - int end_idx = (blockIdx.x + 1) * feature_size; + int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x; + int64_t end_idx = (blockIdx.x + 1) * feature_size; U block_mean = mean[blockIdx.x]; U block_var = var[blockIdx.x]; U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); - for (int i = beg_idx; i < end_idx; i += BlockDim) { + for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { d_x_mean_partial += static_cast(d_x[i]); d_x_var_partial += static_cast(d_x[i]) * (static_cast(x[i]) - block_mean); @@ -675,7 +701,7 @@ __global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, d_x_mean_partial = d_x_reduce_tmp[0]; d_x_var_partial = d_x_reduce_tmp[1]; - for (int i = beg_idx; i < end_idx; i += BlockDim) { + for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { d_x[i] -= static_cast(d_x_mean_partial); d_x[i] -= static_cast((static_cast(x[i]) - block_mean) * d_x_var_partial); @@ -688,17 +714,17 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, T *d_x, const U *mean, const U *var, const U *scale, float epsilon, - int feature_size) { + int64_t feature_size) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ U d_x_reduce_tmp[2]; - int beg_idx = blockIdx.x * feature_size + threadIdx.x; - int end_idx = (blockIdx.x + 1) * feature_size; + int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x; + int64_t end_idx = (blockIdx.x + 1) * feature_size; U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); - for (int i = beg_idx; i < end_idx; i += BlockDim) { + for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { auto var_val = static_cast(real_sqrt(static_cast(block_var) + epsilon)); if (scale != nullptr) { @@ -728,7 +754,7 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, d_x_mean_partial = d_x_reduce_tmp[0]; d_x_var_partial = d_x_reduce_tmp[1]; - for (int i = beg_idx; i < end_idx; i += BlockDim) { + for (int64_t i = beg_idx; i < end_idx; i += BlockDim) { d_x[i] -= static_cast(d_x_mean_partial); d_x[i] -= static_cast((static_cast(x[i]) - block_mean) * d_x_var_partial); @@ -738,8 +764,8 @@ __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, template __global__ void LayerNormBackwardWhenBatchSizeIsOne( const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean, - const U *var, const U *scale, float epsilon, int feature_size) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; + const U *var, const U *scale, float epsilon, int64_t feature_size) { + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < feature_size) { auto var_val = static_cast(real_sqrt(static_cast(var[idx]) + epsilon)); @@ -764,8 +790,8 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( template static void LayerNormBackward(const T *x, const T *d_y, const U *scale, const U *mean, const U *var, T *d_x, U *d_scale, - U *d_bias, float epsilon, int batch_size, - int feature_size, + U *d_bias, float epsilon, int64_t batch_size, + int64_t feature_size, const framework::ExecutionContext &ctx) { auto &dev_ctx = ctx.cuda_device_context(); auto stream = dev_ctx.stream(); @@ -925,8 +951,8 @@ void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, int begin_norm_axis, float eps) { const auto x_dims = framework::make_ddim(input_shape); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int batch_size = static_cast(matrix_dim[0]); - int feature_size = static_cast(matrix_dim[1]); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormForward<<>>( @@ -986,8 +1012,8 @@ class LayerNormKernel auto *bias_data = (bias == nullptr ? nullptr : bias->data()); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int batch_size = static_cast(matrix_dim[0]); - int feature_size = static_cast(matrix_dim[1]); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); auto stream = ctx.cuda_device_context().stream(); @@ -1040,8 +1066,8 @@ class LayerNormGradKernel const auto &x_dims = x->dims(); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int batch_size = static_cast(matrix_dim[0]); - int feature_size = static_cast(matrix_dim[1]); + int64_t batch_size = static_cast(matrix_dim[0]); + int64_t feature_size = static_cast(matrix_dim[1]); LayerNormBackward(x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, d_scale_data, d_bias_data, epsilon, diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py index 77cd6926b563da69b33c5d52a7064137f5487ba0..987c3da4dd7be887c00007fb25d88acc3ae69762 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op_v2.py @@ -51,6 +51,7 @@ class TestDygraphLayerNormv2(unittest.TestCase): self.assertTrue(np.allclose(y1, y2)) def test_static(self): + paddle.enable_static() places = [fluid.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): places.append(fluid.CUDAPlace(0))