未验证 提交 99cfcc09 编写于 作者: L Li Min 提交者: GitHub

Optimize layer norm backward cuda kernel when cols is 1024. (#39247)

* Add fp16 support for scale/bias for fused_layernnorm_residual_dropout_bias op.

* Remove useless code.

* Remove useless code.

* Optimize layer_norm fwd when cols is 1024.

* Remove useless code.

* Minors.

* Minors.

* Modifications accordding to reviews.

* Minors.

* Optimize layer_norm bwd kernel when cols is 1024.

* Polish layer_norm_bwd_1024 kernel.

* Limit ln_bwd_1024_kernel to paddle_with_cuda.

* Fix double type compile error.

* Add optimization of ln bwd for fused_dropout_add_ln op.

* Polish codes.
上级 92da5055
...@@ -284,12 +284,31 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> { ...@@ -284,12 +284,31 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
P* d_layernorm_bias, T* d_dropout_src, P* d_layernorm_bias, T* d_dropout_src,
T* d_bias, T* d_residual) { T* d_bias, T* d_residual) {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
bool can_call_1024_kernel = false;
// Fast impl for cases when cols is 1024 and linear_bias is nullptr.
// In fact, linear_bias is not nullptr is also feasible for impl.
// Here, we do not support it.
if (this->cols_ == 1024 && d_bias == nullptr && d_scale != nullptr &&
d_layernorm_bias != nullptr && sizeof(T) <= 4) {
can_call_1024_kernel = true;
}
VLOG(6) << "LaunchLayernormResidualDropoutGrad = " << can_call_1024_kernel;
if (can_call_1024_kernel) {
LaunchLayernormResidualDropoutGrad<T, U, MaskType, is_same_type>(
ctx, this->rows_, this->cols_, epsilon_,
this->dropout_param_.dropout_prob,
this->dropout_param_.is_upscale_in_train, d_out, layernorm_src, gamma,
mean, variance, mask, d_scale, d_layernorm_bias, d_residual,
d_dropout_src);
} else {
LayerNormBackward<T, U, is_same_type>( LayerNormBackward<T, U, is_same_type>(
layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale, layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale,
d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx); d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx);
this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src,
d_residual, d_bias); d_residual, d_bias);
} }
}
protected: protected:
float epsilon_; float epsilon_;
......
...@@ -441,5 +441,30 @@ void LaunchLayernormResidualDropoutBias( ...@@ -441,5 +441,30 @@ void LaunchLayernormResidualDropoutBias(
} }
} }
template <typename T, typename U, typename MaskType,
bool ScaleBiasWithSameTypeX = false>
void LaunchLayernormResidualDropoutGrad(
const platform::CUDADeviceContext &dev_ctx, const uint32_t rows,
const uint32_t cols, const float epsilon, const float dropout_prob,
const bool is_upscale_in_train, const T *d_out, const T *layernorm_src,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormParamType<T> *mean, const LayerNormParamType<T> *var,
const MaskType *mask_data,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_layernorm_bias,
T *d_residual, T *d_dropout_src) {
const T zero = static_cast<T>(0.0f);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
ln_bwd_1024_kernel_driver<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, MaskType>(
dev_ctx, rows, cols, epsilon, layernorm_src, scale, mean, var, d_out,
d_residual, d_scale, d_layernorm_bias, mask_data, factor, d_dropout_src);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -385,6 +385,471 @@ __inline__ __device__ void cuLoadAddStridedInputs( ...@@ -385,6 +385,471 @@ __inline__ __device__ void cuLoadAddStridedInputs(
} }
} }
#ifdef PADDLE_WITH_CUDA
template <
bool isFusedDropoutResidualLn, typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t, int VecSize = 8, int WARPS_M = 4,
int WARPS_N = 1, int BYTES_PER_LDG = 16, int ELTS_PER_ROW = 1024,
int THREADS_PER_WARP = 32, int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
const int rows, float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const U *__restrict__ mean_ptr,
const U *__restrict__ var_ptr, const T *__restrict__ dout_ptr,
U *__restrict__ dgamma_temp_ptr, U *__restrict__ dbeta_temp_ptr,
T *__restrict__ dx_ptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31
const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3
const int warp_m = warp / WARPS_N; // 0, 1, 2, 3
const int warp_n = warp % WARPS_N; // 0
const int tid_r = warp_n * THREADS_PER_WARP + lane; // 0, 1, ..., 31
const int r = bidx * ROWS_PER_CTA + warp_m;
const int c = warp_n * THREADS_PER_WARP + lane;
static_assert(LN_NUM_COLS == THREADS_PER_ROW * LDGS * VecSize, "");
// smem for column reduction
__shared__ U smem_[ROWS_PER_CTA * LN_NUM_COLS];
U dgamma_sum[LDGS * VecSize];
U dbeta_sum[LDGS * VecSize];
memset(dgamma_sum, 0, sizeof(U) * LDGS * VecSize);
memset(dbeta_sum, 0, sizeof(U) * LDGS * VecSize);
// Note: it is no use for WARP_N = 1
__shared__ U smem_sum_loss1[ROWS_PER_CTA * WARPS_N]; // 4
__shared__ U smem_sum_loss2[ROWS_PER_CTA * WARPS_N]; // 4
U *sum_loss1_shared = &smem_sum_loss1[warp_m * WARPS_N];
U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N];
// step-1: compute dx and local results of dscale and dbias
constexpr float rn = 1.f / static_cast<float>(LN_NUM_COLS);
Vec_scale gamma[LDGS];
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
col += THREADS_PER_ROW;
}
#pragma unroll 1
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
const U mean_cur_row = mean_ptr[row];
const U var_cur_row = rsqrt_<U>(var_ptr[row] + epsilon);
Vec dout[LDGS], x[LDGS];
MaskLoadT mask_vec[LDGS];
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
platform::Load<T, VecSize>(dout_ptr + row * LN_NUM_COLS + col * VecSize,
&dout[it]);
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
if (isFusedDropoutResidualLn) {
platform::Load<MaskType, VecSize>(
mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]);
}
col += THREADS_PER_ROW;
}
// local reductions
U dy[LDGS * VecSize];
U y[LDGS * VecSize];
U sum_loss1 = 0.f;
U sum_loss2 = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U x_tmp = x[it][jt];
U y_tmp = var_cur_row * (x_tmp - mean_cur_row);
U dy_tmp = static_cast<U>(gamma[it][jt]) *
static_cast<U>(dout[it][jt]); // scale * dy
U dout_tmp = dout[it][jt]; // dy
// used for get dx (row reduction)
sum_loss1 += dy_tmp; // scale * dy, sum_1
sum_loss2 += dy_tmp * y_tmp; // scale * dy * y, sum_2
dy[it * VecSize + jt] = dy_tmp; // scale * dy
y[it * VecSize + jt] = y_tmp; // y
// used for get dscale and dbias (column reduction)
dgamma_sum[it * VecSize + jt] += dout_tmp * y_tmp; // dy * y
dbeta_sum[it * VecSize + jt] += dout_tmp; // dy
}
}
// reduction across row for sum_loss1, sum_loss2
if (WARPS_N == 1) {
#pragma unroll
// row reduction among 32 threads.
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
sum_loss1 += __shfl_xor_sync(uint32_t(-1), sum_loss1, it);
sum_loss2 += __shfl_xor_sync(uint32_t(-1), sum_loss2, it);
}
sum_loss1 *= rn;
sum_loss2 *= rn;
} else {
#pragma unroll
for (int it = 16; it > 0; it /= 2) {
sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it);
sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it);
}
if (lane == 0) {
sum_loss1_shared[warp_n] = sum_loss1;
sum_loss2_shared[warp_n] = sum_loss2;
}
__syncthreads();
if (warp_n == 0 && lane == 0) {
sum_loss1 = 0.f;
sum_loss2 = 0.f;
for (int it = 0; it < WARPS_N; it++) {
sum_loss1 += sum_loss1_shared[it];
sum_loss2 += sum_loss2_shared[it];
}
sum_loss1_shared[0] = sum_loss1;
sum_loss2_shared[0] = sum_loss2;
}
__syncthreads();
sum_loss1 = sum_loss1_shared[0] * rn;
sum_loss2 = sum_loss2_shared[0] * rn;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U dy_tmp = dy[it * VecSize + jt]; // scale * dy
U y_tmp = y[it * VecSize + jt]; // y
// dx = var * (scale * dy - sum_loss2 * y - sum_loss1)
U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1);
// Note: reuse x and dout vec register to store dx and d_dropout_src.
x[it][jt] = static_cast<T>(dx_tmp);
if (isFusedDropoutResidualLn) {
dout[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor;
}
}
}
// store dx to global memory
col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
dx_ptr + row * LN_NUM_COLS + col * VecSize);
if (isFusedDropoutResidualLn) {
platform::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize);
}
col += THREADS_PER_ROW;
}
}
// step-2: column reduction of dscale and dbias for each thread block.
// each block's sum: [4 * 1024] -> [1 * 1024]
enum { NUM_RES = LN_NUM_COLS / THREADS_PER_CTA }; // 1024/128 = 8
static_assert(NUM_RES * THREADS_PER_CTA == LN_NUM_COLS, "");
U *smem_write;
smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; // [4 * 1024]
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
smem_write[jt] = dbeta_sum[it * VecSize + jt];
}
smem_write += THREADS_PER_ROW * VecSize; // 32*8
}
__syncthreads();
U cta_dbeta_sum[NUM_RES];
memset(cta_dbeta_sum, 0, sizeof(U) * NUM_RES);
// column reduction for elems in smem: 4*1024 -> 1*1024.
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dbeta_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA];
}
}
__syncthreads();
smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
smem_write[jt] = dgamma_sum[it * VecSize + jt];
}
smem_write += THREADS_PER_ROW * VecSize;
}
__syncthreads();
U cta_dgamma_sum[NUM_RES];
memset(cta_dgamma_sum, 0, sizeof(U) * NUM_RES);
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dgamma_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA];
}
}
// the shape of results:(#blocks, 1024)
U *dgamma_part =
static_cast<U *>(dgamma_temp_ptr) + bidx * LN_NUM_COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dgamma_sum[jt];
dgamma_part += THREADS_PER_CTA;
}
U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * LN_NUM_COLS + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dbeta_sum[jt];
dbeta_part += THREADS_PER_CTA;
}
}
/* This function carry out column reduction whose input is [rows, 1024] and
* output is [1, 1024].
* #blocks: 32
* #threads: 512
*/
// todo(@limin29): to think if there are better impl strategies
template <
typename U, typename ScaleT = U, int VecSize = 1, int WARPS_M = 16,
int WARPS_N = 1, int BYTES_PER_LDG = 4, int ELTS_PER_ROW = 1024,
int THREADS_PER_WARP = 32, int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA,
int VEC_COLS = ELTS_PER_ROW / VecSize>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = platform::AlignedVector<U, VecSize>;
static_assert(VEC_COLS == LN_NUM_COLS / VecSize, "");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP;
const int warp = tidx / THREADS_PER_WARP;
const int warp_m = warp / WARPS_N;
const int warp_n = warp % WARPS_N;
const int tid_c = warp_n * THREADS_PER_WARP + lane;
const int c = bidx * THREADS_PER_ROW + tid_c;
const int r = warp_m;
__shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize];
for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) {
const U *dg_part_ptr = (dg_part_) + r * LN_NUM_COLS + col * VecSize;
const U *db_part_ptr = (db_part_) + r * LN_NUM_COLS + col * VecSize;
U dg_sum[VecSize];
U db_sum[VecSize];
memset(dg_sum, 0, sizeof(U) * VecSize);
memset(db_sum, 0, sizeof(U) * VecSize);
#pragma unroll
for (int row = r; row < rows; row += ROWS_PER_CTA) {
Vec dg;
Vec db;
platform::Load<U, VecSize>(dg_part_ptr, &dg);
platform::Load<U, VecSize>(db_part_ptr, &db);
dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
dg_sum[jt] += dg[jt];
db_sum[jt] += db[jt];
}
}
// reduction across rows of the thread block
U *smem_write;
smem_write = smem_space + (warp_m - 1) * THREADS_PER_ROW * VecSize + tid_c;
if (warp_m > 0) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
*smem_write = dg_sum[jt];
smem_write += THREADS_PER_ROW;
}
}
__syncthreads();
U *smem_read;
smem_read = smem_space + tid_c;
if (warp_m == 0) {
#pragma unroll
for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
dg_sum[jt] += *smem_read;
smem_read += THREADS_PER_ROW;
}
}
}
__syncthreads();
smem_write = smem_space + (warp_m - 1) * THREADS_PER_ROW * VecSize + tid_c;
if (warp_m > 0) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
*smem_write = db_sum[jt];
smem_write += THREADS_PER_ROW;
}
}
__syncthreads();
smem_read = smem_space + tid_c;
if (warp_m == 0) {
#pragma unroll
for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
db_sum[jt] += *smem_read;
smem_read += THREADS_PER_ROW;
}
}
union {
ScaleT raw;
ScaleT elt[VecSize];
} dg_out, db_out;
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
dg_out.elt[jt] = dg_sum[jt];
db_out.elt[jt] = db_sum[jt];
}
ScaleT *dg_ptr = reinterpret_cast<ScaleT *>(dg_) + col;
ScaleT *db_ptr = reinterpret_cast<ScaleT *>(db_) + col;
*dg_ptr = dg_out.raw;
*db_ptr = db_out.raw;
}
}
}
/* This function support two kinds of computations (only for float and fp16
* type):
*
* Case-1: compute layer_norm_grad for layernorm op by setting mask_ptr and
* d_dropout_src_ptr to nullptr. Here, d_x_ptr returns the grad of layernorm
* input.
*
* Case-2: compute layer_norm_grad + residual_grad + dropout_grad for
* fused_dropout_residual_layernorm op. Here, dx_ptr returns residual_grad.
*
*/
template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver(
const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols,
float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr,
ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream();
if (cols == 1024) {
// step-1: compute dx and reduced part results of dscale and dbias.
const int WARPS_M = 4;
const int WARPS_N = 1;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_WARP = 32;
const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP;
const int THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW;
const int ROWS_PER_CTA = WARPS_M;
// 4 * 1024 * 4
const int SMEM_BYTES = ROWS_PER_CTA * cols * sizeof(U);
// #blocks = 2 * #SM
const int gridx = 2 * dev_ctx.GetSMCount();
// get temp space for dscale and dbias.
framework::Tensor dscale_temp;
dscale_temp.Resize({gridx, cols});
dscale_temp.mutable_data<U>(dev_ctx.GetPlace());
U *dscale_temp_ptr = dscale_temp.data<U>();
framework::Tensor dbias_temp;
dbias_temp.Resize({gridx, cols});
dbias_temp.mutable_data<U>(dev_ctx.GetPlace());
U *dbias_temp_ptr = dbias_temp.data<U>();
if (mask_ptr != nullptr) {
if (d_dropout_src_ptr == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
"can't be null"));
}
fused_ln_bwd_1024_kernel<
true, T, U, ScaleT, MaskType, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<gridx, THREADS_PER_CTA, 0, stream>>>(
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr,
dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor,
d_dropout_src_ptr);
} else {
fused_ln_bwd_1024_kernel<
false, T, U, ScaleT, MaskType, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<gridx, THREADS_PER_CTA, 0, stream>>>(
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr,
dscale_temp_ptr, dbias_temp_ptr, dx_ptr);
}
const int WARPS_M_2 = 16;
const int WARPS_N_2 = 1;
const int BYTES_PER_LDG_2 = 4;
const int VecSize_2 =
std::max(1, static_cast<int>(BYTES_PER_LDG_2 / sizeof(U))); // 1
const int THREADS_PER_WARP_2 = 32;
const int THREADS_PER_ROW_2 = WARPS_N_2 * THREADS_PER_WARP_2; // 32
const int THREADS_PER_CTA_2 =
WARPS_M_2 * THREADS_PER_ROW_2; // 16 * 32 = 512
const int ROWS_PER_CTA_2 = WARPS_M_2; // 16
const int gridx_2 = static_cast<int>(
std::ceil(1024 / static_cast<float>(THREADS_PER_ROW_2 * VecSize_2)));
// #blocks: 32,#threads_per_block: 512
// Note: it is not supported for double type.
if (sizeof(U) > 4) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support float and fp16 type"));
} else {
ln_bwd_1024_final_kernel<
U, ScaleT, VecSize_2, WARPS_M_2, WARPS_N_2,
BYTES_PER_LDG_2><<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>(
gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Fast layer_norm kernel is only used when feature_size is 1024"));
}
}
#endif
template <typename T, typename U, int BDIMX, int BDIMY, int VPTX> template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta( __global__ void LayerNormBackwardPartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int64_t n1, const T *__restrict__ dout, const T *__restrict__ input, const int64_t n1,
...@@ -983,6 +1448,21 @@ static void LayerNormBackward( ...@@ -983,6 +1448,21 @@ static void LayerNormBackward(
break; break;
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
{ {
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
// todo: rule out double type.
if (feature_size == 1024 && sizeof(T) <= 4) {
can_call_1024_kernel = true;
}
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel;
if (can_call_1024_kernel) {
ln_bwd_1024_kernel_driver<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
dev_ctx, batch_size, feature_size, epsilon, x, scale, mean, var,
d_y, d_x, d_scale, d_bias);
} else {
#endif
constexpr int VPT = 4; constexpr int VPT = 4;
constexpr int BDIMX2 = 32; constexpr int BDIMX2 = 32;
constexpr int BDIMY2 = 4; constexpr int BDIMY2 = 4;
...@@ -997,9 +1477,10 @@ static void LayerNormBackward( ...@@ -997,9 +1477,10 @@ static void LayerNormBackward(
U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr()); U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr()); U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());
LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2, LayerNormBackwardPartGradGammaBeta<
VPT><<<blocks2, threads2, 0, stream>>>( T, U, BDIMX2, BDIMY2, VPT><<<blocks2, threads2, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma, d_y, x, batch_size, feature_size, mean, var, epsilon,
part_grad_gamma,
part_grad_beta); // compute part_grad_gamma, beta part_grad_beta); // compute part_grad_gamma, beta
constexpr int BDIMX3 = 32; constexpr int BDIMX3 = 32;
...@@ -1009,8 +1490,8 @@ static void LayerNormBackward( ...@@ -1009,8 +1490,8 @@ static void LayerNormBackward(
LayerNormBackwardSumGradGammaBeta< LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3, T, U, BDIMX3, BDIMY3,
ScaleBiasWithSameTypeX><<<blocks3, threads3, 0, stream>>>( ScaleBiasWithSameTypeX><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size, part_grad_gamma, part_grad_beta, part_size, batch_size,
d_scale, d_bias); feature_size, d_scale, d_bias);
constexpr int BDIMX1 = 32; constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4; constexpr int BDIMY1 = 4;
...@@ -1019,6 +1500,10 @@ static void LayerNormBackward( ...@@ -1019,6 +1500,10 @@ static void LayerNormBackward(
T, U, BDIMX1, BDIMY1, T, U, BDIMX1, BDIMY1,
ScaleBiasWithSameTypeX><<<batch_size, threads1, 0, stream>>>( ScaleBiasWithSameTypeX><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
#ifdef PADDLE_WITH_CUDA
}
#endif
break; break;
} }
default: default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册