diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index bc8860eaa055e391c9b5212eb2447d55a932873e..d5a57dd9ddcad9d70a257738a05b3e5025a2264e 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -107,35 +107,54 @@ struct PairForLayerNormAddFunctor { } }; +template +__inline__ __device__ T rsqrt(const T val) { + return ::rsqrt(val); +} + +template <> +__inline__ __device__ float rsqrt(const float val) { + return rsqrtf(val); +} + +template <> +__inline__ __device__ half rsqrt(const half val) { + return hrsqrt(val); +} + template __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, T *y, U *mean, U *var, float epsilon, int feature_size) { - using BlockReduce = cub::BlockReduce, BlockDim>; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mean_share; + __shared__ U var_share; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; // Step 1: Reduce to calculate mean and var - double mean_val = 0; - double var_val = 0; + U mean_val = 0; + U var_val = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { U tmp = static_cast(x[i]); mean_val += tmp; var_val += (tmp * tmp); } auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(mean_val, var_val), - PairForLayerNormAddFunctor()); + .Reduce(PairForLayerNorm(mean_val, var_val), + PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { auto tmp = pair.first_ / feature_size; - mean[blockIdx.x] = static_cast(tmp); - var[blockIdx.x] = static_cast(pair.second_ / feature_size - tmp * tmp); + mean[blockIdx.x] = mean_share = static_cast(tmp); + var[blockIdx.x] = var_share = + static_cast(pair.second_ / feature_size - tmp * tmp); } __syncthreads(); - mean_val = mean[blockIdx.x]; - var_val = static_cast(real_sqrt(var[blockIdx.x] + epsilon)); + + mean_val = mean_share; + U invvar = rsqrt(var_share + static_cast(epsilon)); // Step 2: Calculate y if (scale != nullptr) { @@ -143,26 +162,288 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { y[i] = static_cast( - scale[j] * (static_cast(x[i]) - mean_val) / var_val + bias[j]); + scale[j] * (static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) / - var_val); + y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) * + invvar); } } } else { // scale == nullptr if (bias != nullptr) { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast((static_cast(x[i]) - mean_val) / var_val + + y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = static_cast((static_cast(x[i]) - mean_val) / var_val); + y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar); + } + } + } +} + +template +__inline__ __device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, + const T *input, const T *dout, const int i1_end, const int n2, + const U *__restrict__ mean, const U *__restrict__ var, + const float epsilon) { + const int i1 = i1_block + thr_load_row_off; + if (i1 >= i1_end) return; + U curr_mean = mean[i1]; + U curr_invvar = rsqrt(var[i1] + epsilon); + for (int k = 0; k < VPT; ++k) { + const int i2 = i2_off + k; + const int load_idx = i1 * n2 + i2; + const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } +} + +template +__global__ void LayerNormBackwardPartGradGammaBeta( + const T *__restrict__ dout, const T *__restrict__ input, const int n1, + const int n2, const U *__restrict__ mean, const U *__restrict__ var, + float epsilon, U *part_grad_gamma, U *part_grad_beta) { + // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX + // -> blockDim.x + // template for compile time optimizations + + constexpr int row_stride = BDIMX + 1; + const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1); + const int thr_load_row_off = + (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY; + const int i2_off = blockIdx.x * BDIMX + thr_load_col_off; + + constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride) + ? BDIMX * BDIMY + : 2 * VPTX * BDIMY * row_stride; + __shared__ U buf[shared_cap]; + + U *warp_buf1 = reinterpret_cast(buf); + U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride; + + for (int idx = threadIdx.y * blockDim.x + threadIdx.x; + idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) { + buf[idx] = U(0); + } + __syncthreads(); + + for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; + i1_block += VPTX * BDIMY * gridDim.y) { + cuLoadAddStridedInputs( + i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, + warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon); + } + __syncthreads(); + + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < VPTX; ++k) { + int row1 = threadIdx.y + k * VPTX; + int idx1 = row1 * row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = VPTX >> 1; offset > 1; offset >>= 1) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void LayerNormBackwardSumGradGammaBeta( + const U *part_grad_gamma, const U *part_grad_beta, const int part_size, + // const int n1, const int n2, T* grad_gamma, T* grad_beta) { + const int n1, const int n2, U *grad_gamma, U *grad_beta) { + // sum partial gradients for gamma and beta + __shared__ U buf[BDIMX * BDIMY]; + int i2 = blockIdx.x * BDIMX + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / BDIMY; + U sum_gamma = U(0); + U sum_beta = U(0); + const U *part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U *part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + // inter-warp reductions + constexpr int nbsize3 = BDIMX * BDIMY / 2; + for (int offset = BDIMY / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx + nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * BDIMX + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx + nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template +__global__ void LayerNormBackwardComputeGradInput( + const T *__restrict__ dout, const T *__restrict__ input, const int n1, + const int n2, + // const U* __restrict__ mean, const U* __restrict__ var, const float + // epsilon, const T* gamma, + const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, + const U *gamma, T *grad_input) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = rsqrt(var[i1] + epsilon); + const T *k_input = input + i1 * n2; + const T *k_dout = dout + i1 * n2; + constexpr int numx = BDIMX * BDIMY; + const int thrx = threadIdx.x + threadIdx.y * BDIMX; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = BDIMX / 2; mask > 0; mask /= 2) { + sum_loss1 += + __shfl_xor_sync(0xffffffff, sum_loss1, mask, + warpSize); // WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += + __shfl_xor_sync(0xffffffff, sum_loss2, mask, + warpSize); // WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (BDIMY > 1) { + __shared__ U buf[BDIMX * BDIMY]; + for (int offset = BDIMY / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x; + buf[2 * wrt_i] = sum_loss1; + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2 * read_i]; + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2 * threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + sum_loss1 = buf[2 * threadIdx.x]; + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T *k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); } } } @@ -384,7 +665,11 @@ template static void LayerNormBackward(const T *x, const T *d_y, const U *scale, const U *mean, const U *var, T *d_x, U *d_scale, U *d_bias, float epsilon, int batch_size, - int feature_size, cudaStream_t stream) { + int feature_size, + const framework::ExecutionContext &ctx) { + auto &dev_ctx = ctx.cuda_device_context(); + auto stream = dev_ctx.stream(); + const int kMaxBlockDim = 512; const int kMaxBlockNum = 128; int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) | @@ -485,21 +770,44 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, } break; case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr - switch (block_dim) { - FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( - feature_size, kMaxBlockNum, - LayerNormBackwardGradientAll< - T, U, kBlockDim, true><<>>( - x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, - batch_size, feature_size, col_offset)); - } - switch (GetDesiredBlockDim(feature_size)) { - FIXED_BLOCK_DIM_CASE( - LayerNormBackwardPostProcessToCalculateDX< - T, U, kBlockDim><<>>( - x, d_x, mean, var, epsilon, feature_size)); - } + { + constexpr int VPT = 4; + constexpr int BDIMX2 = 32; + constexpr int BDIMY2 = 4; + dim3 threads2(BDIMX2, BDIMY2, 1); + constexpr int part_size = BDIMY2 * VPT; + const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); + + auto part_grad_gamma_ptr = + memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); + auto part_grad_beta_ptr = + memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); + U *part_grad_gamma = reinterpret_cast(part_grad_gamma_ptr->ptr()); + U *part_grad_beta = reinterpret_cast(part_grad_beta_ptr->ptr()); + + LayerNormBackwardPartGradGammaBeta<<>>( + d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma, + part_grad_beta); // compute part_grad_gamma, beta + + constexpr int BDIMX3 = 32; + constexpr int BDIMY3 = 8; + dim3 threads3(BDIMX3, BDIMY3, 1); + const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); + LayerNormBackwardSumGradGammaBeta< + T, U, BDIMX3, BDIMY3><<>>( + part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, + d_scale, d_bias); + + constexpr int BDIMX1 = 32; + constexpr int BDIMY1 = 4; + dim3 threads1(BDIMX1, BDIMY1, 1); + const dim3 blocks1(1, batch_size, 1); + LayerNormBackwardComputeGradInput< + T, U, BDIMX1, BDIMY1><<>>( + d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); break; + } default: break; } @@ -611,11 +919,9 @@ class LayerNormGradKernel int batch_size = static_cast(matrix_dim[0]); int feature_size = static_cast(matrix_dim[1]); - auto stream = ctx.cuda_device_context().stream(); - LayerNormBackward(x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, d_scale_data, d_bias_data, epsilon, - batch_size, feature_size, stream); + batch_size, feature_size, ctx); } };