未验证 提交 ae6b4713 编写于 作者: W Wang Bojun 提交者: GitHub

preln_res_bias_layernorm half2 bugfix and unroll opt (#46619)

* preln_res_bias_layernorm bugfix unroll opt

* code style refine

* NOLINT for codestyle
上级 9ea279a4
......@@ -69,6 +69,7 @@ __inline__ __device__ T blockReduceSumV2(T *val) {
return (T)0.0f;
}
template <int UNROLL_FACTOR>
__global__ void generalAddBiasResidualLayerNormOpt2(
half2 *normed_output,
half2 *output,
......@@ -87,7 +88,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
float x2_sum = 0.0f;
const int b_offset = blockIdx.x * n;
#pragma unroll 2
#pragma unroll UNROLL_FACTOR
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int index = b_offset + i;
float val_1 = 0.0f;
......@@ -129,7 +130,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
half2 mean_2 = __float2half2_rn(s_mean);
half2 var_2 = __float2half2_rn(s_variance);
#pragma unroll 2
#pragma unroll UNROLL_FACTOR
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int index = b_offset + i;
half2 val = __hmul2(__hmul2(__hsub2(output[index], mean_2), var_2),
......@@ -141,6 +142,20 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
}
#endif
}
#define HALF2_ADD_BIAS_RESIDUAL_LAYERNORM_OPT2(UNROLL_FACTOR) \
generalAddBiasResidualLayerNormOpt2<UNROLL_FACTOR> \
<<<rows, block, 0, stream>>>(reinterpret_cast<half2 *>(layernorm_dst), \
reinterpret_cast<half2 *>(dst), \
(const half2 *)bias, \
(const half2 *)input2, \
(const half2 *)input1, \
(const half2 *)fp16_scale_gpu_, \
(const half2 *)fp16_bias_gpu_, \
rows, \
half_n, \
epsilon);
#endif
using half = phi::dtype::float16;
......@@ -157,6 +172,18 @@ int PrelnResidualBiasPluginDynamic::initialize() TRT_NOEXCEPT {
scale_.data(),
scale_size_ * sizeof(float),
cudaMemcpyHostToDevice);
if (with_fp16_) {
cudaMalloc(&fp16_bias_gpu_, sizeof(half) * bias_size_);
cudaMemcpy(fp16_bias_gpu_,
fp16_bias_.data(),
bias_size_ * sizeof(half),
cudaMemcpyHostToDevice);
cudaMalloc(&fp16_scale_gpu_, sizeof(half) * scale_size_);
cudaMemcpy(fp16_scale_gpu_,
fp16_scale_.data(),
scale_size_ * sizeof(half),
cudaMemcpyHostToDevice);
}
if (ele_bias_size_ > 0) {
if (with_fp16_) {
cudaMalloc(&ele_bias_gpu_, sizeof(half) * ele_bias_size_);
......@@ -183,10 +210,18 @@ void PrelnResidualBiasPluginDynamic::terminate() TRT_NOEXCEPT {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (fp16_bias_gpu_) {
cudaFree(fp16_bias_gpu_);
fp16_bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
if (fp16_scale_gpu_) {
cudaFree(fp16_scale_gpu_);
fp16_scale_gpu_ = nullptr;
}
if (ele_bias_gpu_) {
cudaFree(ele_bias_gpu_);
ele_bias_gpu_ = nullptr;
......@@ -217,7 +252,9 @@ nvinfer1::IPluginV2DynamicExt *PrelnResidualBiasPluginDynamic::clone() const
}
ptr->bias_gpu_ = bias_gpu_;
ptr->fp16_bias_gpu_ = fp16_bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
ptr->fp16_scale_gpu_ = fp16_scale_gpu_;
ptr->ele_bias_gpu_ = ele_bias_gpu_;
return ptr;
}
......@@ -232,7 +269,8 @@ int PrelnResidualBiasPluginDynamic::getNbOutputs() const TRT_NOEXCEPT {
size_t PrelnResidualBiasPluginDynamic::getSerializationSize() const
TRT_NOEXCEPT {
size_t ser_size = SerializedSize(bias_) + SerializedSize(scale_) +
size_t ser_size = SerializedSize(bias_) + SerializedSize(fp16_bias_) +
SerializedSize(scale_) + SerializedSize(fp16_scale_) +
SerializedSize(fp32_ele_bias_) +
SerializedSize(fp16_ele_bias_) +
SerializedSize(bias_size_) + SerializedSize(scale_size_) +
......@@ -243,7 +281,9 @@ size_t PrelnResidualBiasPluginDynamic::getSerializationSize() const
void PrelnResidualBiasPluginDynamic::serialize(void *buffer) const
TRT_NOEXCEPT {
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, fp16_bias_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, fp16_scale_);
SerializeValue(&buffer, fp32_ele_bias_);
SerializeValue(&buffer, fp16_ele_bias_);
SerializeValue(&buffer, bias_size_);
......@@ -419,23 +459,34 @@ int PrelnResidualBiasPluginDynamic::enqueue(
float *mean = nullptr;
float *var = nullptr;
const int VecSize = 8;
// if odd
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
if (hidden & 1 == 0) {
// if hidden is even, use half2 kernel generalAddBiasResidualLayerNormOpt2
if (hidden % 2 == 0) {
int half_n = hidden / 2;
int half_n_32 = (half_n + 31) / 32 * 32;
int block(std::min(half_n_32, 512));
generalAddBiasResidualLayerNormOpt2<<<rows, block, 0, stream>>>(
reinterpret_cast<half2 *>(layernorm_dst),
reinterpret_cast<half2 *>(dst),
(const half2 *)bias,
(const half2 *)input2,
(const half2 *)input1,
(const half2 *)scale,
(const half2 *)layernorm_bias,
rows,
half_n,
epsilon);
dim3 block(std::min(half_n_32, 512));
int rolls_per_thread = half_n / block.x;
int unroll_factor = 8;
while (unroll_factor > rolls_per_thread && unroll_factor > 1) {
unroll_factor /= 2;
}
switch (unroll_factor) {
case 1:
HALF2_ADD_BIAS_RESIDUAL_LAYERNORM_OPT2(1);
break;
case 2:
HALF2_ADD_BIAS_RESIDUAL_LAYERNORM_OPT2(2);
break;
case 4:
HALF2_ADD_BIAS_RESIDUAL_LAYERNORM_OPT2(4);
break;
case 8:
HALF2_ADD_BIAS_RESIDUAL_LAYERNORM_OPT2(8);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"Invalid UNROLL_FACTOR in preln_residual_bias trt plugin."));
}
} else {
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<half,
uint8_t,
......
......@@ -45,10 +45,19 @@ class PrelnResidualBiasPluginDynamic : public DynamicPluginTensorRT {
bias_.resize(bias_size);
scale_.resize(scale_size);
fp16_bias_.resize(bias_size);
fp16_scale_.resize(scale_size);
fp16_ele_bias_.resize(ele_bias_size);
std::copy(ele_bias, ele_bias + ele_bias_size, fp16_ele_bias_.data());
std::copy(bias, bias + bias_size, bias_.data());
std::copy(scale, scale + scale_size, scale_.data());
for (int i = 0; i < bias_size; i++) {
fp16_bias_[i] = static_cast<half>(bias[i]);
}
for (int i = 0; i < scale_size; i++) {
fp16_scale_[i] = static_cast<half>(scale[i]);
}
}
explicit PrelnResidualBiasPluginDynamic(const float* bias,
......@@ -76,7 +85,9 @@ class PrelnResidualBiasPluginDynamic : public DynamicPluginTensorRT {
PrelnResidualBiasPluginDynamic(void const* serial_data,
size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &bias_);
DeserializeValue(&serial_data, &serial_length, &fp16_bias_);
DeserializeValue(&serial_data, &serial_length, &scale_);
DeserializeValue(&serial_data, &serial_length, &fp16_scale_);
DeserializeValue(&serial_data, &serial_length, &fp32_ele_bias_);
DeserializeValue(&serial_data, &serial_length, &fp16_ele_bias_);
DeserializeValue(&serial_data, &serial_length, &bias_size_);
......@@ -95,12 +106,12 @@ class PrelnResidualBiasPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder)
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
......@@ -131,13 +142,21 @@ class PrelnResidualBiasPluginDynamic : public DynamicPluginTensorRT {
void terminate() TRT_NOEXCEPT override;
private:
// bias for fp32 mode
std::vector<float> bias_;
// bias for fp16 mode
std::vector<half> fp16_bias_;
// scale for fp32 mode
std::vector<float> scale_;
// scale for fp16 mode
std::vector<half> fp16_scale_;
std::vector<float> fp32_ele_bias_;
std::vector<half> fp16_ele_bias_;
float* bias_gpu_{nullptr};
half* fp16_bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
half* fp16_scale_gpu_{nullptr};
void* ele_bias_gpu_{nullptr};
int bias_size_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册