未验证 提交 1b6e4664 编写于 作者: L Li Min 提交者: GitHub

Add fp16 support for scale and bias parameter for fused_layernnorm_residual_dropout op. (#38775)

* Add fp16 support for scale/bias for fused_layernnorm_residual_dropout_bias op.
上级 42cfd15e
......@@ -250,11 +250,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
}
// out = layernorm(residual + dropout(src + bias))
void LayernormResidualDropoutBias(
const platform::CUDADeviceContext& ctx, const T* src, const T* residual,
const T* bias, const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* beta, T* dropout_out, MaskType* mask, T* out,
LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) {
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
void LayernormResidualDropoutBias(const platform::CUDADeviceContext& ctx,
const T* src, const T* residual,
const T* bias, const P* gamma,
const P* beta, T* dropout_out,
MaskType* mask, T* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
......@@ -263,7 +266,7 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T, MaskType>(
LaunchLayernormResidualDropoutBias<T, MaskType, U, is_same_type>(
this->rows_, this->cols_, increment, this->dropout_param_.seed,
this->dropout_param_.dropout_prob, epsilon_,
this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test,
......@@ -271,17 +274,19 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
variance, ctx);
}
void LayernormResidualDropoutBiasGrad(
const platform::CUDADeviceContext& ctx, const T* d_out,
const T* layernorm_src, const MaskType* mask,
const LayerNormParamType<T>* gamma, const LayerNormParamType<T>* mean,
const LayerNormParamType<T>* variance, T* d_layernorm_src,
LayerNormParamType<T>* d_scale, LayerNormParamType<T>* d_layernorm_bias,
T* d_dropout_src, T* d_bias, T* d_residual) {
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
void LayernormResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx,
const T* d_out, const T* layernorm_src,
const MaskType* mask, const P* gamma,
const LayerNormParamType<T>* mean,
const LayerNormParamType<T>* variance,
T* d_layernorm_src, P* d_scale,
P* d_layernorm_bias, T* d_dropout_src,
T* d_bias, T* d_residual) {
using U = LayerNormParamType<T>;
LayerNormBackward<T, U>(layernorm_src, d_out, gamma, mean, variance,
d_layernorm_src, d_scale, d_layernorm_bias,
epsilon_, this->rows_, this->cols_, ctx);
LayerNormBackward<T, U, is_same_type>(
layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale,
d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx);
this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src,
d_residual, d_bias);
}
......
......@@ -24,46 +24,57 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT =
typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
/**
* @brief fused add_bias, dropout, add residual and leyer_norm into one
* operators. Currently only support forward
*/
template <typename T, int VecSize>
__device__ void CalcLayernormY(const LayerNormParamType<T> *scale,
const LayerNormParamType<T> *bias, const T *x,
T *y, const int row_id, const int col_id,
const int cols,
const LayerNormParamType<T> mean_val,
const LayerNormParamType<T> invvar) {
using U = LayerNormParamType<T>;
template <typename T, int VecSize, typename U,
bool ScaleBiasWithSameTypeX = false>
__device__ void CalcLayernormY(
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias, const T *x,
T *y, const int row_id, const int col_id, const int cols,
const LayerNormParamType<T> mean_val, const LayerNormParamType<T> invvar) {
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using LoadU = platform::AlignedVector<U, VecSize>;
using LoadScaleOrBias =
platform::AlignedVector<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
LoadU scale_vec;
LoadU bias_vec;
LoadScaleOrBias scale_vec;
LoadScaleOrBias bias_vec;
LoadT x_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
scale_vec[ii] = static_cast<U>(1);
bias_vec[ii] = static_cast<U>(0);
scale_vec[ii] =
static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(1);
bias_vec[ii] =
static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(0);
}
// vectorize load data from global
platform::Load<T, VecSize>(&x[row_id * cols + i], &x_vec);
if (scale != nullptr) {
platform::Load<U, VecSize>(&scale[i], &scale_vec);
platform::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>(&scale[i], &scale_vec);
}
if (bias != nullptr) {
platform::Load<U, VecSize>(&bias[i], &bias_vec);
platform::Load<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
VecSize>(&bias[i], &bias_vec);
}
StoreT y_vec;
for (int ii = 0; ii < VecSize; ii++) {
y_vec[ii] = static_cast<T>(
scale_vec[ii] * (static_cast<U>(x_vec[ii]) - mean_val) * invvar +
bias_vec[ii]);
y_vec[ii] =
static_cast<T>(static_cast<U>(scale_vec[ii]) *
(static_cast<U>(x_vec[ii]) - mean_val) * invvar +
static_cast<U>(bias_vec[ii]));
}
platform::Store<T, VecSize>(y_vec, &y[row_id * cols + i]);
}
......@@ -85,15 +96,17 @@ __device__ void CalcLayernormY(const LayerNormParamType<T> *scale,
* means: [rows]: layernorm means
* vars: [rows]: layernorm vars
*/
template <typename T, typename MaskType, int VecSize>
template <typename T, typename MaskType, int VecSize, typename U,
bool ScaleBiasWithSameTypeX = false>
__global__ void FusedLayernormResidualDropoutBias(
const size_t rows, const size_t cols, uint64_t seed,
const float dropout_prob, const bool is_upscale_in_train,
const bool is_test, const uint64_t increment, const float epsilon,
const T *src, const T *residual, const T *bias,
const LayerNormParamType<T> *scale,
const LayerNormParamType<T> *layernorm_bias, MaskType *mask, T *dst,
T *layernorm_dst, LayerNormParamType<T> *mean, LayerNormParamType<T> *var) {
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask, T *dst, T *layernorm_dst, LayerNormParamType<T> *mean,
LayerNormParamType<T> *var) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
int idx = row_id * cols + col_id;
......@@ -101,7 +114,6 @@ __global__ void FusedLayernormResidualDropoutBias(
curand_init(seed, idx, increment, &state);
T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
using U = LayerNormParamType<T>;
__shared__ U mean_share;
__shared__ U var_share;
......@@ -121,10 +133,12 @@ __global__ void FusedLayernormResidualDropoutBias(
mean_val = BlockReduceSum<U>(mean_val, shared_mean);
var_val = BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) {
auto scale = static_cast<float>(1.) / static_cast<float>(cols);
auto tmp = mean_val * scale;
auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
static_cast<float>(1.) / static_cast<float>(cols));
auto tmp = mean_val * static_cast<U>(scale);
mean[row_id] = mean_share = static_cast<U>(tmp);
var_share = static_cast<U>(var_val * scale - mean_share * mean_share);
var_share = static_cast<U>(var_val * static_cast<U>(scale) -
mean_share * mean_share);
var_share = var_share > U(0) ? var_share : U(0);
var[row_id] = var_share;
}
......@@ -134,8 +148,9 @@ __global__ void FusedLayernormResidualDropoutBias(
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst
CalcLayernormY<T, VecSize>(scale, layernorm_bias, dst, layernorm_dst, row_id,
col_id, cols, mean_val, invvar);
CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(
scale, layernorm_bias, dst, layernorm_dst, row_id, col_id, cols, mean_val,
invvar);
}
/**
......@@ -154,16 +169,17 @@ __global__ void FusedLayernormResidualDropoutBias(
* means: [rows]: layernorm means
* vars: [rows]: layernorm vars
*/
template <typename T, typename MaskType>
template <typename T, typename MaskType, typename U,
bool ScaleBiasWithSameTypeX = false>
void LaunchLayernormResidualDropoutBias(
const uint32_t rows, const uint32_t cols, const int increment,
uint64_t seed, const float dropout_prob, const float epsilon,
const bool is_upscale_in_train, const bool is_test, const T *src,
const T *residual, const T *bias, const LayerNormParamType<T> *scale,
const LayerNormParamType<T> *layernorm_bias, MaskType *mask_data, T *dst,
T *layernorm_dst, LayerNormParamType<T> *mean, LayerNormParamType<T> *var,
const platform::CUDADeviceContext &ctx) {
using U = LayerNormParamType<T>;
const T *residual, const T *bias,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *layernorm_bias,
MaskType *mask_data, T *dst, T *layernorm_dst, LayerNormParamType<T> *mean,
LayerNormParamType<T> *var, const platform::CUDADeviceContext &ctx) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
......@@ -175,8 +191,9 @@ void LaunchLayernormResidualDropoutBias(
// call layernorm forward
switch (GetDesiredBlockDim(cols)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U,
kBlockDim><<<rows, kBlockDim, 0, ctx.stream()>>>(
LayerNormForward<
T, U, kBlockDim,
ScaleBiasWithSameTypeX><<<rows, kBlockDim, 0, ctx.stream()>>>(
dst, scale, layernorm_bias, layernorm_dst, mean, var, epsilon,
cols));
default:
......@@ -184,21 +201,24 @@ void LaunchLayernormResidualDropoutBias(
"Product from begin_norm_axis to end must be larger than 1"));
break;
}
return;
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) {
int blockDim = GetDesiredBlockDim(cols);
FusedLayernormResidualDropoutBias<T, uint8_t,
1><<<rows, blockDim, 0, ctx.stream()>>>(
FusedLayernormResidualDropoutBias<
T, uint8_t, 1, U,
ScaleBiasWithSameTypeX><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
T, uint8_t, VecSize><<<rows, blockDim, 0, ctx.stream()>>>(
T, uint8_t, VecSize, U,
ScaleBiasWithSameTypeX><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
......
......@@ -223,7 +223,7 @@ struct TestFusedLayernormResidualDropoutBias {
layernorm_bias_ptr = layernorm_bias.data<U>();
}
paddle::operators::LaunchLayernormResidualDropoutBias<T, uint8_t>(
paddle::operators::LaunchLayernormResidualDropoutBias<T, uint8_t, U, false>(
rows, cols, increment, seed, dropout_prob, epsilon, is_upscale_in_train,
is_test, src.data<T>(), residual.data<T>(), bias_ptr, scale_ptr,
layernorm_bias_ptr, mask.data<uint8_t>(), out.data<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册