diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index c6205863103ff99e3d850c5acc739a400cdb5696..babf1c657f232d8316df924487a925c6b6162cf9 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -19,6 +19,8 @@ limitations under the License. */ namespace paddle { namespace operators { +#define LN_NUM_COLS 1024 + template using CudnnDataType = platform::CudnnDataType; template @@ -153,6 +155,191 @@ __global__ void FusedLayernormResidualDropoutBias( invvar); } +/* +* @brief layernorm(residual + dropout(x)); + * Conditions: + * (1) The number of cols is 1024; + * (2) layer_norm scale and bias is not null; + * (3) linear bias is null; + * @param + * rows: batch_size * seq_len + * cols: 1024 + * x_: [rows, cols], inputs + * residual_:[rows, cols] + * gamma_: [cols]: layernorm scale, not null + * beta_: [cols], layernorm bias, not null + * mask_out_: [rows, cols], dropout result + * residual_out_: [rows, cols], residual + dropout(src) + * y_: [rows, cols], layernorm result + * mean_out_: [rows]: layernorm means + * var_out_: [rows]: layernorm vars +*/ +template < + typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t, + int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16, + int ELTS_PER_ROW = 1024, int THREADS_PER_WARP = 32, + int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP, + int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, + int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, + int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> +__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( + int rows, int cols, uint64_t seed, const float dropout_prob, + const bool is_upscale_in_train, const bool is_test, + const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr, + const T *__restrict__ residual_ptr, const ScaleT *__restrict__ gamma_ptr, + const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr, + U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, + T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) { + using Vec = platform::AlignedVector; + using Vec_scale = platform::AlignedVector; + using MaskStoreT = platform::AlignedVector; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31 + const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3 + const int warp_n = warp % WARPS_N; // 0 + const int warp_m = warp / WARPS_N; // 0, 1, 2, 3 + + const int c = warp_n * THREADS_PER_WARP + lane; // lane + const int r = bidx * ROWS_PER_CTA + warp_m; // row id + + int idx = r * LN_NUM_COLS + c; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); + + Vec_scale gamma[LDGS]; + Vec_scale beta[LDGS]; +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Load(gamma_ptr + col * VecSize, &gamma[it]); + platform::Load(beta_ptr + col * VecSize, &beta[it]); + col += THREADS_PER_ROW; + } + + constexpr U rn = 1.f / U(LN_NUM_COLS); + for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { + Vec x[LDGS]; + Vec residual[LDGS]; +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, + &x[it]); + platform::Load( + residual_ptr + row * LN_NUM_COLS + col * VecSize, &residual[it]); + col += THREADS_PER_ROW; + } + + MaskStoreT mask_vec[LDGS]; + if (!is_test) { +#pragma unroll + for (int it = 0; it < LDGS; it++) { + float rand[VecSize]; + RandVec(&state, rand); +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { +#pragma unroll + mask_vec[it][jt] = static_cast(rand[jt] >= dropout_prob); + } + } + } else { +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + mask_vec[it][jt] = static_cast(1); + } + } + } + + // 4 * 8 + U xf[LDGS * VecSize]; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + // dropout(x) + residual + x[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor + + residual[it][jt]; + xf[it * VecSize + jt] = U(x[it][jt]); + } + } + +// store dropout_residual_out and mask_out +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Store( + x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize); + platform::Store( + mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize); + col += THREADS_PER_ROW; + } + + U mu_local = 0.f; +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + mu_local += xf[it * VecSize + jt]; + } + } + +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it); + } + mu_local *= rn; + if (lane == 0) { + mean_out_ptr[row] = mu_local; + } + U var_local = 0.f; + +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + U diff = xf[it * VecSize + jt] - mu_local; + var_local += diff * diff; + } + } + +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); + } + U rsigma = rsqrtf(var_local * rn + epsilon); + if (lane == 0) { + // Note: the stored var is different for paddle(ln) and apex (fast ln). + // var_out_ptr[row] = rsigma; + var_out_ptr[row] = var_local * rn; + } + +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + // use fp16 to compute + // ScaleT tmp = static_cast(rsigma * (xf[it * VecSize + jt] - + // mu_local)); + // x[it][jt] = gamma[it][jt] * tmp + beta[it][jt]; + // cast to fp32 to compute + U tmp = rsigma * (static_cast(xf[it * VecSize + jt]) - mu_local); + x[it][jt] = static_cast(static_cast(gamma[it][jt]) * tmp + + static_cast(beta[it][jt])); + } + } + +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Store(x[it], + y_ptr + row * LN_NUM_COLS + col * VecSize); + col += THREADS_PER_ROW; + } + } +} + /** * @brief layernorm(residual + dropout(src + bias)); * @param @@ -205,6 +392,13 @@ void LaunchLayernormResidualDropoutBias( return; } + bool can_call_1024_kernel = false; + if (cols == 1024 && scale != nullptr && layernorm_bias != nullptr && + bias == nullptr) { + can_call_1024_kernel = true; + } + VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel; + const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (cols % VecSize != 0) { int blockDim = GetDesiredBlockDim(cols); @@ -215,13 +409,35 @@ void LaunchLayernormResidualDropoutBias( epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, layernorm_dst, mean, var); } else { - int blockDim = GetDesiredBlockDim(cols / VecSize); - FusedLayernormResidualDropoutBias< - T, uint8_t, VecSize, U, - ScaleBiasWithSameTypeX><<>>( - rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, - epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, - layernorm_dst, mean, var); + if (can_call_1024_kernel) { + const int WARPS_M = 4; + const int WARPS_N = 1; + const int THREADS_PER_WARP = 32; + const int BYTES_PER_LDG = 16; + const int VecSize = BYTES_PER_LDG / sizeof(T); + + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; + const int ROWS_PER_CTA = WARPS_M; + + // Note: the grid can not exceed max_grid of the gpu. + const int grid = + static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); + fused_ln_fwd_1024_kernel< + T, U, LayerNormScaleBiasT, uint8_t, + VecSize, WARPS_M, WARPS_N, + BYTES_PER_LDG><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, + increment, epsilon, src, residual, scale, layernorm_bias, mask_data, + mean, var, dst, layernorm_dst); + } else { + int blockDim = GetDesiredBlockDim(cols / VecSize); + FusedLayernormResidualDropoutBias< + T, uint8_t, VecSize, U, + ScaleBiasWithSameTypeX><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, + increment, epsilon, src, residual, bias, scale, layernorm_bias, + mask_data, dst, layernorm_dst, mean, var); + } } } diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu index 57d3fc94dc88a0699b103c081642757798719332..cc14d0680d381ff2bbe73ee712e218c9c4d79185 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -66,12 +66,10 @@ struct TestFusedLayernormResidualDropoutBias { ctx = reinterpret_cast(devicectx); } - TestFusedLayernormResidualDropoutBias(int _rows, int _cols, - uint64_t _seed = 0, - float _dropout_prob = 0.0, - float _epsilon = 0.00001f, - bool _is_upscale_in_train = false, - bool _is_test = false) { + TestFusedLayernormResidualDropoutBias( + int _rows, int _cols, uint64_t _seed = 0, float _dropout_prob = 0.0, + float _epsilon = 0.00001f, bool _is_upscale_in_train = false, + bool _is_test = false, bool _has_bias = true) { rows = _rows; cols = _cols; seed = _seed; @@ -79,7 +77,7 @@ struct TestFusedLayernormResidualDropoutBias { epsilon = _epsilon; is_upscale_in_train = _is_upscale_in_train; is_test = _is_test; - has_bias = true; + has_bias = _has_bias; has_scale = true; has_layernorm_bias = true; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -283,7 +281,6 @@ static void BaseTest(const bool is_fp16 = false) { } } } - TEST(FusedDropout, GPUFusedLayernormResidualDropoutBias) { BaseTest(); } TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasDouble) { @@ -330,3 +327,12 @@ TEST(FusedDropout, GPUFusedLayernormResidualDropoutLargeShape) { test.Run(); test.CheckOut(static_cast(1e-4)); } + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutFp16MLperf) { + const int rows = 512; + const int cols = 1024; + TestFusedLayernormResidualDropoutBias test( + rows, cols, 0, 0, 0.00001f, false, false, false); + test.Run(); + test.CheckOut(static_cast(1e-2)); +} diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 0c1f58a2f30f68c184906a0cebd78da98a83d952..bc00d875cd1dd37b64ae8a38c6949054bc168c7c 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -23,6 +23,7 @@ namespace cub = hipcub; #endif #include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" @@ -35,6 +36,8 @@ using CudnnDataType = platform::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; +#define LN_NUM_COLS 1024 + inline static int GetDesiredBlockDim(int64_t block_dim) { #ifdef __HIPCC__ const int kMaxBlockDim = 256; @@ -169,6 +172,118 @@ __inline__ __device__ half rsqrt_(const half val) { } #endif +#ifdef PADDLE_WITH_CUDA +template +__global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( + int rows, int cols, const float epsilon, const T *__restrict__ x_ptr, + const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr, + U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, + T *__restrict__ y_ptr) { + using Vec = platform::AlignedVector; + using Vec_scale = platform::AlignedVector; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31 + const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3 + const int warp_n = warp % WARPS_N; // 0 + const int warp_m = warp / WARPS_N; // 0, 1, 2, 3 + + const int c = warp_n * THREADS_PER_WARP + lane; // lane + const int r = bidx * ROWS_PER_CTA + warp_m; // row id + + Vec_scale gamma[LDGS]; + Vec_scale beta[LDGS]; +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Load(gamma_ptr + col * VecSize, &gamma[it]); + platform::Load(beta_ptr + col * VecSize, &beta[it]); + col += THREADS_PER_ROW; + } + + constexpr U rn = 1.f / U(LN_NUM_COLS); + for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { + Vec x[LDGS]; +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, + &x[it]); + col += THREADS_PER_ROW; + } + U xf[LDGS * VecSize]; + + U mu_local = 0.f; + +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + xf[it * VecSize + jt] = U(x[it][jt]); + mu_local += xf[it * VecSize + jt]; + } + } + +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it); + } + mu_local *= rn; + if (lane == 0) { + mean_out_ptr[row] = mu_local; + } + U var_local = 0.f; + +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + U diff = xf[it * VecSize + jt] - mu_local; + var_local += diff * diff; + } + } + +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); + } + // Note: to assure if it is right for double + U rsigma = rsqrtf(var_local * rn + epsilon); + if (lane == 0) { + var_out_ptr[row] = var_local * rn; + } + +#pragma unroll + for (int it = 0; it < LDGS; it++) { +#pragma unroll + for (int jt = 0; jt < VecSize; jt++) { + // use fp16 to compute + // ScaleT tmp = static_cast(rsigma * (xf[it * VecSize + jt] - + // mu_local)); + // x[it][jt] = gamma[it][jt] * tmp + beta[it][jt]; + // cast to fp32 to compute + U tmp = (rsigma * (static_cast(xf[it * VecSize + jt]) - mu_local)); + x[it][jt] = static_cast(static_cast(gamma[it][jt]) * tmp + + static_cast(beta[it][jt])); + } + } + +#pragma unroll + for (int it = 0, col = c; it < LDGS; it++) { + platform::Store(x[it], + y_ptr + row * LN_NUM_COLS + col * VecSize); + col += THREADS_PER_ROW; + } + } +} +#endif + template using LayerNormScaleBiasT = typename std::conditional::type; diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 7725f336416dbb80e0f65a38b6a4f16c88fb799f..ef4f0c6ba7063d4ff39732aed85ab5bbe007e7ca 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -112,11 +112,49 @@ class LayerNormKernel } \ } while (0) - if (is_scale_bias_same_dtype_with_x) { - PADDLE_LAUNCH_LAYERNORM_FWD(T, true); +#ifdef PADDLE_WITH_CUDA + bool can_call_1024_kernel = false; + if (feature_size == 1024 && scale != nullptr && bias != nullptr) { + can_call_1024_kernel = true; + } + if (can_call_1024_kernel) { + const int WARPS_M = 4; + const int WARPS_N = 1; + const int THREADS_PER_WARP = 32; + const int BYTES_PER_LDG = 16; + const int VecSize = BYTES_PER_LDG / sizeof(T); + + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; + const int ROWS_PER_CTA = WARPS_M; + + const int grid = static_cast( + std::ceil(batch_size / static_cast(ROWS_PER_CTA))); + if (is_scale_bias_same_dtype_with_x) { + ln_fwd_1024_kernel<<>>( + batch_size, feature_size, epsilon, x_data, + static_cast(void_scale_data), + static_cast(void_bias_data), mean_data, var_data, + y_data); + } else { + ln_fwd_1024_kernel<<>>( + batch_size, feature_size, epsilon, x_data, + static_cast(void_scale_data), + static_cast(void_bias_data), mean_data, var_data, + y_data); + } } else { - PADDLE_LAUNCH_LAYERNORM_FWD(U, false); +#endif + if (is_scale_bias_same_dtype_with_x) { + PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_FWD(U, false); + } +#ifdef PADDLE_WITH_CUDA } +#endif + #undef PADDLE_LAUNCH_LAYERNORM_FWD } }; diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index d2d931f148078d124a25ddbb888b3e9cb5911211..7dd310d2b88a90e09ba5ceedb541da4be263e559 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -278,6 +278,8 @@ class TestLayerNormOp(unittest.TestCase): has_scale=False, has_bias=False, y_grad_scale=0.1) + self.check_forward_backward( + shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True) class TestLayerNormAPI(unittest.TestCase):