From 44b81e633b87ed2366649f13523d1a4b24082ad6 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Thu, 7 Jan 2021 16:07:55 +0800 Subject: [PATCH] [Cherry-pick] Layer norm fp16 and Nvidia optimize (#29169 #29434 #29522 #29576) (#30110) * Layer norm fp16 (#29169) * add fp16 for layer_norm op * revert layernorm api * fix forward * fix forward * fix backward for layernorm with fp16 * fix unit test for layernorm with fp16 * fix with_mkldnn compile error for layernorm with fp16 * 1. revert to PADDLE_ENFORCE_NOT_NULL, 2. change static_cast to static_cast * fix with_mkldnn compile error for layernorm with fp16 * fix with_mkldnn compile error for layernorm with fp16 Co-authored-by: zhiqiu * fix layer_norm accuracy (#29434) * Layernorm opt (#29522) * layernorm fw opt * layernorm bw opt * fix typo, test=develop * remove const dim3 for windows CI compatibility * merge develop Co-authored-by: zlsh80826 * Fix compile problem when cuda_arch < 6000 (#29576) * fix compile problem when cuda_arch < 6000 * refine code * refine code Co-authored-by: zhiqiu Co-authored-by: zlsh80826 --- paddle/fluid/operators/layer_norm_op.cc | 35 +- paddle/fluid/operators/layer_norm_op.cu | 618 ++++++++++++++---- .../contrib/mixed_precision/fp16_lists.py | 4 +- .../contrib/mixed_precision/fp16_utils.py | 7 +- .../tests/unittests/test_layer_norm_op.py | 3 + python/paddle/nn/functional/norm.py | 11 +- 6 files changed, 535 insertions(+), 143 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 6f83a667a5..23de34bc6f 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/layer_norm_op.h" #include +#include #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { + const framework::ExecutionContext &ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). + auto ln_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + ln_param_type = framework::proto::VarType::FP64; + } + if (ctx.HasInput("Scale")) { + PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input("Scale")->type(), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + } + if (ctx.HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input("Bias")->type(), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + } + framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; @@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel { } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); } }; @@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel { } PADDLE_ENFORCE_NOT_NULL( t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); - return framework::OpKernelType(t->type(), ctx.GetPlace()); + + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 30bafb5c13..ad15b18d7f 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -15,12 +15,22 @@ limitations under the License. */ #include #include #include + #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; +template +using CudnnDataType = platform::CudnnDataType; +template +using LayerNormParamType = typename CudnnDataType::BatchNormParamType; + inline static int GetDesiredBlockDim(int block_dim) { const int kMaxBlockDim = 512; return block_dim >= kMaxBlockDim @@ -97,59 +107,350 @@ struct PairForLayerNormAddFunctor { } }; -template -__global__ void LayerNormForward(const T *x, const T *scale, const T *bias, - T *y, T *mean, T *var, float epsilon, +template +__inline__ __device__ T rsqrt(const T val) { + return static_cast(1) / sqrt(val); +} + +template <> +__inline__ __device__ float rsqrt(const float val) { + return rsqrtf(val); +} + +template <> +__inline__ __device__ double rsqrt(const double val) { + return rsqrt(val); +} + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +template <> +__inline__ __device__ half rsqrt(const half val) { + return hrsqrt(val); +} +#endif + +template +__global__ void LayerNormForward(const T *x, const U *scale, const U *bias, + T *y, U *mean, U *var, float epsilon, int feature_size) { - using BlockReduce = cub::BlockReduce, BlockDim>; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mean_share; + __shared__ U var_share; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; // Step 1: Reduce to calculate mean and var - double mean_val = 0; - double var_val = 0; + U mean_val = 0; + U var_val = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { - T tmp = x[i]; + U tmp = static_cast(x[i]); mean_val += tmp; var_val += (tmp * tmp); } auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(mean_val, var_val), - PairForLayerNormAddFunctor()); + .Reduce(PairForLayerNorm(mean_val, var_val), + PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { auto tmp = pair.first_ / feature_size; - mean[blockIdx.x] = static_cast(tmp); - var[blockIdx.x] = static_cast(pair.second_ / feature_size - tmp * tmp); + mean[blockIdx.x] = mean_share = static_cast(tmp); + var[blockIdx.x] = var_share = + static_cast(pair.second_ / feature_size - tmp * tmp); } __syncthreads(); - mean_val = mean[blockIdx.x]; - var_val = static_cast(real_sqrt(var[blockIdx.x] + epsilon)); + + mean_val = mean_share; + U invvar = rsqrt(var_share + static_cast(epsilon)); // Step 2: Calculate y if (scale != nullptr) { if (bias != nullptr) { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = scale[j] * (x[i] - mean_val) / var_val + bias[j]; + y[i] = static_cast( + scale[j] * (static_cast(x[i]) - mean_val) * invvar + bias[j]); } } else { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = scale[j] * (x[i] - mean_val) / var_val; + y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) * + invvar); } } } else { // scale == nullptr if (bias != nullptr) { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = (x[i] - mean_val) / var_val + bias[j]; + y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + + bias[j]); } } else { for (int i = beg_idx, j = threadIdx.x; i < end_idx; i += BlockDim, j += BlockDim) { - y[i] = (x[i] - mean_val) / var_val; + y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar); + } + } + } +} + +template +__inline__ __device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, + const T *input, const T *dout, const int i1_end, const int n2, + const U *__restrict__ mean, const U *__restrict__ var, + const float epsilon) { + const int i1 = i1_block + thr_load_row_off; + if (i1 >= i1_end) return; + U curr_mean = mean[i1]; + U curr_invvar = rsqrt(var[i1] + epsilon); + for (int k = 0; k < VPT; ++k) { + const int i2 = i2_off + k; + const int load_idx = i1 * n2 + i2; + const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } +} + +template +__global__ void LayerNormBackwardPartGradGammaBeta( + const T *__restrict__ dout, const T *__restrict__ input, const int n1, + const int n2, const U *__restrict__ mean, const U *__restrict__ var, + float epsilon, U *part_grad_gamma, U *part_grad_beta) { + // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX + // -> blockDim.x + // template for compile time optimizations + + constexpr int row_stride = BDIMX + 1; + const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1); + const int thr_load_row_off = + (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY; + const int i2_off = blockIdx.x * BDIMX + thr_load_col_off; + + constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride) + ? BDIMX * BDIMY + : 2 * VPTX * BDIMY * row_stride; + __shared__ U buf[shared_cap]; + + U *warp_buf1 = reinterpret_cast(buf); + U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride; + + for (int idx = threadIdx.y * blockDim.x + threadIdx.x; + idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) { + buf[idx] = U(0); + } + __syncthreads(); + + for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; + i1_block += VPTX * BDIMY * gridDim.y) { + cuLoadAddStridedInputs( + i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, + warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon); + } + __syncthreads(); + + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < VPTX; ++k) { + int row1 = threadIdx.y + k * VPTX; + int idx1 = row1 * row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = VPTX >> 1; offset > 1; offset >>= 1) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void LayerNormBackwardSumGradGammaBeta( + const U *part_grad_gamma, const U *part_grad_beta, const int part_size, + // const int n1, const int n2, T* grad_gamma, T* grad_beta) { + const int n1, const int n2, U *grad_gamma, U *grad_beta) { + // sum partial gradients for gamma and beta + __shared__ U buf[BDIMX * BDIMY]; + int i2 = blockIdx.x * BDIMX + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / BDIMY; + U sum_gamma = U(0); + U sum_beta = U(0); + const U *part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U *part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + // inter-warp reductions + constexpr int nbsize3 = BDIMX * BDIMY / 2; + for (int offset = BDIMY / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx + nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * BDIMX + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx + nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template +__global__ void LayerNormBackwardComputeGradInput( + const T *__restrict__ dout, const T *__restrict__ input, const int n1, + const int n2, + // const U* __restrict__ mean, const U* __restrict__ var, const float + // epsilon, const T* gamma, + const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, + const U *gamma, T *grad_input) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = rsqrt(var[i1] + epsilon); + const T *k_input = input + i1 * n2; + const T *k_dout = dout + i1 * n2; + constexpr int numx = BDIMX * BDIMY; + const int thrx = threadIdx.x + threadIdx.y * BDIMX; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = BDIMX / 2; mask > 0; mask /= 2) { + sum_loss1 += + __shfl_xor_sync(0xffffffff, sum_loss1, mask, + warpSize); // WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += + __shfl_xor_sync(0xffffffff, sum_loss2, mask, + warpSize); // WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (BDIMY > 1) { + __shared__ U buf[BDIMX * BDIMY]; + for (int offset = BDIMY / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x; + buf[2 * wrt_i] = sum_loss1; + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2 * read_i]; + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2 * threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + sum_loss1 = buf[2 * threadIdx.x]; + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T *k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); } } } @@ -157,35 +458,37 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias, // Make sure that d_scale != nullptr && d_bias != nullptr // Since d_scale != nullptr, scale would not be nullptr -template +template __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, - T *d_scale, T *d_bias, T *d_x, - const T *mean, const T *var, - const T *scale, float epsilon, + U *d_scale, U *d_bias, T *d_x, + const U *mean, const U *var, + const U *scale, float epsilon, int batch_size, int feature_size, int col_offset) { - using BlockReduce = cub::BlockReduce, BlockDim>; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); int end_idx = batch_size * feature_size + (blockIdx.x + col_offset); int stride = BlockDim * feature_size; - T d_scale_partial = 0, d_bias_partial = 0; + U d_scale_partial = static_cast(0), d_bias_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; - auto var_val = static_cast(real_sqrt(var[row_idx] + epsilon)); - d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; - d_bias_partial += d_y[i]; + auto var_val = real_sqrt(static_cast(var[row_idx]) + epsilon); + d_scale_partial += static_cast(d_y[i]) * + (static_cast(x[i]) - mean[row_idx]) / var_val; + d_bias_partial += static_cast(d_y[i]); if (HasDx) { - d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val; + d_x[i] = static_cast(static_cast(d_y[i]) * + scale[blockIdx.x + col_offset] / var_val); } } auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(d_scale_partial, d_bias_partial), - PairForLayerNormAddFunctor()); + .Reduce(PairForLayerNorm(d_scale_partial, d_bias_partial), + PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { d_scale[blockIdx.x + col_offset] = pair.first_; @@ -196,32 +499,36 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, // Make sure that there is only one true expression: d_scale != nullptr // or d_bias != nullptr // Notice: scale may be nullptr -template +template __global__ void LayerNormBackwardGradientScaleOrBias( - const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean, - const T *var, const T *scale, float epsilon, int batch_size, + const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean, + const U *var, const U *scale, float epsilon, int batch_size, int feature_size, int col_offset) { - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; int end_idx = batch_size * feature_size + blockIdx.x + col_offset; int stride = BlockDim * feature_size; - T d_scale_or_d_bias_partial = 0; + U d_scale_or_d_bias_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += stride) { int row_idx = i / feature_size; - auto var_val = static_cast(real_sqrt(var[row_idx] + epsilon)); + auto var_val = + static_cast(real_sqrt(static_cast(var[row_idx]) + epsilon)); if (HasDScale) { - d_scale_or_d_bias_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; + d_scale_or_d_bias_partial += static_cast(d_y[i]) * + (static_cast(x[i]) - mean[row_idx]) / + var_val; } else { // d_bias != nullptr - d_scale_or_d_bias_partial += d_y[i]; + d_scale_or_d_bias_partial += static_cast(d_y[i]); } if (HasDx) { if (scale != nullptr) { - d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val; + d_x[i] = static_cast(static_cast(d_y[i]) * + scale[blockIdx.x + col_offset] / var_val); } else { - d_x[i] = d_y[i] / var_val; + d_x[i] = static_cast(static_cast(d_y[i]) / var_val); } } } @@ -238,121 +545,138 @@ __global__ void LayerNormBackwardGradientScaleOrBias( } } -template +template __global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, - const T *mean, - const T *var, + const U *mean, + const U *var, float epsilon, int feature_size) { - using BlockReduce = cub::BlockReduce, BlockDim>; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T d_x_reduce_tmp[2]; + __shared__ U d_x_reduce_tmp[2]; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; - T block_mean = mean[blockIdx.x]; - T block_var = var[blockIdx.x]; - T d_x_mean_partial = 0, d_x_var_partial = 0; + U block_mean = mean[blockIdx.x]; + U block_var = var[blockIdx.x]; + U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += BlockDim) { - d_x_mean_partial += d_x[i]; - d_x_var_partial += d_x[i] * (x[i] - block_mean); + d_x_mean_partial += static_cast(d_x[i]); + d_x_var_partial += + static_cast(d_x[i]) * (static_cast(x[i]) - block_mean); } auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), - PairForLayerNormAddFunctor()); + .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), + PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { - d_x_reduce_tmp[0] = pair.first_ / feature_size; - d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon)); + d_x_reduce_tmp[0] = static_cast(pair.first_) / feature_size; + d_x_reduce_tmp[1] = + static_cast(pair.second_) / + (feature_size * (static_cast(block_var) + epsilon)); } __syncthreads(); d_x_mean_partial = d_x_reduce_tmp[0]; d_x_var_partial = d_x_reduce_tmp[1]; for (int i = beg_idx; i < end_idx; i += BlockDim) { - d_x[i] -= d_x_mean_partial; - d_x[i] -= (x[i] - block_mean) * d_x_var_partial; + d_x[i] -= static_cast(d_x_mean_partial); + d_x[i] -= + static_cast((static_cast(x[i]) - block_mean) * d_x_var_partial); } } // Here, we only calculate d_x -template +template __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, - T *d_x, const T *mean, - const T *var, const T *scale, + T *d_x, const U *mean, + const U *var, const U *scale, float epsilon, int feature_size) { - using BlockReduce = cub::BlockReduce, BlockDim>; + using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T d_x_reduce_tmp[2]; + __shared__ U d_x_reduce_tmp[2]; int beg_idx = blockIdx.x * feature_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * feature_size; - T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; - T d_x_mean_partial = 0, d_x_var_partial = 0; + U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; + U d_x_mean_partial = static_cast(0), d_x_var_partial = static_cast(0); for (int i = beg_idx; i < end_idx; i += BlockDim) { - auto var_val = static_cast(real_sqrt(block_var + epsilon)); + auto var_val = + static_cast(real_sqrt(static_cast(block_var) + epsilon)); if (scale != nullptr) { int col_idx = i % feature_size; - d_x[i] = d_y[i] * scale[col_idx] / var_val; + d_x[i] = + static_cast(static_cast(d_y[i]) * scale[col_idx] / var_val); } else { - d_x[i] = d_y[i] / var_val; + d_x[i] = static_cast(static_cast(d_y[i]) / var_val); } - d_x_mean_partial += d_x[i]; - d_x_var_partial += d_x[i] * (x[i] - block_mean); + d_x_mean_partial += static_cast(d_x[i]); + d_x_var_partial += + static_cast(d_x[i]) * (static_cast(x[i]) - block_mean); } auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), - PairForLayerNormAddFunctor()); + .Reduce(PairForLayerNorm(d_x_mean_partial, d_x_var_partial), + PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { - d_x_reduce_tmp[0] = pair.first_ / feature_size; - d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon)); + d_x_reduce_tmp[0] = static_cast(pair.first_) / feature_size; + d_x_reduce_tmp[1] = + static_cast(pair.second_) / + (feature_size * (static_cast(block_var) + epsilon)); } __syncthreads(); d_x_mean_partial = d_x_reduce_tmp[0]; d_x_var_partial = d_x_reduce_tmp[1]; for (int i = beg_idx; i < end_idx; i += BlockDim) { - d_x[i] -= d_x_mean_partial; - d_x[i] -= (x[i] - block_mean) * d_x_var_partial; + d_x[i] -= static_cast(d_x_mean_partial); + d_x[i] -= + static_cast((static_cast(x[i]) - block_mean) * d_x_var_partial); } } -template +template __global__ void LayerNormBackwardWhenBatchSizeIsOne( - const T *x, const T *d_y, T *d_x, T *d_scale, T *d_bias, const T *mean, - const T *var, const T *scale, float epsilon, int feature_size) { + const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean, + const U *var, const U *scale, float epsilon, int feature_size) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < feature_size) { - auto var_val = static_cast(real_sqrt(var[idx] + epsilon)); + auto var_val = + static_cast(real_sqrt(static_cast(var[idx]) + epsilon)); if (d_x != nullptr) { if (d_scale == nullptr) { - d_x[idx] = d_y[idx] / var_val; + d_x[idx] = static_cast(static_cast(d_y[idx]) / var_val); } else { - d_x[idx] = d_y[idx] * scale[idx] / var_val; + d_x[idx] = + static_cast(static_cast(d_y[idx]) * scale[idx] / var_val); } } if (d_scale != nullptr) { - d_scale[idx] = d_y[idx] * (x[idx] - mean[idx]) / var_val; + d_scale[idx] = static_cast(d_y[idx]) * + (static_cast(x[idx]) - mean[idx]) / var_val; } - if (d_bias != nullptr) d_bias[idx] = d_y[idx]; + if (d_bias != nullptr) d_bias[idx] = static_cast(d_y[idx]); } } -template -static void LayerNormBackward(const T *x, const T *d_y, const T *scale, - const T *mean, const T *var, T *d_x, T *d_scale, - T *d_bias, float epsilon, int batch_size, - int feature_size, cudaStream_t stream) { +template +static void LayerNormBackward(const T *x, const T *d_y, const U *scale, + const U *mean, const U *var, T *d_x, U *d_scale, + U *d_bias, float epsilon, int batch_size, + int feature_size, + const framework::ExecutionContext &ctx) { + auto &dev_ctx = ctx.cuda_device_context(); + auto stream = dev_ctx.stream(); + const int kMaxBlockDim = 512; const int kMaxBlockNum = 128; int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) | @@ -362,14 +686,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, if (batch_size == 1) { LayerNormBackwardWhenBatchSizeIsOne< - T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, - stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon, - feature_size); + T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, + 0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, + epsilon, feature_size); if (d_x != nullptr) { switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX< - T, kBlockDim><<<1, kBlockDim, 0, stream>>>( + T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>( x, d_x, mean, var, epsilon, feature_size)); } } @@ -383,7 +707,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, false, + T, U, kBlockDim, false, false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); @@ -394,7 +718,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, false, true><<>>( + T, U, kBlockDim, false, + true><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -404,7 +729,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientAll< - T, kBlockDim, false><<>>( + T, U, kBlockDim, false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } @@ -413,7 +738,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardGradientOnlyDX< - T, kBlockDim><<>>( + T, U, kBlockDim><<>>( x, d_y, d_x, mean, var, scale, epsilon, feature_size)); } break; @@ -422,14 +747,15 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, true, false><<>>( + T, U, kBlockDim, true, + false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardPostProcessToCalculateDX< - T, kBlockDim><<>>( + T, U, kBlockDim><<>>( x, d_x, mean, var, epsilon, feature_size)); } break; @@ -438,33 +764,57 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( feature_size, kMaxBlockNum, LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, true, true><<>>( + T, U, kBlockDim, true, + true><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( LayerNormBackwardPostProcessToCalculateDX< - T, kBlockDim><<>>( + T, U, kBlockDim><<>>( x, d_x, mean, var, epsilon, feature_size)); } break; case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr - switch (block_dim) { - FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( - feature_size, kMaxBlockNum, - LayerNormBackwardGradientAll< - T, kBlockDim, true><<>>( - x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, - batch_size, feature_size, col_offset)); - } - switch (GetDesiredBlockDim(feature_size)) { - FIXED_BLOCK_DIM_CASE( - LayerNormBackwardPostProcessToCalculateDX< - T, kBlockDim><<>>( - x, d_x, mean, var, epsilon, feature_size)); - } + { + constexpr int VPT = 4; + constexpr int BDIMX2 = 32; + constexpr int BDIMY2 = 4; + dim3 threads2(BDIMX2, BDIMY2, 1); + constexpr int part_size = BDIMY2 * VPT; + const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); + + auto part_grad_gamma_ptr = + memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); + auto part_grad_beta_ptr = + memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U)); + U *part_grad_gamma = reinterpret_cast(part_grad_gamma_ptr->ptr()); + U *part_grad_beta = reinterpret_cast(part_grad_beta_ptr->ptr()); + + LayerNormBackwardPartGradGammaBeta<<>>( + d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma, + part_grad_beta); // compute part_grad_gamma, beta + + constexpr int BDIMX3 = 32; + constexpr int BDIMY3 = 8; + dim3 threads3(BDIMX3, BDIMY3, 1); + const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); + LayerNormBackwardSumGradGammaBeta< + T, U, BDIMX3, BDIMY3><<>>( + part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, + d_scale, d_bias); + + constexpr int BDIMX1 = 32; + constexpr int BDIMY1 = 4; + dim3 threads1(BDIMX1, BDIMY1, 1); + const dim3 blocks1(1, batch_size, 1); + LayerNormBackwardComputeGradInput< + T, U, BDIMX1, BDIMY1><<>>( + d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); break; + } default: break; } @@ -483,7 +833,7 @@ void LayerNormDirectCUDAFunctor::operator()(cudaStream_t stream, int feature_size = static_cast(matrix_dim[1]); switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward<<>>( + LayerNormForward<<>>( input, scale, bias, output, mean, variance, eps, feature_size)); default: PADDLE_THROW(platform::errors::InvalidArgument( @@ -498,6 +848,7 @@ class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); auto *scale = ctx.Input("Scale"); auto *bias = ctx.Input("Bias"); @@ -511,10 +862,10 @@ class LayerNormKernel const auto x_dims = x->dims(); auto *x_data = x->data(); auto *y_data = y->mutable_data(ctx.GetPlace()); - auto *mean_data = mean->mutable_data(ctx.GetPlace()); - auto *var_data = var->mutable_data(ctx.GetPlace()); - auto *scale_data = (scale == nullptr ? nullptr : scale->data()); - auto *bias_data = (bias == nullptr ? nullptr : bias->data()); + auto *mean_data = mean->mutable_data(ctx.GetPlace()); + auto *var_data = var->mutable_data(ctx.GetPlace()); + auto *scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *bias_data = (bias == nullptr ? nullptr : bias->data()); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int batch_size = static_cast(matrix_dim[0]); @@ -524,7 +875,8 @@ class LayerNormKernel switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward<<>>( + LayerNormForward<<>>( x_data, scale_data, bias_data, y_data, mean_data, var_data, epsilon, feature_size)); default: @@ -540,6 +892,7 @@ class LayerNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; const float epsilon = ctx.Attr("epsilon"); // d_x, d_scale, d_bias may be nullptr auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -554,14 +907,15 @@ class LayerNormGradKernel auto *x_data = x->data(); auto *d_y_data = d_y->data(); - auto *mean_data = mean->data(); - auto *var_data = var->data(); - auto *scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *mean_data = mean->data(); + auto *var_data = var->data(); + + auto *scale_data = (scale == nullptr ? nullptr : scale->data()); auto *d_scale_data = (d_scale == nullptr ? nullptr - : d_scale->mutable_data(ctx.GetPlace())); + : d_scale->mutable_data(ctx.GetPlace())); auto *d_bias_data = - (d_bias == nullptr ? nullptr : d_bias->mutable_data(ctx.GetPlace())); + (d_bias == nullptr ? nullptr : d_bias->mutable_data(ctx.GetPlace())); auto *d_x_data = (d_x == nullptr ? nullptr : d_x->mutable_data(ctx.GetPlace())); @@ -571,14 +925,14 @@ class LayerNormGradKernel int batch_size = static_cast(matrix_dim[0]); int feature_size = static_cast(matrix_dim[1]); - auto stream = ctx.cuda_device_context().stream(); - - LayerNormBackward(x_data, d_y_data, scale_data, mean_data, var_data, - d_x_data, d_scale_data, d_bias_data, epsilon, - batch_size, feature_size, stream); + LayerNormBackward(x_data, d_y_data, scale_data, mean_data, var_data, + d_x_data, d_scale_data, d_bias_data, epsilon, + batch_size, feature_size, ctx); } }; + template class LayerNormDirectCUDAFunctor; + #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE #undef FIXED_BLOCK_DIM_CASE_BASE @@ -587,11 +941,15 @@ template class LayerNormDirectCUDAFunctor; } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, - ops::LayerNormKernel); + ops::LayerNormKernel, + ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, - ops::LayerNormGradKernel); + ops::LayerNormGradKernel, + ops::LayerNormGradKernel); diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 8c467a4969..a92d8f17db 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -109,9 +109,11 @@ gray_list = { 'elementwise_mod', 'elementwise_floordiv', 'batch_norm', + 'layer_norm', 'tanh', 'sigmoid', 'lookup_table', + 'lookup_table_v2', 'top_k', 'pool2d', 'pool3d', @@ -123,6 +125,7 @@ gray_list = { 'flatten2', 'stack', 'unstack', + 'uniform_random', 'uniform_random_batch_size_like', 'gaussian_random', 'gaussian_random_batch_size_like', @@ -192,7 +195,6 @@ unsupported_fp16_list = { 'sequence_concat', 'sequence_slice', 'data_norm', - 'layer_norm', 'group_norm', 'spectral_norm', 'depthwise_conv2d_transpose', diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 6987b92a89..2f2f476a87 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -76,7 +76,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ - 'batch_norm', 'fused_bn_add_activation' + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' ]: if in_name not in {'X', 'Z'}: continue @@ -110,8 +110,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): op._set_attr('in_dtype', dest_dtype) if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: - if op.type in ['batch_norm', 'fused_bn_add_activation' - ] and out_name != 'Y': + if op.type in [ + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + ] and out_name != 'Y': continue for out_var_name in op.output(out_name): out_var = block.var(out_var_name) diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index d2c07c185d..51224002c9 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle from operator import mul import paddle.fluid.core as core @@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase): class TestDygraphLayerNormAPIError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): + paddle.enable_static() + layer_norm = fluid.LayerNorm([32, 32]) # the input of LayerNorm must be Variable. x1 = np.random.random((3, 32, 32)).astype('float32') diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 8d62535a25..fcda579332 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -299,7 +299,8 @@ def layer_norm(x, 'begin_norm_axis', begin_norm_axis) return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) - check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'LayerNorm') + check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], + 'LayerNorm') inputs = dict() inputs['X'] = [x] @@ -311,11 +312,13 @@ def layer_norm(x, # create output helper = LayerHelper('layer_norm', **locals()) + + dtype = x.dtype mean_out = helper.create_variable_for_type_inference( - dtype=x.dtype, stop_gradient=True) + dtype=dtype, stop_gradient=True) variance_out = helper.create_variable_for_type_inference( - dtype=x.dtype, stop_gradient=True) - layer_norm_out = helper.create_variable_for_type_inference(x.dtype) + dtype=dtype, stop_gradient=True) + layer_norm_out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="layer_norm", -- GitLab