未验证 提交 55accdfc 编写于 作者: W wenbin 提交者: GitHub

preln_residual_bias optimization (#46496)

* half2

* add epsilon
上级 4d772144
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -30,6 +31,116 @@ namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T *val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
__global__ void generalAddBiasResidualLayerNormOpt2(
half2 *normed_output,
half2 *output,
const half2 *__restrict bias,
const half2 *__restrict src,
const half2 *__restrict residual,
const half2 *__restrict gamma,
const half2 *__restrict beta,
int m,
int n,
float epsilon) {
__shared__ float s_mean;
__shared__ float s_variance;
float x_sum = 0.0f;
float x2_sum = 0.0f;
const int b_offset = blockIdx.x * n;
#pragma unroll 2
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int index = b_offset + i;
float val_1 = 0.0f;
float val_2 = 0.0f;
half2 tmp;
if (bias) {
tmp = __ldg(&bias[i]);
val_1 += static_cast<float>(tmp.x);
val_2 += static_cast<float>(tmp.y);
}
{
tmp = __ldg(&residual[index]);
val_1 += static_cast<float>(tmp.x);
val_2 += static_cast<float>(tmp.y);
}
{
tmp = __ldg(&src[index]);
val_1 += static_cast<float>(tmp.x);
val_2 += static_cast<float>(tmp.y);
}
tmp.x = __float2half_rn(val_1);
tmp.y = __float2half_rn(val_2);
output[index] = tmp;
x_sum += val_1 + val_2;
x2_sum += val_1 * val_1 + val_2 * val_2;
}
float sums[2];
sums[0] = x_sum;
sums[1] = x2_sum;
blockReduceSumV2<float, 2>(sums);
if (threadIdx.x == 0) {
s_mean = sums[0] / n / 2;
s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + epsilon);
}
__syncthreads();
half2 mean_2 = __float2half2_rn(s_mean);
half2 var_2 = __float2half2_rn(s_variance);
#pragma unroll 2
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int index = b_offset + i;
half2 val = __hmul2(__hmul2(__hsub2(output[index], mean_2), var_2),
__ldg(&gamma[i]));
if (beta) {
val = __hadd2(val, __ldg(&beta[i]));
}
normed_output[index] = val;
}
}
#endif
using half = phi::dtype::float16;
#if IS_TRT_VERSION_GE(6000)
......@@ -306,30 +417,48 @@ int PrelnResidualBiasPluginDynamic::enqueue(
float *mean = nullptr;
float *var = nullptr;
const int VecSize = 8;
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<half,
uint8_t,
VecSize,
float,
false>()(
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
layernorm_dst,
mean,
var,
stream);
// if odd
if (hidden & 1 == 0) {
int half_n = hidden / 2;
int half_n_32 = (half_n + 31) / 32 * 32;
int block(std::min(half_n_32, 512));
generalAddBiasResidualLayerNormOpt2<<<rows, block, 0, stream>>>(
reinterpret_cast<half2 *>(layernorm_dst),
reinterpret_cast<half2 *>(dst),
(const half2 *)bias,
(const half2 *)input2,
(const half2 *)input1,
(const half2 *)scale,
(const half2 *)layernorm_bias,
rows,
half_n,
epsilon);
} else {
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<half,
uint8_t,
VecSize,
float,
false>()(
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
layernorm_dst,
mean,
var,
stream);
}
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册