未验证 提交 9f926eb7 编写于 作者: L Leo Chen 提交者: GitHub

Layernorm opt (#29522)

* layernorm fw opt

* layernorm bw opt

* fix typo, test=develop

* remove const dim3 for windows CI compatibility

* merge develop
Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>
上级 40019793
......@@ -107,35 +107,54 @@ struct PairForLayerNormAddFunctor {
}
};
template <typename T>
__inline__ __device__ T rsqrt(const T val) {
return ::rsqrt(val);
}
template <>
__inline__ __device__ float rsqrt(const float val) {
return rsqrtf(val);
}
template <>
__inline__ __device__ half rsqrt(const half val) {
return hrsqrt(val);
}
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U mean_share;
__shared__ U var_share;
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var
double mean_val = 0;
double var_val = 0;
U mean_val = 0;
U var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
U tmp = static_cast<U>(x[i]);
mean_val += tmp;
var_val += (tmp * tmp);
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<double>(mean_val, var_val),
PairForLayerNormAddFunctor<double>());
.Reduce(PairForLayerNorm<U>(mean_val, var_val),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size;
mean[blockIdx.x] = static_cast<U>(tmp);
var[blockIdx.x] = static_cast<U>(pair.second_ / feature_size - tmp * tmp);
mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var[blockIdx.x] = var_share =
static_cast<U>(pair.second_ / feature_size - tmp * tmp);
}
__syncthreads();
mean_val = mean[blockIdx.x];
var_val = static_cast<U>(real_sqrt(var[blockIdx.x] + epsilon));
mean_val = mean_share;
U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y
if (scale != nullptr) {
......@@ -143,26 +162,288 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) / var_val + bias[j]);
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) /
var_val);
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
}
}
} else { // scale == nullptr
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) / var_val +
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) / var_val);
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
}
template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ var,
const float epsilon) {
const int i1 = i1_block + thr_load_row_off;
if (i1 >= i1_end) return;
U curr_mean = mean[i1];
U curr_invvar = rsqrt<U>(var[i1] + epsilon);
for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1 * n2 + i2;
const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ var,
float epsilon, U *part_grad_gamma, U *part_grad_beta) {
// VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
// -> blockDim.x
// template for compile time optimizations
constexpr int row_stride = BDIMX + 1;
const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
const int thr_load_row_off =
(threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;
constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
? BDIMX * BDIMY
: 2 * VPTX * BDIMY * row_stride;
__shared__ U buf[shared_cap];
U *warp_buf1 = reinterpret_cast<U *>(buf);
U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;
for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
buf[idx] = U(0);
}
__syncthreads();
for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
i1_block += VPTX * BDIMY * gridDim.y) {
cuLoadAddStridedInputs<T, U, VPTX>(
i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < VPTX; ++k) {
int row1 = threadIdx.y + k * VPTX;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n2, U *grad_gamma, U *grad_beta) {
// sum partial gradients for gamma and beta
__shared__ U buf[BDIMX * BDIMY];
int i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / BDIMY;
U sum_gamma = U(0);
U sum_beta = U(0);
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
constexpr int nbsize3 = BDIMX * BDIMY / 2;
for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2,
// const U* __restrict__ mean, const U* __restrict__ var, const float
// epsilon, const T* gamma,
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = rsqrt<U>(var[i1] + epsilon);
const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2;
constexpr int numx = BDIMX * BDIMY;
const int thrx = threadIdx.x + threadIdx.y * BDIMX;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
sum_loss1 +=
__shfl_xor_sync(0xffffffff, sum_loss1, mask,
warpSize); // WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 +=
__shfl_xor_sync(0xffffffff, sum_loss2, mask,
warpSize); // WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (BDIMY > 1) {
__shared__ U buf[BDIMX * BDIMY];
for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
......@@ -384,7 +665,11 @@ template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
const U *mean, const U *var, T *d_x, U *d_scale,
U *d_bias, float epsilon, int batch_size,
int feature_size, cudaStream_t stream) {
int feature_size,
const framework::ExecutionContext &ctx) {
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
const int kMaxBlockDim = 512;
const int kMaxBlockNum = 128;
int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
......@@ -485,21 +770,44 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
}
break;
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
switch (block_dim) {
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll<
T, U, kBlockDim, true><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
{
constexpr int VPT = 4;
constexpr int BDIMX2 = 32;
constexpr int BDIMY2 = 4;
dim3 threads2(BDIMX2, BDIMY2, 1);
constexpr int part_size = BDIMY2 * VPT;
const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);
auto part_grad_gamma_ptr =
memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
auto part_grad_beta_ptr =
memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());
LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
VPT><<<blocks2, threads2, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
part_grad_beta); // compute part_grad_gamma, beta
constexpr int BDIMX3 = 32;
constexpr int BDIMY3 = 8;
dim3 threads3(BDIMX3, BDIMY3, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
d_scale, d_bias);
constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
default:
break;
}
......@@ -611,11 +919,9 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
int batch_size = static_cast<int>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]);
auto stream = ctx.cuda_device_context().stream();
LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
d_x_data, d_scale_data, d_bias_data, epsilon,
batch_size, feature_size, stream);
batch_size, feature_size, ctx);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册