未验证 提交 b4a93884 编写于 作者: L limingshu 提交者: GitHub

optimize bwd layer_norm kernel with fast method (#42491)

上级 798e2e7e
......@@ -541,7 +541,7 @@ void LaunchLayernormResidualDropoutGrad(
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
ln_bwd_1024_kernel_driver<
ln_bwd_fast_kernel_driver<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, MaskType>(
dev_ctx, rows, cols, epsilon, layernorm_src, scale, mean, var, d_out,
d_residual, d_scale, d_layernorm_bias, mask_data, factor, d_dropout_src);
......
......@@ -22,6 +22,8 @@ limitations under the License. */
namespace cub = hipcub;
#endif
#include <iostream>
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
......@@ -428,7 +430,7 @@ template <
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
const int rows, float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const U *__restrict__ mean_ptr,
const U *__restrict__ var_ptr, const T *__restrict__ dout_ptr,
......@@ -671,7 +673,7 @@ template <
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA,
int VEC_COLS = ELTS_PER_ROW / VecSize>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_fast_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = phi::AlignedVector<U, VecSize>;
......@@ -795,7 +797,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/
template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
const int cols, float epsilon, const T *x_ptr,
const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr,
......@@ -804,10 +806,10 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream();
if (cols == 1024) {
if (cols == 1024 || cols == 384 || cols == 256) {
// step-1: compute dx and reduced part results of dscale and dbias.
const int WARPS_M = 4;
const int WARPS_N = 1;
const int WARPS_M = 4; // how many rows delt in a cta.
const int WARPS_N = 1; // how many warps to deal with a row.
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
......@@ -839,20 +841,52 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
"To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
"can't be null"));
}
fused_ln_bwd_1024_kernel<true, T, U, ScaleT, MaskType, VecSize, WARPS_M,
WARPS_N, BYTES_PER_LDG>
<<<gridx, THREADS_PER_CTA, 0, stream>>>(
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr,
dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor,
d_dropout_src_ptr);
#define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
fused_ln_bwd_fast_kernel<true, T, U, ScaleT, MaskType, vec_size, WARPS_M, \
WARPS_N, BYTES_PER_LDG, ele_per_row> \
<<<gridx, THREADS_PER_CTA, 0, stream>>>( \
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \
dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, \
d_dropout_src_ptr);
if (cols == 1024) {
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
} else {
switch (cols) {
case 384:
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(1, 384);
break;
case 256:
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256);
break;
}
}
#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
} else {
fused_ln_bwd_1024_kernel<false, T, U, ScaleT, MaskType, VecSize, WARPS_M,
WARPS_N, BYTES_PER_LDG>
<<<gridx, THREADS_PER_CTA, 0, stream>>>(
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr,
dscale_temp_ptr, dbias_temp_ptr, dx_ptr);
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
fused_ln_bwd_fast_kernel<false, T, U, ScaleT, MaskType, vec_size, WARPS_M, \
WARPS_N, BYTES_PER_LDG, ele_per_row> \
<<<gridx, THREADS_PER_CTA, 0, stream>>>( \
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \
dscale_temp_ptr, dbias_temp_ptr, dx_ptr);
if (cols == 1024) {
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
} else {
switch (cols) {
case 384:
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(1, 384);
break;
case 256:
LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256);
break;
}
}
#undef LAUNCH_FUSED_LN_BWD_FAST_KERNEL
}
const int WARPS_M_2 = 16;
const int WARPS_N_2 = 1;
const int BYTES_PER_LDG_2 = 4;
......@@ -865,18 +899,36 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
WARPS_M_2 * THREADS_PER_ROW_2; // 16 * 32 = 512
const int ROWS_PER_CTA_2 = WARPS_M_2; // 16
const int gridx_2 = static_cast<int>(
std::ceil(1024 / static_cast<float>(THREADS_PER_ROW_2 * VecSize_2)));
// #blocks: 32,#threads_per_block: 512
// Note: it is not supported for double type.
if (sizeof(U) > 4) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support float and fp16 type"));
} else {
ln_bwd_1024_final_kernel<U, ScaleT, VecSize_2, WARPS_M_2, WARPS_N_2,
BYTES_PER_LDG_2>
<<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>(
gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr);
int gridx_2 = 0;
#define LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(vec_size, ele_per_row) \
gridx_2 = static_cast<int>(std::ceil( \
ele_per_row / static_cast<float>(THREADS_PER_ROW_2 * vec_size))); \
ln_bwd_fast_final_kernel<U, ScaleT, vec_size, WARPS_M_2, WARPS_N_2, \
BYTES_PER_LDG_2, ele_per_row> \
<<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>( \
gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr);
if (cols == 1024) {
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 1024);
} else {
switch (cols) {
case 384:
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(1, 384);
break;
case 256:
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 256);
break;
}
}
#undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -1484,15 +1536,17 @@ static void LayerNormBackward(
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
{
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
bool can_call_fast_kernel = false;
// todo: rule out double type.
if (feature_size == 1024 && sizeof(T) <= 4) {
can_call_1024_kernel = true;
if ((feature_size == 1024 || feature_size == 384 ||
feature_size == 256) &&
sizeof(T) <= 4) {
can_call_fast_kernel = true;
}
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel;
if (can_call_1024_kernel) {
ln_bwd_1024_kernel_driver<
VLOG(6) << "can_call_fast_kernel = " << can_call_fast_kernel;
if (can_call_fast_kernel) {
ln_bwd_fast_kernel_driver<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
dev_ctx, batch_size, feature_size, epsilon, x, scale, mean, var,
d_y, d_x, d_scale, d_bias);
......
......@@ -247,7 +247,6 @@ class TestLayerNormOp(unittest.TestCase):
def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5],
begin_norm_axis=1,
......@@ -288,6 +287,14 @@ class TestLayerNormOp(unittest.TestCase):
begin_norm_axis=1,
has_scale=True,
has_bias=True)
self.check_forward_backward(shape=[1, 128, 256, 256],
begin_norm_axis=3,
has_scale=True,
has_bias=True)
self.check_forward_backward(shape=[1, 256, 384],
begin_norm_axis=2,
has_scale=True,
has_bias=True)
class TestLayerNormAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册