未验证 提交 7a3d05d9 编写于 作者: L limingshu 提交者: GitHub

Optimization for layerNormGrad [Part1] (#51282)

* first commit

* fix code bugs in for_loop

* fix bugs in cuLoadAddStridedInputs.

* optimization for LayerNormBackwardComputeGradInput

* add unitest for validating the optimization

* fix windows ci error
上级 e4ba5f86
...@@ -65,10 +65,7 @@ __forceinline__ __device__ U BlockReduceSum(U val, U *shared) { ...@@ -65,10 +65,7 @@ __forceinline__ __device__ U BlockReduceSum(U val, U *shared) {
int wid = threadIdx.x / warpSize; int wid = threadIdx.x / warpSize;
val = WarpReduceSum(val); // Each warp performs partial reduction val = WarpReduceSum(val); // Each warp performs partial reduction
__syncthreads();
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions __syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed // read from shared memory only if that warp existed
val = val =
...@@ -507,8 +504,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block, ...@@ -507,8 +504,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
const int row_stride, const int row_stride,
U *warp_buf1, U *warp_buf1,
U *warp_buf2, U *warp_buf2,
const T *input, const T *__restrict__ input,
const T *dout, const T *__restrict__ dout,
const int64_t i1_end, const int64_t i1_end,
const int64_t n2, const int64_t n2,
const U *__restrict__ mean, const U *__restrict__ mean,
...@@ -518,6 +515,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block, ...@@ -518,6 +515,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
if (i1 >= i1_end) return; if (i1 >= i1_end) return;
U curr_mean = mean[i1]; U curr_mean = mean[i1];
U curr_invvar = rsqrt_<U>(var[i1] + epsilon); U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
#pragma unroll
for (int k = 0; k < VPT; ++k) { for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k; const int i2 = i2_off + k;
const int64_t load_idx = i1 * n2 + i2; const int64_t load_idx = i1 * n2 + i2;
...@@ -1151,38 +1149,38 @@ __global__ void LayerNormBackwardPartGradGammaBeta(const T *__restrict__ dout, ...@@ -1151,38 +1149,38 @@ __global__ void LayerNormBackwardPartGradGammaBeta(const T *__restrict__ dout,
float epsilon, float epsilon,
U *part_grad_gamma, U *part_grad_gamma,
U *part_grad_beta) { U *part_grad_beta) {
// VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX // VPTX -> value per thread.x, BDIMX -> blockDim.x,
// -> blockDim.x // BDIMY -> blockDim.y, template for compile time optimizations.
// template for compile time optimizations constexpr int RowStride = BDIMX + 1;
constexpr int BLOCK_SIZE = BDIMX * BDIMY;
constexpr int VPTX_MUL_BDIMY = VPTX * BDIMY;
constexpr int SharedSize = (BLOCK_SIZE > 2 * VPTX_MUL_BDIMY * RowStride)
? BLOCK_SIZE
: 2 * VPTX_MUL_BDIMY * RowStride;
constexpr int row_stride = BDIMX + 1;
const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1); const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
const int thr_load_row_off = const int thr_load_row_off =
(threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY; (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
const int i2_off = blockIdx.x * BDIMX + thr_load_col_off; const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;
constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride) __shared__ U buf[SharedSize];
? BDIMX * BDIMY
: 2 * VPTX * BDIMY * row_stride;
__shared__ U buf[shared_cap];
U *warp_buf1 = reinterpret_cast<U *>(buf); U *warp_buf1 = reinterpret_cast<U *>(buf);
U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride; U *warp_buf2 = warp_buf1 + VPTX_MUL_BDIMY * RowStride;
for (int idx = threadIdx.y * blockDim.x + threadIdx.x; for (int idx = threadIdx.y * BDIMX + threadIdx.x;
idx < 2 * VPTX * BDIMY * row_stride; idx < 2 * VPTX_MUL_BDIMY * RowStride;
idx += BDIMX * BDIMY) { idx += BLOCK_SIZE) {
buf[idx] = U(0); buf[idx] = U(0);
} }
__syncthreads(); __syncthreads();
for (int64_t i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1; for (int64_t i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
i1_block += VPTX * BDIMY * gridDim.y) { i1_block += VPTX_MUL_BDIMY * gridDim.y) {
cuLoadAddStridedInputs<T, U, VPTX>(i1_block, cuLoadAddStridedInputs<T, U, VPTX>(i1_block,
thr_load_row_off, thr_load_row_off,
thr_load_col_off, thr_load_col_off,
i2_off, i2_off,
row_stride, RowStride,
warp_buf1, warp_buf1,
warp_buf2, warp_buf2,
input, input,
...@@ -1195,54 +1193,53 @@ __global__ void LayerNormBackwardPartGradGammaBeta(const T *__restrict__ dout, ...@@ -1195,54 +1193,53 @@ __global__ void LayerNormBackwardPartGradGammaBeta(const T *__restrict__ dout,
} }
__syncthreads(); __syncthreads();
// inter-warp reductions // inter-warp reductions, sum within each warp
// sum within each warp
U acc1 = U(0); U acc1 = U(0);
U acc2 = U(0); U acc2 = U(0);
#pragma unroll
for (int k = 0; k < VPTX; ++k) { for (int k = 0; k < VPTX; ++k) {
int row1 = threadIdx.y + k * VPTX; int row1 = threadIdx.y + k * VPTX;
int idx1 = row1 * row_stride + threadIdx.x; int idx1 = row1 * RowStride + threadIdx.x;
acc1 += warp_buf1[idx1]; acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1]; acc2 += warp_buf2[idx1];
} }
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; warp_buf1[threadIdx.y * RowStride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; warp_buf2[threadIdx.y * RowStride + threadIdx.x] = acc2;
__syncthreads(); __syncthreads();
// sum all warps // sum all warps
#pragma unroll
for (int offset = VPTX >> 1; offset > 1; offset >>= 1) { for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
if (threadIdx.y < offset) { if (threadIdx.y < offset) {
int row1 = threadIdx.y; int row1 = threadIdx.y;
int row2 = threadIdx.y + offset; int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x; int idx1 = row1 * RowStride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x; int idx2 = row2 * RowStride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2]; warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2]; warp_buf2[idx1] += warp_buf2[idx2];
} }
__syncthreads(); __syncthreads();
} }
int64_t i2 = blockIdx.x * blockDim.x + threadIdx.x; int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) { if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y; int row1 = threadIdx.y;
int row2 = threadIdx.y + 1; int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x; int idx1 = row1 * RowStride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x; int idx2 = row2 * RowStride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
} }
} }
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX> template <typename T, typename U, int BDIMX, int BDIMY, typename ScaleT>
__global__ void LayerNormBackwardSumGradGammaBeta( __global__ void LayerNormBackwardSumGradGammaBeta(const U *part_grad_gamma,
const U *part_grad_gamma,
const U *part_grad_beta, const U *part_grad_beta,
const int part_size, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n1,
const int n2, const int n2,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_gamma, ScaleT *grad_gamma,
LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_beta) { ScaleT *grad_beta) {
// sum partial gradients for gamma and beta // sum partial gradients for gamma and beta
using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX>;
__shared__ U buf[BDIMX * BDIMY]; __shared__ U buf[BDIMX * BDIMY];
int64_t i2 = blockIdx.x * BDIMX + threadIdx.x; int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) { if (i2 < n2) {
...@@ -1279,27 +1276,26 @@ __global__ void LayerNormBackwardSumGradGammaBeta( ...@@ -1279,27 +1276,26 @@ __global__ void LayerNormBackwardSumGradGammaBeta(
} }
// write out fully summed gradients // write out fully summed gradients
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
grad_gamma[i2] = static_cast<ScaleBiasT>(sum_gamma); grad_gamma[i2] = static_cast<ScaleT>(sum_gamma);
grad_beta[i2] = static_cast<ScaleBiasT>(sum_beta); grad_beta[i2] = static_cast<ScaleT>(sum_beta);
} }
} }
} }
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX> template <typename T, typename U, int BDIMX, int BDIMY, typename ScaleT>
__global__ void LayerNormBackwardComputeGradInput( __global__ void LayerNormBackwardComputeGradInput(const T *__restrict__ dout,
const T *__restrict__ dout,
const T *__restrict__ input, const T *__restrict__ input,
const int n1, const int n1,
const int n2, const int n2,
const U *__restrict__ mean, const U *__restrict__ mean,
const U *__restrict__ var, const U *__restrict__ var,
const float epsilon, const float epsilon,
const LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *gamma, const ScaleT *gamma,
T *grad_input) { T *grad_input) {
#ifdef __HIPCC__ #ifdef __HIPCC__
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) { for (int64_t i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else #else
for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) { for (int64_t i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) {
#endif #endif
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
...@@ -1345,25 +1341,16 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -1345,25 +1341,16 @@ __global__ void LayerNormBackwardComputeGradInput(
} }
} }
// intra-warp reductions // intra-warp reductions
#pragma unroll
for (int mask = BDIMX / 2; mask > 0; mask /= 2) { for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
sum_loss1 += __shfl_xor(sum_loss1, // WARP_SHFL_XOR(sum_loss, mask);
mask, sum_loss1 += __shfl_xor(sum_loss1, mask, warpSize);
warpSize); // WARP_SHFL_XOR(sum_loss1, mask); sum_loss2 += __shfl_xor(sum_loss2, mask, warpSize);
sum_loss2 += __shfl_xor(sum_loss2,
mask,
warpSize); // WARP_SHFL_XOR(sum_loss2, mask);
#else #else
sum_loss1 += // WARP_SHFL_XOR(sum_loss, mask);
__shfl_xor_sync(0xffffffff, sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, warpSize);
sum_loss1, sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, warpSize);
mask,
warpSize); // WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 +=
__shfl_xor_sync(0xffffffff,
sum_loss2,
mask,
warpSize); // WARP_SHFL_XOR(sum_loss2, mask);
#endif #endif
} }
// inter-warp reductions // inter-warp reductions
...@@ -1423,6 +1410,167 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -1423,6 +1410,167 @@ __global__ void LayerNormBackwardComputeGradInput(
} }
} }
template <typename T, typename U, typename ScaleT, int DataPerTid>
__global__ void LayerNormBackwardComputeGradInputWithSmallFeatureSize(
const T *__restrict__ dout,
const T *__restrict__ input,
const int n1,
const int n2,
const U *__restrict__ mean,
const U *__restrict__ var,
const float epsilon,
const ScaleT *__restrict__ gamma,
T *grad_input) {
constexpr int WarpSize = 32;
#ifdef __HIPCC__
for (int64_t bid = hipBlockIdx_x; bid < n1; bid += hipGridDim_x) {
#else
for (int64_t bid = blockIdx.x; bid < n1; bid += gridDim.x) {
#endif
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[bid];
const U c_invvar = rsqrt_<U>(var[bid] + epsilon);
const int main_vec_n2 = n2 / DataPerTid;
const int tid_num = WarpSize * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * WarpSize;
// One feature-size per block.
const T *__restrict__ k_dout = dout + bid * n2;
const T *__restrict__ k_input = input + bid * n2;
T *k_grad_input = grad_input + bid * n2;
// Data storage location in local register.
using VecT = phi::AlignedVector<T, DataPerTid>;
using VecScaleT = phi::AlignedVector<ScaleT, DataPerTid>;
const VecT *__restrict__ v_k_dout =
reinterpret_cast<const VecT *__restrict__>(k_dout);
const VecT *__restrict__ v_k_input =
reinterpret_cast<const VecT *__restrict__>(k_input);
const VecScaleT *__restrict__ v_gamma =
reinterpret_cast<const VecScaleT *__restrict__>(gamma);
VecT *v_grad = reinterpret_cast<VecT *>(k_grad_input);
// Each thread shall deal with no more than 8 data.
U dout_data[8];
U input_data[8];
U gamma_data[8];
if (gamma != NULL) {
int tid = thrx;
for (int i = 0; tid < main_vec_n2; tid += tid_num, ++i) {
VecT v_tmp_dout = v_k_dout[tid];
VecT v_tmp_input = v_k_input[tid];
VecScaleT v_tmp_gamma = v_gamma[tid];
#pragma unroll
for (int k = 0; k < DataPerTid; ++k) {
const int idx = k + i * DataPerTid;
dout_data[idx] = static_cast<U>(v_tmp_dout[k]);
input_data[idx] = static_cast<U>(v_tmp_input[k]);
gamma_data[idx] = static_cast<U>(v_tmp_gamma[k]);
sum_loss1 += dout_data[idx] * gamma_data[idx];
sum_loss2 += dout_data[idx] * gamma_data[idx] *
(input_data[idx] - c_mean) * c_invvar;
}
}
} else {
int tid = thrx;
for (int i = 0; tid < main_vec_n2; tid += tid_num, ++i) {
VecT v_tmp_dout = v_k_dout[tid];
VecT v_tmp_input = v_k_input[tid];
#pragma unroll
for (int k = 0; k < DataPerTid; ++k) {
const int idx = k + i * DataPerTid;
dout_data[idx] = static_cast<U>(v_tmp_dout[k]);
input_data[idx] = static_cast<U>(v_tmp_input[k]);
sum_loss1 += dout_data[idx];
sum_loss2 += dout_data[idx] * (input_data[idx] - c_mean) * c_invvar;
}
}
}
// intra-warp reductions
#pragma unroll
for (int mask = WarpSize / 2; mask > 0; mask /= 2) {
#ifdef PADDLE_WITH_HIP
// WARP_SHFL_XOR(sum_loss, mask);
sum_loss1 += __shfl_xor(sum_loss1, mask, warpSize);
sum_loss2 += __shfl_xor(sum_loss2, mask, warpSize);
#else
// WARP_SHFL_XOR(sum_loss, mask);
sum_loss1 += __shfl_xor_sync(0xffffffff, sum_loss1, mask, WarpSize);
sum_loss2 += __shfl_xor_sync(0xffffffff, sum_loss2, mask, WarpSize);
#endif
}
// inter-warp reductions
if (blockDim.y > 1) {
__shared__ U buf[512];
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * WarpSize + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
U fH = static_cast<U>(n2);
U ratio_term = (static_cast<U>(1) / fH) * c_invvar;
if (gamma != NULL) {
int tid = thrx;
for (int i = 0; tid < main_vec_n2; tid += tid_num, ++i) {
VecT temp_grad;
#pragma unroll
for (int k = 0; k < DataPerTid; ++k) {
const int idx = i * DataPerTid + k;
const U c_h = input_data[idx];
const U c_loss = dout_data[idx];
U f_grad_input = fH * c_loss * gamma_data[idx] - sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
temp_grad[k] = static_cast<T>(f_grad_input * ratio_term);
}
v_grad[tid] = temp_grad;
}
} else {
int tid = thrx;
for (int i = 0; tid < main_vec_n2; tid += tid_num, ++i) {
VecT temp_grad;
#pragma unroll
for (int k = 0; k < DataPerTid; ++k) {
const int idx = i * DataPerTid + k;
const U c_h = input_data[idx];
const U c_loss = dout_data[idx];
U f_grad_input = fH * c_loss - sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
temp_grad[k] = static_cast<T>(f_grad_input * ratio_term);
}
v_grad[tid] = temp_grad;
}
}
}
}
// Make sure that d_scale != nullptr && d_bias != nullptr // Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr // Since d_scale != nullptr, scale would not be nullptr
template <typename T, template <typename T,
...@@ -1684,6 +1832,18 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne( ...@@ -1684,6 +1832,18 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
} }
} }
inline int VecSizeJudgeForeGradInput(const int feature_size,
const int vec_size) {
if (!(feature_size & (vec_size - 1))) {
return vec_size;
} else if (vec_size == 4) {
if (!(feature_size & 1)) {
return 2;
}
}
return 1;
}
template <typename T, typename U, bool ScaleBiasWithSameTypeX = false> template <typename T, typename U, bool ScaleBiasWithSameTypeX = false>
static void LayerNormBackward( static void LayerNormBackward(
const T *x, const T *x,
...@@ -1916,27 +2076,26 @@ static void LayerNormBackward( ...@@ -1916,27 +2076,26 @@ static void LayerNormBackward(
d_bias); d_bias);
} else { } else {
#endif #endif
using ScaleT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
constexpr int BDIMX = 32;
constexpr int VPT = 4; constexpr int VPT = 4;
constexpr int BDIMX2 = 32; constexpr int BDIMY1 = 4;
constexpr int BDIMY2 = 4; constexpr int PartSize = BDIMY1 * VPT;
dim3 threads2(BDIMX2, BDIMY2, 1); dim3 threads2(BDIMX, BDIMY1, 1);
constexpr int part_size = BDIMY2 * VPT; dim3 blocks2((feature_size + BDIMX - 1) / BDIMX, PartSize, 1);
const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);
auto part_grad_gamma_ptr = phi::memory_utils::Alloc( int64_t param_num = PartSize * feature_size;
dev_ctx.GetPlace(), auto part_grad_param_ptr = phi::memory_utils::Alloc(
part_size * feature_size * sizeof(U),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto part_grad_beta_ptr = phi::memory_utils::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
part_size * feature_size * sizeof(U), param_num * sizeof(U) * 2, // for both gamma and beta
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());
LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2, VPT> U *part_grad_gamma = reinterpret_cast<U *>(part_grad_param_ptr->ptr());
<<<blocks2, threads2, 0, stream>>>( U *part_grad_beta = reinterpret_cast<U *>(part_grad_gamma + param_num);
d_y,
LayerNormBackwardPartGradGammaBeta<T, U, BDIMX, BDIMY1, VPT>
<<<blocks2, threads2, 0, stream>>>(d_y,
x, x,
batch_size, batch_size,
feature_size, feature_size,
...@@ -1944,33 +2103,66 @@ static void LayerNormBackward( ...@@ -1944,33 +2103,66 @@ static void LayerNormBackward(
var, var,
epsilon, epsilon,
part_grad_gamma, part_grad_gamma,
part_grad_beta); // compute part_grad_gamma, beta part_grad_beta);
constexpr int BDIMX3 = 32; constexpr int BDIMY2 = 8;
constexpr int BDIMY3 = 8; dim3 threads3(BDIMX, BDIMY2, 1);
dim3 threads3(BDIMX3, BDIMY3, 1); const dim3 blocks3((feature_size + BDIMX - 1) / BDIMX, 1, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1); LayerNormBackwardSumGradGammaBeta<T, U, BDIMX, BDIMY2, ScaleT>
LayerNormBackwardSumGradGammaBeta<T,
U,
BDIMX3,
BDIMY3,
ScaleBiasWithSameTypeX>
<<<blocks3, threads3, 0, stream>>>(part_grad_gamma, <<<blocks3, threads3, 0, stream>>>(part_grad_gamma,
part_grad_beta, part_grad_beta,
part_size, PartSize,
batch_size, batch_size,
feature_size, feature_size,
d_scale, d_scale,
d_bias); d_bias);
constexpr int BDIMX1 = 32; uint64_t addr = reinterpret_cast<uint64_t>(d_y) |
constexpr int BDIMY1 = 4; reinterpret_cast<uint64_t>(x) |
dim3 threads1(BDIMX1, BDIMY1, 1); reinterpret_cast<uint64_t>(d_x);
LayerNormBackwardComputeGradInput<T, int vec_size = phi::GetVectorizedSize<T>(reinterpret_cast<T *>(addr));
U, int real_vec = VecSizeJudgeForeGradInput(feature_size, vec_size);
BDIMX1,
BDIMY1, if (feature_size <= 2048) {
ScaleBiasWithSameTypeX> // One thread must work with at least real_vec quantity data, at most
// 8 data.
int data_per_warp = BDIMX * real_vec;
uint32_t warp_num =
feature_size < data_per_warp ? 1 : (feature_size / data_per_warp);
#if defined(__clang__) || defined(__GNUC__)
int block_dim_y = std::min(8, 1 << (31 - __builtin_clz(warp_num)));
#else
int block_dim_y = 1;
while (warp_num != 0) {
warp_num = warp_num >> 1;
block_dim_y <<= 1;
}
block_dim_y = std::min(8, (block_dim_y / 2));
#endif // __GNUCC__
dim3 threads1(BDIMX, block_dim_y, 1);
#define IMPL_BACKWARD_FOR_INPUT(num) \
LayerNormBackwardComputeGradInputWithSmallFeatureSize<T, U, ScaleT, num> \
<<<batch_size, threads1, 0, stream>>>( \
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
switch (real_vec) {
case 4: {
IMPL_BACKWARD_FOR_INPUT(4);
} break;
case 2: {
IMPL_BACKWARD_FOR_INPUT(2);
} break;
default: {
IMPL_BACKWARD_FOR_INPUT(1);
}
}
#undef IMPL_BACKWARD_FOR_INPUT
} else {
constexpr int BDIMY3 = 4;
dim3 threads1(BDIMX, BDIMY3, 1);
LayerNormBackwardComputeGradInput<T, U, BDIMX, BDIMY3, ScaleT>
<<<batch_size, threads1, 0, stream>>>(d_y, <<<batch_size, threads1, 0, stream>>>(d_y,
x, x,
batch_size, batch_size,
...@@ -1980,6 +2172,7 @@ static void LayerNormBackward( ...@@ -1980,6 +2172,7 @@ static void LayerNormBackward(
epsilon, epsilon,
scale, scale,
d_x); d_x);
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
......
...@@ -265,8 +265,8 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -265,8 +265,8 @@ class TestLayerNormOp(unittest.TestCase):
test_with_place(place, shape, begin_norm_axis) test_with_place(place, shape, begin_norm_axis)
def test_check_forward_backward_with_scale_and_bias(self): def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward( self.check_forward_backward(
shape=[2, 3, 4, 5], shape=[2, 3, 4, 5],
begin_norm_axis=1, begin_norm_axis=1,
...@@ -290,6 +290,7 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -290,6 +290,7 @@ class TestLayerNormOp(unittest.TestCase):
shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1 shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1
) )
self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2) self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2)
self.check_forward_backward(shape=[3, 2, 1133], begin_norm_axis=2)
self.check_forward_backward( self.check_forward_backward(
shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1 shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册