diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4c3b8ec78190723598a56f7633764f10dd5047f3..b395739809dbd187e36c28b3f609a0a08c839643 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -273,9 +273,9 @@ op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) - if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) + op_library(layer_norm_op DEPS cub) else() op_library(conv_op DEPS vol2col im2col) endif() diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index b2900cbf352636acc0db1bbb86778be4cea0aacf..0886c41a1b582881faf24f5531d414db4e4db71c 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -45,38 +45,55 @@ inline static int GetDesiredBlockDim(int block_dim) { static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); } static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); } +template +struct PairForLayerNorm { + __device__ __forceinline__ PairForLayerNorm() {} + __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second) + : first_(first), second_(second) {} + + T first_; + T second_; +}; + +template +struct PairForLayerNormAddFunctor { + __device__ __forceinline__ PairForLayerNorm operator()( + const PairForLayerNorm &p1, const PairForLayerNorm &p2) { + return PairForLayerNorm(p1.first_ + p2.first_, p1.second_ + p2.second_); + } +}; + template __global__ void LayerNormForward(const T *x, const T *scale, const T *bias, T *y, T *mean, T *var, float epsilon, int feature_size) { - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; - // Step 1: Reduce to calculate mean + // Step 1: Reduce to calculate mean and var T mean_val = static_cast(0); - for (int i = beg_idx; i < end_idx; i += BlockDim) { - mean_val += x[i]; - } - mean_val = BlockReduce(temp_storage).Reduce(mean_val, cub::Sum()); - if (threadIdx.x == 0) mean[blockIdx.x] = mean_val / feature_size; - __syncthreads(); - mean_val = mean[blockIdx.x]; - - // Step 2: Reduce to calculate var T var_val = static_cast(0); for (int i = beg_idx; i < end_idx; i += BlockDim) { - T tmp = x[i] - mean_val; + T tmp = x[i]; + mean_val += tmp; var_val += (tmp * tmp); } - var_val = BlockReduce(temp_storage).Reduce(var_val, cub::Sum()); - if (threadIdx.x == 0) var[blockIdx.x] = var_val / feature_size; + auto pair = BlockReduce(temp_storage) + .Reduce(PairForLayerNorm(mean_val, var_val), + PairForLayerNormAddFunctor()); + if (threadIdx.x == 0) { + auto tmp = pair.first_ / feature_size; + mean[blockIdx.x] = tmp; + var[blockIdx.x] = pair.second_ / feature_size - tmp * tmp; + } __syncthreads(); + mean_val = mean[blockIdx.x]; var_val = static_cast(real_sqrt(var[blockIdx.x] + epsilon)); - // Step 3: Calculate y + // Step 2: Calculate y if (scale != nullptr) { if (bias != nullptr) { for (int i = beg_idx, j = threadIdx.x; i < end_idx; @@ -104,27 +121,6 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias, } } -template -struct PairForLayerNormBackward { - __device__ __forceinline__ PairForLayerNormBackward() {} - __device__ __forceinline__ PairForLayerNormBackward(const T &first, - const T &second) - : first_(first), second_(second) {} - - T first_; - T second_; -}; - -template -struct PairForLayerNormBackwardAddFunctor { - __device__ __forceinline__ PairForLayerNormBackward operator()( - const PairForLayerNormBackward &p1, - const PairForLayerNormBackward &p2) { - return PairForLayerNormBackward(p1.first_ + p2.first_, - p1.second_ + p2.second_); - } -}; - // Make sure that d_scale != nullptr && d_bias != nullptr // Since d_scale != nullptr, scale would not be nullptr template @@ -133,12 +129,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, const T *mean, const T *var, const T *scale, float epsilon, int batch_size, int feature_size) { - using BlockReduce = cub::BlockReduce, BlockDim>; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; int beg_idx = threadIdx.x * feature_size + blockIdx.x; int end_idx = batch_size * feature_size + blockIdx.x; int stride = BlockDim * feature_size; + T d_scale_partial = 0, d_bias_partial = 0; for (int i = beg_idx; i < end_idx; i += stride) { @@ -146,13 +143,14 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, auto var_val = static_cast(real_sqrt(var[row_idx] + epsilon)); d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; d_bias_partial += d_y[i]; - if (HasDx) d_x[i] = d_y[i] * scale[blockIdx.x] / var_val; + if (HasDx) { + d_x[i] = d_y[i] * scale[blockIdx.x] / var_val; + } } - auto pair = - BlockReduce(temp_storage) - .Reduce(PairForLayerNormBackward(d_scale_partial, d_bias_partial), - PairForLayerNormBackwardAddFunctor()); + auto pair = BlockReduce(temp_storage) + .Reduce(PairForLayerNorm(d_scale_partial, d_bias_partial), + PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { d_scale[blockIdx.x] = pair.first_; @@ -205,22 +203,90 @@ __global__ void LayerNormBackwardGradientScaleOrBias( } } +template +__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, + const T *mean, + const T *var, + float epsilon, + int feature_size) { + using BlockReduce = cub::BlockReduce, BlockDim>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T d_x_reduce_tmp[2]; + + int beg_idx = blockIdx.x * feature_size + threadIdx.x; + int end_idx = (blockIdx.x + 1) * feature_size; + + T block_mean = mean[blockIdx.x]; + T block_var = var[blockIdx.x]; + T d_x_mean_partial = 0, d_x_var_partial = 0; + for (int i = beg_idx; i < end_idx; i += BlockDim) { + d_x_mean_partial += d_x[i]; + d_x_var_partial += d_x[i] * (x[i] - block_mean); + } + + auto pair = + BlockReduce(temp_storage) + .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), + PairForLayerNormAddFunctor()); + + if (threadIdx.x == 0) { + d_x_reduce_tmp[0] = pair.first_ / feature_size; + d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon)); + } + __syncthreads(); + + d_x_mean_partial = d_x_reduce_tmp[0]; + d_x_var_partial = d_x_reduce_tmp[1]; + for (int i = beg_idx; i < end_idx; i += BlockDim) { + d_x[i] -= d_x_mean_partial; + d_x[i] -= (x[i] - block_mean) * d_x_var_partial; + } +} + // Here, we only calculate d_x -template -__global__ void LayerNormBackwardGradientOnlyX(const T *d_y, T *d_x, - const T *var, const T *scale, - float epsilon, int batch_size, - int feature_size) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < batch_size * feature_size) { - int row_idx = idx / feature_size; - auto var_val = static_cast(real_sqrt(var[row_idx] + epsilon)); +template +__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, + T *d_x, const T *mean, + const T *var, const T *scale, + float epsilon, + int feature_size) { + using BlockReduce = cub::BlockReduce, BlockDim>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T d_x_reduce_tmp[2]; + + int beg_idx = blockIdx.x * feature_size + threadIdx.x; + int end_idx = (blockIdx.x + 1) * feature_size; + + T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; + T d_x_mean_partial = 0, d_x_var_partial = 0; + for (int i = beg_idx; i < end_idx; i += BlockDim) { + auto var_val = static_cast(real_sqrt(block_var + epsilon)); if (scale != nullptr) { - int col_idx = idx % feature_size; - d_x[idx] = d_y[idx] * scale[col_idx] / var_val; + int col_idx = i % feature_size; + d_x[i] = d_y[i] * scale[col_idx] / var_val; } else { - d_x[idx] = d_y[idx] / var_val; + d_x[i] = d_y[i] / var_val; } + d_x_mean_partial += d_x[i]; + d_x_var_partial += d_x[i] * (x[i] - block_mean); + } + + auto pair = + BlockReduce(temp_storage) + .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), + PairForLayerNormAddFunctor()); + + if (threadIdx.x == 0) { + d_x_reduce_tmp[0] = pair.first_ / feature_size; + d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon)); + } + __syncthreads(); + + d_x_mean_partial = d_x_reduce_tmp[0]; + d_x_var_partial = d_x_reduce_tmp[1]; + for (int i = beg_idx; i < end_idx; i += BlockDim) { + d_x[i] -= d_x_mean_partial; + d_x[i] -= (x[i] - block_mean) * d_x_var_partial; } } @@ -263,6 +329,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon, feature_size); + + if (d_x != nullptr) { + switch (GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX< + T, kBlockDim><<<1, kBlockDim, 0, stream>>>( + x, d_x, mean, var, epsilon, feature_size)); + } + } return; } @@ -296,10 +370,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, } break; case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr - LayerNormBackwardGradientOnlyX< - T><<<(batch_size * feature_size + kMaxBlockDim - 1) / kMaxBlockDim, - kMaxBlockDim, 0, stream>>>(d_y, d_x, var, scale, epsilon, - batch_size, feature_size); + switch (GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE( + LayerNormBackwardGradientOnlyDX< + T, kBlockDim><<>>( + x, d_y, d_x, mean, var, scale, epsilon, feature_size)); + } break; case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr switch (block_dim) { @@ -309,6 +385,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size)); } + switch (GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE( + LayerNormBackwardPostProcessToCalculateDX< + T, kBlockDim><<>>( + x, d_x, mean, var, epsilon, feature_size)); + } break; case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr switch (block_dim) { @@ -318,6 +400,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size)); } + switch (GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE( + LayerNormBackwardPostProcessToCalculateDX< + T, kBlockDim><<>>( + x, d_x, mean, var, epsilon, feature_size)); + } break; case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr switch (block_dim) { @@ -327,6 +415,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size)); } + switch (GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE( + LayerNormBackwardPostProcessToCalculateDX< + T, kBlockDim><<>>( + x, d_x, mean, var, epsilon, feature_size)); + } break; default: break;