未验证 提交 b4a93884 编写于 作者: L limingshu 提交者: GitHub

optimize bwd layer_norm kernel with fast method (#42491)

上级 798e2e7e
...@@ -541,7 +541,7 @@ void LaunchLayernormResidualDropoutGrad( ...@@ -541,7 +541,7 @@ void LaunchLayernormResidualDropoutGrad(
if (!is_upscale_in_train) { if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f); factor = static_cast<T>(1.0f);
} }
ln_bwd_1024_kernel_driver< ln_bwd_fast_kernel_driver<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, MaskType>( T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, MaskType>(
dev_ctx, rows, cols, epsilon, layernorm_src, scale, mean, var, d_out, dev_ctx, rows, cols, epsilon, layernorm_src, scale, mean, var, d_out,
d_residual, d_scale, d_layernorm_bias, mask_data, factor, d_dropout_src); d_residual, d_scale, d_layernorm_bias, mask_data, factor, d_dropout_src);
......
...@@ -22,6 +22,8 @@ limitations under the License. */ ...@@ -22,6 +22,8 @@ limitations under the License. */
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include <iostream>
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -428,7 +430,7 @@ template < ...@@ -428,7 +430,7 @@ template <
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
const int rows, float epsilon, const T *__restrict__ x_ptr, const int rows, float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const U *__restrict__ mean_ptr, const ScaleT *__restrict__ gamma_ptr, const U *__restrict__ mean_ptr,
const U *__restrict__ var_ptr, const T *__restrict__ dout_ptr, const U *__restrict__ var_ptr, const T *__restrict__ dout_ptr,
...@@ -671,7 +673,7 @@ template < ...@@ -671,7 +673,7 @@ template <
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA, int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA,
int VEC_COLS = ELTS_PER_ROW / VecSize> int VEC_COLS = ELTS_PER_ROW / VecSize>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_fast_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_, const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) { ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = phi::AlignedVector<U, VecSize>; using Vec = phi::AlignedVector<U, VecSize>;
...@@ -795,7 +797,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( ...@@ -795,7 +797,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/ */
template <typename T, typename U, typename ScaleT = U, template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t> typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
const int cols, float epsilon, const T *x_ptr, const int cols, float epsilon, const T *x_ptr,
const ScaleT *scale_ptr, const U *mean_ptr, const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr, const U *var_ptr, const T *dout_ptr, T *dx_ptr,
...@@ -804,10 +806,10 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, ...@@ -804,10 +806,10 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
T factor = static_cast<T>(0), T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) { T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
if (cols == 1024) { if (cols == 1024 || cols == 384 || cols == 256) {
// step-1: compute dx and reduced part results of dscale and dbias. // step-1: compute dx and reduced part results of dscale and dbias.
const int WARPS_M = 4; const int WARPS_M = 4; // how many rows delt in a cta.
const int WARPS_N = 1; const int WARPS_N = 1; // how many warps to deal with a row.
const int BYTES_PER_LDG = 16; const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T); const int VecSize = BYTES_PER_LDG / sizeof(T);
...@@ -839,20 +841,52 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, ...@@ -839,20 +841,52 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
"To compute fused_dropout_residual_ln grad, d_dropout_src_ptr " "To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
"can't be null")); "can't be null"));
} }
fused_ln_bwd_1024_kernel<true, T, U, ScaleT, MaskType, VecSize, WARPS_M, #define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
WARPS_N, BYTES_PER_LDG> fused_ln_bwd_fast_kernel<true, T, U, ScaleT, MaskType, vec_size, WARPS_M, \
<<<gridx, THREADS_PER_CTA, 0, stream>>>( WARPS_N, BYTES_PER_LDG, ele_per_row> \
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, <<<gridx, THREADS_PER_CTA, 0, stream>>>( \
dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \
d_dropout_src_ptr); dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, \
d_dropout_src_ptr);
if (cols == 1024) {
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
} else {
switch (cols) {
case 384:
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(1, 384);
break;
case 256:
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256);
break;
}
}
#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
} else { } else {
fused_ln_bwd_1024_kernel<false, T, U, ScaleT, MaskType, VecSize, WARPS_M, #define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
WARPS_N, BYTES_PER_LDG> fused_ln_bwd_fast_kernel<false, T, U, ScaleT, MaskType, vec_size, WARPS_M, \
<<<gridx, THREADS_PER_CTA, 0, stream>>>( WARPS_N, BYTES_PER_LDG, ele_per_row> \
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, <<<gridx, THREADS_PER_CTA, 0, stream>>>( \
dscale_temp_ptr, dbias_temp_ptr, dx_ptr); rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \
dscale_temp_ptr, dbias_temp_ptr, dx_ptr);
if (cols == 1024) {
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
} else {
switch (cols) {
case 384:
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(1, 384);
break;
case 256:
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256);
break;
}
}
#undef LAUNCH_FUSED_LN_BWD_FAST_KERNEL
} }
const int WARPS_M_2 = 16; const int WARPS_M_2 = 16;
const int WARPS_N_2 = 1; const int WARPS_N_2 = 1;
const int BYTES_PER_LDG_2 = 4; const int BYTES_PER_LDG_2 = 4;
...@@ -865,18 +899,36 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, ...@@ -865,18 +899,36 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
WARPS_M_2 * THREADS_PER_ROW_2; // 16 * 32 = 512 WARPS_M_2 * THREADS_PER_ROW_2; // 16 * 32 = 512
const int ROWS_PER_CTA_2 = WARPS_M_2; // 16 const int ROWS_PER_CTA_2 = WARPS_M_2; // 16
const int gridx_2 = static_cast<int>(
std::ceil(1024 / static_cast<float>(THREADS_PER_ROW_2 * VecSize_2)));
// #blocks: 32,#threads_per_block: 512 // #blocks: 32,#threads_per_block: 512
// Note: it is not supported for double type. // Note: it is not supported for double type.
if (sizeof(U) > 4) { if (sizeof(U) > 4) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Only support float and fp16 type")); "Only support float and fp16 type"));
} else { } else {
ln_bwd_1024_final_kernel<U, ScaleT, VecSize_2, WARPS_M_2, WARPS_N_2, int gridx_2 = 0;
BYTES_PER_LDG_2>
<<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>( #define LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(vec_size, ele_per_row) \
gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr); gridx_2 = static_cast<int>(std::ceil( \
ele_per_row / static_cast<float>(THREADS_PER_ROW_2 * vec_size))); \
ln_bwd_fast_final_kernel<U, ScaleT, vec_size, WARPS_M_2, WARPS_N_2, \
BYTES_PER_LDG_2, ele_per_row> \
<<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>( \
gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr);
if (cols == 1024) {
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 1024);
} else {
switch (cols) {
case 384:
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(1, 384);
break;
case 256:
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 256);
break;
}
}
#undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -1484,15 +1536,17 @@ static void LayerNormBackward( ...@@ -1484,15 +1536,17 @@ static void LayerNormBackward(
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
{ {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false; bool can_call_fast_kernel = false;
// todo: rule out double type. // todo: rule out double type.
if (feature_size == 1024 && sizeof(T) <= 4) { if ((feature_size == 1024 || feature_size == 384 ||
can_call_1024_kernel = true; feature_size == 256) &&
sizeof(T) <= 4) {
can_call_fast_kernel = true;
} }
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel;
if (can_call_1024_kernel) { VLOG(6) << "can_call_fast_kernel = " << can_call_fast_kernel;
ln_bwd_1024_kernel_driver< if (can_call_fast_kernel) {
ln_bwd_fast_kernel_driver<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>( T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
dev_ctx, batch_size, feature_size, epsilon, x, scale, mean, var, dev_ctx, batch_size, feature_size, epsilon, x, scale, mean, var,
d_y, d_x, d_scale, d_bias); d_y, d_x, d_scale, d_bias);
......
...@@ -247,7 +247,6 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -247,7 +247,6 @@ class TestLayerNormOp(unittest.TestCase):
def test_check_forward_backward_with_scale_and_bias(self): def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], self.check_forward_backward(shape=[2, 3, 4, 5],
begin_norm_axis=1, begin_norm_axis=1,
...@@ -288,6 +287,14 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -288,6 +287,14 @@ class TestLayerNormOp(unittest.TestCase):
begin_norm_axis=1, begin_norm_axis=1,
has_scale=True, has_scale=True,
has_bias=True) has_bias=True)
self.check_forward_backward(shape=[1, 128, 256, 256],
begin_norm_axis=3,
has_scale=True,
has_bias=True)
self.check_forward_backward(shape=[1, 256, 384],
begin_norm_axis=2,
has_scale=True,
has_bias=True)
class TestLayerNormAPI(unittest.TestCase): class TestLayerNormAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册