/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; inline static int GetDesiredBlockDim(int block_dim) { const int kMaxBlockDim = 512; return block_dim >= kMaxBlockDim ? kMaxBlockDim : (1 << (static_cast(std::log2f(block_dim)))); } #define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ case (1 << (log2_block_dim)): { \ constexpr auto kBlockDim = (1 << (log2_block_dim)); \ __VA_ARGS__; \ } break #define FIXED_BLOCK_DIM_CASE(...) \ FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__) #define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \ log2_block_dim, feature_size, kMaxBlockNum, ...) \ case (1 << (log2_block_dim)): { \ for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \ int col_offset = i * kMaxBlockNum; \ int block_num = std::min(feature_size - col_offset, kMaxBlockNum); \ constexpr auto kBlockDim = (1 << (log2_block_dim)); \ __VA_ARGS__; \ } \ } break #define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum, \ ##__VA_ARGS__); \ FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum, \ ##__VA_ARGS__) static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); } static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); } template struct PairForLayerNorm { __device__ __forceinline__ PairForLayerNorm() {} __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second) : first_(first), second_(second) {} T first_; T second_; }; template struct PairForLayerNormAddFunctor { __device__ __forceinline__ PairForLayerNorm operator()( const PairForLayerNorm &p1, const PairForLayerNorm &p2) { return PairForLayerNorm(p1.first_ + p2.first_, p1.second_ + p2.second_); } }; template __inline__ __device__ T rsqrt_(const T val) { return static_cast(1) / sqrt(val); } template <> __inline__ __device__ float rsqrt_(const float val) { return rsqrtf(val); } template <> __inline__ __device__ double rsqrt_(const double val) { return rsqrt(val); } #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) template <> __inline__ __device__ half rsqrt_(const half val) { return hrsqrt(val); } #endif template __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, T *y, U *mean, U *var, float epsilon, int feature_size) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ U mean_share; __shared__ U var_share; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; // Step 1: Reduce to calculate mean and var U mean_val = 0; U var_val = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { U tmp = static_cast(x[i]); mean_val += tmp; var_val += (tmp * tmp); } auto pair = BlockReduce(temp_storage) .Reduce(PairForLayerNorm(mean_val, var_val), PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { auto tmp = pair.first_ / feature_size; mean[blockIdx.x] = mean_share = static_cast(tmp); var[blockIdx.x] = var_share = static_cast(pair.second_ / feature_size - tmp * tmp); } __syncthreads(); mean_val = mean_share; U invvar = rsqrt_(var_share + static_cast(epsilon)); // Step 2: Calculate y if (scale != nullptr) { if (bias != nullptr) { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast( scale[j] * (static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) * invvar); } } } else { // scale == nullptr if (bias != nullptr) { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar); } } } } template __inline__ __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, const int thr_load_col_off, const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, const T *input, const T *dout, const int i1_end, const int n2, const U *__restrict__ mean, const U *__restrict__ var, const float epsilon) { const int i1 = i1_block + thr_load_row_off; if (i1 >= i1_end) return; U curr_mean = mean[i1]; U curr_invvar = rsqrt_(var[i1] + epsilon); for (int k = 0; k < VPT; ++k) { const int i2 = i2_off + k; const int load_idx = i1 * n2 + i2; const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; if (i2 < n2) { U curr_input = static_cast(input[load_idx]); U curr_dout = static_cast(dout[load_idx]); warp_buf1[write_idx] += curr_dout; warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; } } } template __global__ void LayerNormBackwardPartGradGammaBeta( const T *__restrict__ dout, const T *__restrict__ input, const int n1, const int n2, const U *__restrict__ mean, const U *__restrict__ var, float epsilon, U *part_grad_gamma, U *part_grad_beta) { // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX // -> blockDim.x // template for compile time optimizations constexpr int row_stride = BDIMX + 1; const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1); const int thr_load_row_off = (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY; const int i2_off = blockIdx.x * BDIMX + thr_load_col_off; constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride) ? BDIMX * BDIMY : 2 * VPTX * BDIMY * row_stride; __shared__ U buf[shared_cap]; U *warp_buf1 = reinterpret_cast(buf); U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride; for (int idx = threadIdx.y * blockDim.x + threadIdx.x; idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) { buf[idx] = U(0); } __syncthreads(); for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; i1_block += VPTX * BDIMY * gridDim.y) { cuLoadAddStridedInputs( i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon); } __syncthreads(); // inter-warp reductions // sum within each warp U acc1 = U(0); U acc2 = U(0); for (int k = 0; k < VPTX; ++k) { int row1 = threadIdx.y + k * VPTX; int idx1 = row1 * row_stride + threadIdx.x; acc1 += warp_buf1[idx1]; acc2 += warp_buf2[idx1]; } warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; __syncthreads(); // sum all warps for (int offset = VPTX >> 1; offset > 1; offset >>= 1) { if (threadIdx.y < offset) { int row1 = threadIdx.y; int row2 = threadIdx.y + offset; int idx1 = row1 * row_stride + threadIdx.x; int idx2 = row2 * row_stride + threadIdx.x; warp_buf1[idx1] += warp_buf1[idx2]; warp_buf2[idx1] += warp_buf2[idx2]; } __syncthreads(); } int i2 = blockIdx.x * blockDim.x + threadIdx.x; if (threadIdx.y == 0 && i2 < n2) { int row1 = threadIdx.y; int row2 = threadIdx.y + 1; int idx1 = row1 * row_stride + threadIdx.x; int idx2 = row2 * row_stride + threadIdx.x; part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; } } template __global__ void LayerNormBackwardSumGradGammaBeta( const U *part_grad_gamma, const U *part_grad_beta, const int part_size, // const int n1, const int n2, T* grad_gamma, T* grad_beta) { const int n1, const int n2, U *grad_gamma, U *grad_beta) { // sum partial gradients for gamma and beta __shared__ U buf[BDIMX * BDIMY]; int i2 = blockIdx.x * BDIMX + threadIdx.x; if (i2 < n2) { // each warp does sequential reductions until reduced part_size is num_warps int num_warp_reductions = part_size / BDIMY; U sum_gamma = U(0); U sum_beta = U(0); const U *part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; const U *part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; sum_beta += part_grad_beta_ptr[warp_offset * n2]; } // inter-warp reductions constexpr int nbsize3 = BDIMX * BDIMY / 2; for (int offset = BDIMY / 2; offset >= 1; offset /= 2) { // top half write to shared memory if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; buf[write_idx] = sum_gamma; buf[write_idx + nbsize3] = sum_beta; } __syncthreads(); // bottom half sums if (threadIdx.y < offset) { const int read_idx = threadIdx.y * BDIMX + threadIdx.x; sum_gamma += buf[read_idx]; sum_beta += buf[read_idx + nbsize3]; } __syncthreads(); } // write out fully summed gradients if (threadIdx.y == 0) { grad_gamma[i2] = sum_gamma; grad_beta[i2] = sum_beta; } } } template __global__ void LayerNormBackwardComputeGradInput( const T *__restrict__ dout, const T *__restrict__ input, const int n1, const int n2, // const U* __restrict__ mean, const U* __restrict__ var, const float // epsilon, const T* gamma, const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, const U *gamma, T *grad_input) { for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; const U c_invvar = rsqrt_(var[i1] + epsilon); const T *k_input = input + i1 * n2; const T *k_dout = dout + i1 * n2; constexpr int numx = BDIMX * BDIMY; const int thrx = threadIdx.x + threadIdx.y * BDIMX; if (gamma != NULL) { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss * gamma[l + k]; sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; } } for (; l < n2; ++l) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss * gamma[l]; sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; } } else { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { const U c_h = static_cast(k_input[l + k]); const U c_loss = static_cast(k_dout[l + k]); sum_loss1 += c_loss; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; } } for (; l < n2; ++l) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); sum_loss1 += c_loss; sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; } } // intra-warp reductions for (int mask = BDIMX / 2; mask > 0; mask /= 2) { sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize); // WARP_SHFL_XOR(sum_loss1, mask); sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize); // WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions if (BDIMY > 1) { __shared__ U buf[BDIMX * BDIMY]; for (int offset = BDIMY / 2; offset > 0; offset /= 2) { // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x; buf[2 * wrt_i] = sum_loss1; buf[2 * wrt_i + 1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; sum_loss1 += buf[2 * read_i]; sum_loss2 += buf[2 * read_i + 1]; } __syncthreads(); } if (threadIdx.y == 0) { buf[2 * threadIdx.x] = sum_loss1; buf[2 * threadIdx.x + 1] = sum_loss2; } __syncthreads(); if (threadIdx.y != 0) { sum_loss1 = buf[2 * threadIdx.x]; sum_loss2 = buf[2 * threadIdx.x + 1]; } } // all threads now have the two sums over l U fH = (U)n2; U term1 = (U(1) / fH) * c_invvar; T *k_grad_input = grad_input + i1 * n2; if (gamma != NULL) { for (int l = thrx; l < n2; l += numx) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss * gamma[l]; f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } else { for (int l = thrx; l < n2; l += numx) { const U c_h = static_cast(k_input[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss; f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } } } // Make sure that d_scale != nullptr && d_bias != nullptr // Since d_scale != nullptr, scale would not be nullptr template __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, const U *var, const U *scale, float epsilon, int batch_size, int feature_size, int col_offset) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); int end_idx = batch_size * feature_size + (blockIdx.x + col_offset); int stride = BlockDim * feature_size; U d_scale_partial = static_cast(0), d_bias_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; auto var_val = real_sqrt(static_cast(var[row_idx]) + epsilon); d_scale_partial += static_cast(d_y[i]) * (static_cast(x[i]) - mean[row_idx]) / var_val; d_bias_partial += static_cast(d_y[i]); if (HasDx) { d_x[i] = static_cast(static_cast(d_y[i]) * scale[blockIdx.x + col_offset] / var_val); } } auto pair = BlockReduce(temp_storage) .Reduce(PairForLayerNorm(d_scale_partial, d_bias_partial), PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { d_scale[blockIdx.x + col_offset] = pair.first_; d_bias[blockIdx.x + col_offset] = pair.second_; } } // Make sure that there is only one true expression: d_scale != nullptr // or d_bias != nullptr // Notice: scale may be nullptr template __global__ void LayerNormBackwardGradientScaleOrBias( const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, const U *var, const U *scale, float epsilon, int batch_size, int feature_size, int col_offset) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; int end_idx = batch_size * feature_size + blockIdx.x + col_offset; int stride = BlockDim * feature_size; U d_scale_or_d_bias_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; auto var_val = static_cast(real_sqrt(static_cast(var[row_idx]) + epsilon)); if (HasDScale) { d_scale_or_d_bias_partial += static_cast(d_y[i]) * (static_cast(x[i]) - mean[row_idx]) / var_val; } else { // d_bias != nullptr d_scale_or_d_bias_partial += static_cast(d_y[i]); } if (HasDx) { if (scale != nullptr) { d_x[i] = static_cast(static_cast(d_y[i]) * scale[blockIdx.x + col_offset] / var_val); } else { d_x[i] = static_cast(static_cast(d_y[i]) / var_val); } } } d_scale_or_d_bias_partial = BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum()); if (threadIdx.x == 0) { if (HasDScale) { d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; } else { d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; } } } template __global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, const U *mean, const U *var, float epsilon, int feature_size) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ U d_x_reduce_tmp[2]; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; U block_mean = mean[blockIdx.x]; U block_var = var[blockIdx.x]; U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += BlockDim) { d_x_mean_partial += static_cast(d_x[i]); d_x_var_partial += static_cast(d_x[i]) * (static_cast(x[i]) - block_mean); } auto pair = BlockReduce(temp_storage) .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { d_x_reduce_tmp[0] = static_cast(pair.first_) / feature_size; d_x_reduce_tmp[1] = static_cast(pair.second_) / (feature_size * (static_cast(block_var) + epsilon)); } __syncthreads(); d_x_mean_partial = d_x_reduce_tmp[0]; d_x_var_partial = d_x_reduce_tmp[1]; for (int i = beg_idx; i < end_idx; i += BlockDim) { d_x[i] -= static_cast(d_x_mean_partial); d_x[i] -= static_cast((static_cast(x[i]) - block_mean) * d_x_var_partial); } } // Here, we only calculate d_x template __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, T *d_x, const U *mean, const U *var, const U *scale, float epsilon, int feature_size) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ U d_x_reduce_tmp[2]; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += BlockDim) { auto var_val = static_cast(real_sqrt(static_cast(block_var) + epsilon)); if (scale != nullptr) { int col_idx = i % feature_size; d_x[i] = static_cast(static_cast(d_y[i]) * scale[col_idx] / var_val); } else { d_x[i] = static_cast(static_cast(d_y[i]) / var_val); } d_x_mean_partial += static_cast(d_x[i]); d_x_var_partial += static_cast(d_x[i]) * (static_cast(x[i]) - block_mean); } auto pair = BlockReduce(temp_storage) .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { d_x_reduce_tmp[0] = static_cast(pair.first_) / feature_size; d_x_reduce_tmp[1] = static_cast(pair.second_) / (feature_size * (static_cast(block_var) + epsilon)); } __syncthreads(); d_x_mean_partial = d_x_reduce_tmp[0]; d_x_var_partial = d_x_reduce_tmp[1]; for (int i = beg_idx; i < end_idx; i += BlockDim) { d_x[i] -= static_cast(d_x_mean_partial); d_x[i] -= static_cast((static_cast(x[i]) - block_mean) * d_x_var_partial); } } template __global__ void LayerNormBackwardWhenBatchSizeIsOne( const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean, const U *var, const U *scale, float epsilon, int feature_size) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < feature_size) { auto var_val = static_cast(real_sqrt(static_cast(var[idx]) + epsilon)); if (d_x != nullptr) { if (d_scale == nullptr) { d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); } else { d_x[idx] = static_cast(static_cast(d_y[idx]) * scale[idx] / var_val); } } if (d_scale != nullptr) { d_scale[idx] = static_cast(d_y[idx]) * (static_cast(x[idx]) - mean[idx]) / var_val; } if (d_bias != nullptr) d_bias[idx] = static_cast(d_y[idx]); } } template static void LayerNormBackward(const T *x, const T *d_y, const U *scale, const U *mean, const U *var, T *d_x, U *d_scale, U *d_bias, float epsilon, int batch_size, int feature_size, const framework::ExecutionContext &ctx) { auto &dev_ctx = ctx.cuda_device_context(); auto stream = dev_ctx.stream(); const int kMaxBlockDim = 512; const int kMaxBlockNum = 128; int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) | ((d_scale != nullptr ? 1 : 0) << 1) | ((d_bias != nullptr ? 1 : 0)); if (gradient_flag == 0) return; if (batch_size == 1) { LayerNormBackwardWhenBatchSizeIsOne< T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon, feature_size); if (d_x != nullptr) { switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX< T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>( x, d_x, mean, var, epsilon, feature_size)); } } return; } auto block_dim = GetDesiredBlockDim(batch_size); switch (gradient_flag) { case 1: // d_x == nulptr, d_scale == nullptr, d_bias != nullptr switch (block_dim) { FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< T, U, kBlockDim, false, false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } break; case 2: // d_x == nullptr, d_scale != nullptr, d_bias == nullptr switch (block_dim) { FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< T, U, kBlockDim, false, true><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } break; case 3: // d_x == nullptr, d_scale != nulptr, d_bias != nullptr switch (block_dim) { FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientAll< T, U, kBlockDim, false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } break; case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardGradientOnlyDX< T, U, kBlockDim><<>>( x, d_y, d_x, mean, var, scale, epsilon, feature_size)); } break; case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr switch (block_dim) { FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< T, U, kBlockDim, true, false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardPostProcessToCalculateDX< T, U, kBlockDim><<>>( x, d_x, mean, var, epsilon, feature_size)); } break; case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr switch (block_dim) { FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< T, U, kBlockDim, true, true><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardPostProcessToCalculateDX< T, U, kBlockDim><<>>( x, d_x, mean, var, epsilon, feature_size)); } break; case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr { constexpr int VPT = 4; constexpr int BDIMX2 = 32; constexpr int BDIMY2 = 4; dim3 threads2(BDIMX2, BDIMY2, 1); constexpr int part_size = BDIMY2 * VPT; const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); auto part_grad_gamma_ptr = memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); auto part_grad_beta_ptr = memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); U *part_grad_gamma = reinterpret_cast(part_grad_gamma_ptr->ptr()); U *part_grad_beta = reinterpret_cast(part_grad_beta_ptr->ptr()); LayerNormBackwardPartGradGammaBeta<<>>( d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma, part_grad_beta); // compute part_grad_gamma, beta constexpr int BDIMX3 = 32; constexpr int BDIMY3 = 8; dim3 threads3(BDIMX3, BDIMY3, 1); const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); LayerNormBackwardSumGradGammaBeta< T, U, BDIMX3, BDIMY3><<>>( part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, d_scale, d_bias); constexpr int BDIMX1 = 32; constexpr int BDIMY1 = 4; dim3 threads1(BDIMX1, BDIMY1, 1); const dim3 blocks1(1, batch_size, 1); LayerNormBackwardComputeGradInput< T, U, BDIMX1, BDIMY1><<>>( d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); break; } default: break; } } template void LayerNormDirectCUDAFunctor::operator()(cudaStream_t stream, const T *input, std::vector input_shape, const T *bias, const T *scale, T *output, T *mean, T *variance, int begin_norm_axis, float eps) { const auto x_dims = framework::make_ddim(input_shape); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int batch_size = static_cast(matrix_dim[0]); int feature_size = static_cast(matrix_dim[1]); switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormForward<<>>( input, scale, bias, output, mean, variance, eps, feature_size)); default: PADDLE_THROW(platform::errors::InvalidArgument( "Product from begin_norm_axis to end in layer_norm must be larger " "than 1")); break; } } template class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); auto *scale = ctx.Input("Scale"); auto *bias = ctx.Input("Bias"); auto *x = ctx.Input("X"); auto *y = ctx.Output("Y"); auto *mean = ctx.Output("Mean"); auto *var = ctx.Output("Variance"); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto x_dims = x->dims(); auto *x_data = x->data(); auto *y_data = y->mutable_data(ctx.GetPlace()); auto *mean_data = mean->mutable_data(ctx.GetPlace()); auto *var_data = var->mutable_data(ctx.GetPlace()); auto *scale_data = (scale == nullptr ? nullptr : scale->data()); auto *bias_data = (bias == nullptr ? nullptr : bias->data()); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int batch_size = static_cast(matrix_dim[0]); int feature_size = static_cast(matrix_dim[1]); auto stream = ctx.cuda_device_context().stream(); switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormForward<<>>( x_data, scale_data, bias_data, y_data, mean_data, var_data, epsilon, feature_size)); default: PADDLE_THROW(platform::errors::InvalidArgument( "Product from begin_norm_axis to end must be larger than 1")); break; } } }; template class LayerNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); // d_x, d_scale, d_bias may be nullptr auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); auto *x = ctx.Input("X"); auto *mean = ctx.Input("Mean"); auto *var = ctx.Input("Variance"); auto *scale = ctx.Input("Scale"); auto *d_y = ctx.Input(framework::GradVarName("Y")); auto *x_data = x->data(); auto *d_y_data = d_y->data(); auto *mean_data = mean->data(); auto *var_data = var->data(); auto *scale_data = (scale == nullptr ? nullptr : scale->data()); auto *d_scale_data = (d_scale == nullptr ? nullptr : d_scale->mutable_data(ctx.GetPlace())); auto *d_bias_data = (d_bias == nullptr ? nullptr : d_bias->mutable_data(ctx.GetPlace())); auto *d_x_data = (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); const auto &x_dims = x->dims(); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int batch_size = static_cast(matrix_dim[0]); int feature_size = static_cast(matrix_dim[1]); LayerNormBackward(x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, d_scale_data, d_bias_data, epsilon, batch_size, feature_size, ctx); } }; template class LayerNormDirectCUDAFunctor; #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE #undef FIXED_BLOCK_DIM_CASE_BASE #undef FIXED_BLOCK_DIM_CASE } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, ops::LayerNormKernel, ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, ops::LayerNormGradKernel, ops::LayerNormGradKernel);