未验证 提交 01d04be6 编写于 作者: L Li Min 提交者: GitHub

Optimize layer norm forward when cols is 1024. (#39167)

* Optimize layer_norm fwd when cols is 1024.
上级 6efb9f59
......@@ -19,6 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
#define LN_NUM_COLS 1024
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
......@@ -153,6 +155,191 @@ __global__ void FusedLayernormResidualDropoutBias(
invvar);
}
/*
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is 1024;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
* rows: batch_size * seq_len
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
* residual_out_: [rows, cols], residual + dropout(src)
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
template <
typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t,
int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16,
int ELTS_PER_ROW = 1024, int THREADS_PER_WARP = 32,
int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
int rows, int cols, uint64_t seed, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr,
const T *__restrict__ residual_ptr, const ScaleT *__restrict__ gamma_ptr,
const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31
const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3
const int warp_n = warp % WARPS_N; // 0
const int warp_m = warp / WARPS_N; // 0, 1, 2, 3
const int c = warp_n * THREADS_PER_WARP + lane; // lane
const int r = bidx * ROWS_PER_CTA + warp_m; // row id
int idx = r * LN_NUM_COLS + c;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
Vec_scale gamma[LDGS];
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
platform::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
col += THREADS_PER_ROW;
}
constexpr U rn = 1.f / U(LN_NUM_COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
Vec residual[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
platform::Load<T, VecSize>(
residual_ptr + row * LN_NUM_COLS + col * VecSize, &residual[it]);
col += THREADS_PER_ROW;
}
MaskStoreT mask_vec[LDGS];
if (!is_test) {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
float rand[VecSize];
RandVec<VecSize>(&state, rand);
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
#pragma unroll
mask_vec[it][jt] = static_cast<MaskType>(rand[jt] >= dropout_prob);
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
mask_vec[it][jt] = static_cast<MaskType>(1);
}
}
}
// 4 * 8
U xf[LDGS * VecSize];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
// store dropout_residual_out and mask_out
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(
x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize);
platform::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
U mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
mu_local += xf[it * VecSize + jt];
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
}
U var_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U diff = xf[it * VecSize + jt] - mu_local;
var_local += diff * diff;
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
// Note: the stored var is different for paddle(ln) and apex (fast ln).
// var_out_ptr[row] = rsigma;
var_out_ptr[row] = var_local * rn;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// use fp16 to compute
// ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
// mu_local));
// x[it][jt] = gamma[it][jt] * tmp + beta[it][jt];
// cast to fp32 to compute
U tmp = rsigma * (static_cast<U>(xf[it * VecSize + jt]) - mu_local);
x[it][jt] = static_cast<T>(static_cast<U>(gamma[it][jt]) * tmp +
static_cast<U>(beta[it][jt]));
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
y_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
}
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
......@@ -205,6 +392,13 @@ void LaunchLayernormResidualDropoutBias(
return;
}
bool can_call_1024_kernel = false;
if (cols == 1024 && scale != nullptr && layernorm_bias != nullptr &&
bias == nullptr) {
can_call_1024_kernel = true;
}
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel;
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) {
int blockDim = GetDesiredBlockDim(cols);
......@@ -215,13 +409,35 @@ void LaunchLayernormResidualDropoutBias(
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
T, uint8_t, VecSize, U,
ScaleBiasWithSameTypeX><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
// Note: the grid can not exceed max_grid of the gpu.
const int grid =
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA)));
fused_ln_fwd_1024_kernel<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t,
VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, scale, layernorm_bias, mask_data,
mean, var, dst, layernorm_dst);
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
T, uint8_t, VecSize, U,
ScaleBiasWithSameTypeX><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, bias, scale, layernorm_bias,
mask_data, dst, layernorm_dst, mean, var);
}
}
}
......
......@@ -66,12 +66,10 @@ struct TestFusedLayernormResidualDropoutBias {
ctx = reinterpret_cast<platform::CUDADeviceContext *>(devicectx);
}
TestFusedLayernormResidualDropoutBias(int _rows, int _cols,
uint64_t _seed = 0,
float _dropout_prob = 0.0,
float _epsilon = 0.00001f,
bool _is_upscale_in_train = false,
bool _is_test = false) {
TestFusedLayernormResidualDropoutBias(
int _rows, int _cols, uint64_t _seed = 0, float _dropout_prob = 0.0,
float _epsilon = 0.00001f, bool _is_upscale_in_train = false,
bool _is_test = false, bool _has_bias = true) {
rows = _rows;
cols = _cols;
seed = _seed;
......@@ -79,7 +77,7 @@ struct TestFusedLayernormResidualDropoutBias {
epsilon = _epsilon;
is_upscale_in_train = _is_upscale_in_train;
is_test = _is_test;
has_bias = true;
has_bias = _has_bias;
has_scale = true;
has_layernorm_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -283,7 +281,6 @@ static void BaseTest(const bool is_fp16 = false) {
}
}
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBias) { BaseTest<float>(); }
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasDouble) {
......@@ -330,3 +327,12 @@ TEST(FusedDropout, GPUFusedLayernormResidualDropoutLargeShape) {
test.Run();
test.CheckOut(static_cast<float>(1e-4));
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutFp16MLperf) {
const int rows = 512;
const int cols = 1024;
TestFusedLayernormResidualDropoutBias<platform::float16> test(
rows, cols, 0, 0, 0.00001f, false, false, false);
test.Run();
test.CheckOut(static_cast<platform::float16>(1e-2));
}
......@@ -23,6 +23,7 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
......@@ -35,6 +36,8 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
#define LN_NUM_COLS 1024
inline static int GetDesiredBlockDim(int64_t block_dim) {
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
......@@ -169,6 +172,118 @@ __inline__ __device__ half rsqrt_(const half val) {
}
#endif
#ifdef PADDLE_WITH_CUDA
template <typename T, typename U, typename ScaleT = U, int VecSize = 8,
int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16,
int ELTS_PER_ROW = 1024, int THREADS_PER_WARP = 32,
int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW,
int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
int rows, int cols, const float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ y_ptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31
const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3
const int warp_n = warp % WARPS_N; // 0
const int warp_m = warp / WARPS_N; // 0, 1, 2, 3
const int c = warp_n * THREADS_PER_WARP + lane; // lane
const int r = bidx * ROWS_PER_CTA + warp_m; // row id
Vec_scale gamma[LDGS];
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
platform::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
col += THREADS_PER_ROW;
}
constexpr U rn = 1.f / U(LN_NUM_COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
col += THREADS_PER_ROW;
}
U xf[LDGS * VecSize];
U mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
xf[it * VecSize + jt] = U(x[it][jt]);
mu_local += xf[it * VecSize + jt];
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
}
U var_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U diff = xf[it * VecSize + jt] - mu_local;
var_local += diff * diff;
}
}
#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
// Note: to assure if it is right for double
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
var_out_ptr[row] = var_local * rn;
}
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// use fp16 to compute
// ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
// mu_local));
// x[it][jt] = gamma[it][jt] * tmp + beta[it][jt];
// cast to fp32 to compute
U tmp = (rsigma * (static_cast<U>(xf[it * VecSize + jt]) - mu_local));
x[it][jt] = static_cast<T>(static_cast<U>(gamma[it][jt]) * tmp +
static_cast<U>(beta[it][jt]));
}
}
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
y_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
}
}
#endif
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT =
typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
......
......@@ -112,11 +112,49 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
} \
} while (0)
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
can_call_1024_kernel = true;
}
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
if (is_scale_bias_same_dtype_with_x) {
ln_fwd_1024_kernel<T, U, T, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size, feature_size, epsilon, x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data), mean_data, var_data,
y_data);
} else {
ln_fwd_1024_kernel<T, U, U, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size, feature_size, epsilon, x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data), mean_data, var_data,
y_data);
}
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
#endif
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
};
......
......@@ -278,6 +278,8 @@ class TestLayerNormOp(unittest.TestCase):
has_scale=False,
has_bias=False,
y_grad_scale=0.1)
self.check_forward_backward(
shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True)
class TestLayerNormAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册