From d93c63a049091bd68465eb320ec8bb352384751f Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 9 Feb 2023 12:28:45 +0800 Subject: [PATCH] [Paddle-TRT] GroupNorm int8 nchw32 fake kernel (#50146) * add fmha_flashattention oss plugin * add fmhca * add oss fmhca * code reconstruct and add ut * code style refine * fix ut and enforce check * refine trt version check refine compile fix compile * fix cross ut * code refine * use runtime trt version check * bug fix and code refine * compile fix * merge develop * add GN QDQ kernel * support GN int8 fake kernel * add with_int8 * add GN int8 fake kernel * add GN int8 fake kernel * add GN int8 fake kernel * add GN int8 fake kernel * add GN int8 fake kernel * add GN int8 fake kernel * add GN int8 fake kernel * add GN int8 UT * add verison > 8000 in GN int8 UT * add some check in .cu * add stdlib.h in UT * little change in .cu * remove rand_r use rand * remove use rand * setAxis(1) * when int8 is on allow fall back to fp16 --------- Co-authored-by: wwbitejotunn --- .../tensorrt/convert/group_norm_op.cc | 7 +- paddle/fluid/inference/tensorrt/engine.h | 6 + .../plugin/common/groupNormPluginCommon.h | 3 + .../tensorrt/plugin/group_norm_op_plugin.cu | 335 +++++++++++++++++- .../tensorrt/plugin/group_norm_op_plugin.h | 15 +- .../inference/tensorrt/test_dynamic_engine.cc | 304 ++++++++++++++++ 6 files changed, 659 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc index 4384f7d2b3..5d1f2031f6 100644 --- a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc @@ -67,7 +67,9 @@ class GroupNormOpConverter : public OpConverter { auto scale_weights = GetWeight(scale_name, &scale_dims); auto bias_weights = GetWeight(bias_name, &bias_dims); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - + bool with_int8 = engine_->WithInt8(); + // when int8 is on, allow fall back to fp16 + if (with_int8) with_fp16 = true; if (engine_->with_dynamic_shape()) { int gn_num = groups; std::vector mean_shape({gn_num}); @@ -83,7 +85,8 @@ class GroupNormOpConverter : public OpConverter { mean_shape, variance_shape, with_silu, - with_fp16); + with_fp16, + with_int8); nvinfer1::ILayer* groupnorm_layer = engine_->AddDynamicPlugin(&input_itensor, 1, plugin); auto output_name = op_desc.Output("Y")[0]; diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index e3f1ed2a37..421842cf56 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -361,6 +361,12 @@ class TensorRTEngine { return enable_fp16 && support_fp16; } + bool WithInt8() { + bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8); + bool support_int8 = infer_builder_->platformHasFastInt8(); + return enable_int8 && support_int8; + } + int GetDeviceId() { return device_id_; } nvinfer1::IPluginV2Layer* AddPlugin(nvinfer1::ITensor* const* inputs, diff --git a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h index 915ee1b5e2..1ba134a6fc 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h @@ -69,6 +69,9 @@ struct GroupNormNHWCParams { int32_t groupsPerBlock; // epsilon, Constant for numerical stability float eps; + // for NCHW32 int8 use + float dqScaleIn; + float inv_qScale; }; } // namespace plugin diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index fc139a9734..279a005896 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -188,6 +188,257 @@ void groupNormNHWCSum(const GroupNormNHWCParams ¶ms, cudaStream_t stream) { } } +template +__global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for int8x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + const int8_t *src_ptr = reinterpret_cast(params.srcX); + + // nchw32 layout + // batch offset + channel offset + int nc_offset = static_cast(ni) * params.hwc + + ci / 32 * params.hw * 32 + ci % 32; + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = nc_offset + static_cast(hwi) * 32; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + int8_t tmp_in[2]; + *reinterpret_cast(tmp_in) = + *reinterpret_cast(&src_ptr[offset]); + h2.x = params.dqScaleIn * tmp_in[0]; + h2.y = params.dqScaleIn * tmp_in[1]; + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + // 2 channels per thread + if (cj == params.cPerGroup - 2) { + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +void groupNormNCHW32SumQDQ(const GroupNormNHWCParams ¶ms, + cudaStream_t stream) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNCHW32SumKernelQDQ<160><<>>(params); + break; + case 480: + groupNormNCHW32SumKernelQDQ<256><<>>(params); + break; + case 256: + groupNormNCHW32SumKernelQDQ<128><<>>(params); + break; + case 128: + groupNormNCHW32SumKernelQDQ<64><<>>(params); + break; + case 8: + groupNormNCHW32SumKernelQDQ<4><<>>(params); + break; + } +} + +template +__global__ void groupNormNCHW32ScaleKernelQDQ( + const GroupNormNHWCParams params) { + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + const int8_t *src_ptr = reinterpret_cast(params.srcX); + int8_t *dst_ptr = reinterpret_cast(params.dst); + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2, betaF2; + if (ci < params.c) { + gammaF2 = __half22float2(*reinterpret_cast( + reinterpret_cast(params.gamma) + ci)); + betaF2 = __half22float2(*reinterpret_cast( + reinterpret_cast(params.beta) + ci)); + } + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = rsqrtf(var + params.eps); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // nchw32 layout + int c_offset = ci / 32 * params.hw * 32 + ci % 32; + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = static_cast(ni) * params.hwc + c_offset + + static_cast(hwi) * 32; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + int8_t tmp_in[2]; + *reinterpret_cast(tmp_in) = + *reinterpret_cast(&src_ptr[offset]); + h2.x = params.dqScaleIn * tmp_in[0]; + h2.y = params.dqScaleIn * tmp_in[1]; + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Silu if needed. + if (params.withSilu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + // Store the scaled values. + if (ci < params.c) { + int8_t tmp_in[2]; + int32_t tmpq0 = __float2int_rn(params.inv_qScale * f2.x); + int32_t tmpq1 = __float2int_rn(params.inv_qScale * f2.y); + tmpq0 = max(-128, tmpq0); + tmpq0 = min(127, tmpq0); + tmpq1 = max(-128, tmpq1); + tmpq1 = min(127, tmpq1); + tmp_in[0] = tmpq0; + tmp_in[1] = tmpq1; + *reinterpret_cast(&dst_ptr[offset]) = + *reinterpret_cast(tmp_in); + } + } +} + +void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams ¶ms, + cudaStream_t stream) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNCHW32ScaleKernelQDQ<160><<>>(params); + break; + case 480: + groupNormNCHW32ScaleKernelQDQ<256><<>>(params); + break; + case 256: + groupNormNCHW32ScaleKernelQDQ<128><<>>(params); + break; + case 128: + groupNormNCHW32ScaleKernelQDQ<64><<>>(params); + break; + case 8: + groupNormNCHW32ScaleKernelQDQ<4><<>>(params); + break; + default: + PADDLE_THROW( + platform::errors::Fatal("The function groupNormNCHW32ScaleQDQ of " + "GroupNorm TRT Plugin encounter error")); + } +} + template __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { // The instance in the batch. @@ -454,11 +705,19 @@ bool GroupNormPluginDynamic::supportsFormatCombination( pos, nb_inputs + nb_outputs)); const nvinfer1::PluginTensorDesc &in = in_out[pos]; + + bool int8_support = in.type == nvinfer1::DataType::kINT8 && + in.format == nvinfer1::PluginFormat::kCHW32; + bool fp16_support = + (in.type == nvinfer1::DataType::kHALF) && + ((!with_silu_ && in.format == nvinfer1::PluginFormat::kLINEAR) || + in.format == nvinfer1::PluginFormat::kHWC8); + if (pos == 0) { - if (with_fp16_) { - return ((in.type == nvinfer1::DataType::kHALF) && - ((!with_silu_ && in.format == nvinfer1::PluginFormat::kLINEAR) || - in.format == nvinfer1::PluginFormat::kHWC8)); + if (with_int8_) { + return int8_support || fp16_support; + } else if (with_fp16_) { + return fp16_support; } else { return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR); @@ -655,10 +914,76 @@ int GroupNormPluginDynamic::enqueue( PADDLE_THROW(platform::errors::Fatal( "The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); } + } else if (input_type == nvinfer1::DataType::kINT8) { + const int8_t *input = reinterpret_cast(inputs[0]); + int8_t *output = static_cast(outputs[0]); + + if (input_desc[0].format == nvinfer1::PluginFormat::kCHW32) { + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + switch (input_desc[0].dims.d[1]) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } + params_.withSilu = with_silu_; + params_.dst = static_cast(outputs[0]); + params_.srcX = static_cast(inputs[0]); + + params_.gamma = scale_gpu_; + params_.beta = bias_gpu_; + params_.redBuffer = static_cast(workspace); + params_.n = input_desc[0].dims.d[0]; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + params_.c = input_desc[0].dims.d[1]; + params_.groups = groups_; + params_.hw = params_.h * params_.w; + const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); + params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.cPerBlock = cPerBlock; + params_.cPerGroup = params_.c / params_.groups; + params_.hwc = params_.hw * params_.c; + params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.groupsPerBlock = cPerBlock / params_.cPerGroup; + CHECK_EQ(cPerBlock % params_.cPerGroup, 0); + CHECK_EQ(params_.cPerGroup % 2, 0); + params_.eps = eps_; + params_.dqScaleIn = input_desc[0].scale; + params_.inv_qScale = 1.f / output_desc[0].scale; + + // Just used for TensorRTDynamicShapeGNTes in test_dynamic_engine.cc + // Do not Edit it + // params_.dqScaleIn = 1.f; + // params_.inv_qScale = 1 / 0.05f; + + cudaMemsetAsync(params_.redBuffer, + 0, + 2 * sizeof(float) * params_.n * groups_, + stream); + groupNormNCHW32SumQDQ(params_, stream); + groupNormNCHW32ScaleQDQ(params_, stream); + } else { + PADDLE_THROW(platform::errors::Fatal( + "The Groupnorm TRT Plugin only support nchw32 input")); + } } else { // input not float PADDLE_THROW(platform::errors::Fatal( - "The Groupnorm TRT Plugin's only support fp32 or fp16 input")); + "The Groupnorm TRT Plugin's only support fp32, fp16 or int8 input")); } return cudaGetLastError() != cudaSuccess; diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h index 3feb35e070..8f521c8bf5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h @@ -165,13 +165,15 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { std::vector mean_shape, std::vector variance_shape, bool with_silu, - bool with_fp16) + bool with_fp16, + bool with_int8) : groups_(groups), eps_(eps), mean_shape_(mean_shape), variance_shape_(variance_shape), with_silu_(with_silu), - with_fp16_(with_fp16) { + with_fp16_(with_fp16), + with_int8_(with_int8) { scale_.resize(scale_num); bias_.resize(bias_num); std::copy(scale, scale + scale_num, scale_.data()); @@ -187,6 +189,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &with_silu_); DeserializeValue(&serialData, &serialLength, &with_fp16_); + DeserializeValue(&serialData, &serialLength, &with_int8_); } nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { auto* ptr = new GroupNormPluginDynamic(scale_.data(), @@ -198,7 +201,8 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { mean_shape_, variance_shape_, with_silu_, - with_fp16_); + with_fp16_, + with_int8_); ptr->scale_gpu_ = scale_gpu_; ptr->bias_gpu_ = bias_gpu_; return ptr; @@ -214,7 +218,8 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { return SerializedSize(scale_) + SerializedSize(bias_) + SerializedSize(eps_) + SerializedSize(groups_) + SerializedSize(mean_shape_) + SerializedSize(variance_shape_) + - SerializedSize(with_silu_) + SerializedSize(with_fp16_); + SerializedSize(with_silu_) + SerializedSize(with_fp16_) + + +SerializedSize(with_int8_); } void serialize(void* buffer) const TRT_NOEXCEPT override { SerializeValue(&buffer, scale_); @@ -225,6 +230,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, with_silu_); SerializeValue(&buffer, with_fp16_); + SerializeValue(&buffer, with_int8_); } nvinfer1::DimsExprs getOutputDimensions( int output_index, @@ -284,6 +290,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { GroupNormNHWCParams params_; bool with_silu_; bool with_fp16_; + bool with_int8_; }; class GroupNormPluginDynamicCreator : public TensorRTPluginCreator { public: diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc index 36d0f4b1d3..6e3f587dd4 100644 --- a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" #endif #include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/float16.h" @@ -745,7 +746,310 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) { LOG(INFO) << "finish"; #endif } +#if IS_TRT_VERSION_GE(8000) +class TensorRTDynamicShapeGNTest : public ::testing::Test { + protected: + void SetUp() override { + ctx_ = new phi::GPUContext(platform::CUDAPlace(0)); + ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(platform::CUDAPlace(0), ctx_->stream()) + .get()); + ctx_->SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + ctx_->SetZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(platform::CUDAPlace(0)) + .get()); + ctx_->SetPinnedAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CUDAPinnedPlace()) + .get()); + ctx_->PartialInitWithAllocator(); + std::map> min_input_shape = { + {"x", {n_, c_, h_, w_}}}; + std::map> max_input_shape = { + {"x", {n_, c_, h_, w_}}}; + std::map> optim_input_shape = { + {"x", {n_, c_, h_, w_}}}; + std::map> min_input_value = {}; + std::map> max_input_value = {}; + std::map> optim_input_value = {}; + + engine_ = new TensorRTEngine(16, + 1 << 10, + AnalysisConfig::Precision::kInt8, + nullptr, + 0, + min_input_shape, + max_input_shape, + optim_input_shape, + min_input_value, + max_input_value, + optim_input_value, + false, + phi::DataType::FLOAT32, + NaiveLogger::Global()); + engine_->InitNetwork(); + } + + void TearDown() override { + if (engine_) { + delete engine_; + engine_ = nullptr; + } + } + + void PrepareInputOutput(const std::vector &input, + std::vector output_shape) { + paddle::framework::TensorFromVector(input, *ctx_, &x_); + paddle::framework::TensorFromVector(input, *ctx_, &y_); + } + void GetOutput(std::vector *output) { + paddle::framework::TensorToVector(y_, *ctx_, output); + } + + struct logical_struct { + int n; + int c; + int h; + int w; + }; + + int nchw(struct logical_struct shape, struct logical_struct index) { + return index.n * shape.c * shape.h * shape.w + index.c * shape.h * shape.w + + index.h * shape.w + index.w; + } + + // this function + void naive_qdq_cpu( + float *output, const float *input, int n, float q, float dq) { + for (int i = 0; i < n; i++) { + float tmp = input[i]; + int32_t qtmp = std::round(tmp / q); + qtmp = std::max(-128, qtmp); + qtmp = std::min(127, qtmp); + output[i] = qtmp * dq; + } + } + + void naive_groupnorm_post_qdq(float *output, + const float *input, + int n, + int c, + int h, + int w, + int groups, + float epsilon, + float post_scale, + const float *scale, + const float *bias, + bool with_silu) { + assert(c % groups == 0); + struct logical_struct shape { + n, c, h, w + }; + + for (int ni = 0; ni < n; ni++) { + for (int group_i = 0; group_i < groups; group_i++) { + int ci_begin = group_i * (c / groups); + int ci_end = (group_i + 1) * (c / groups); + + float sum = 0.f; + float q2sum = 0.f; + + for (int ci = ci_begin; ci < ci_end; ci++) { + for (int hi = 0; hi < h; hi++) { + for (int wi = 0; wi < w; wi++) { + struct logical_struct index { + ni, ci, hi, wi + }; + float tmp_data = *(input + nchw(shape, index)); + sum += tmp_data; + q2sum += tmp_data * tmp_data; + } + } + } + + int nums = h * w * c / groups; + float mean = sum / nums; + float sigma = sqrtf(q2sum / nums - mean * mean + epsilon); + + for (int ci = ci_begin; ci < ci_end; ci++) { + for (int hi = 0; hi < h; hi++) { + for (int wi = 0; wi < w; wi++) { + struct logical_struct index { + ni, ci, hi, wi + }; + float tmp_data = *(input + nchw(shape, index)); + float norm_data = (tmp_data - mean) / sigma; + *(output + nchw(shape, index)) = norm_data; + } + } + } + } + } + + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < c; ci++) { + for (int hi = 0; hi < h; hi++) { + for (int wi = 0; wi < w; wi++) { + struct logical_struct index { + ni, ci, hi, wi + }; + float tmp = *(output + nchw(shape, index)); + float scale_v = scale[ci]; + float bias_v = bias[ci]; + float x = tmp * scale_v + bias_v; + if (with_silu) { + x = x / (1 + std::exp(-x)); + } + *(output + nchw(shape, index)) = x; + } + } + } + } + + naive_qdq_cpu(output, output, n * c * h * w, post_scale, post_scale); + } + + protected: + phi::DenseTensor x_; + phi::DenseTensor y_; + TensorRTEngine *engine_; + phi::GPUContext *ctx_; + // case from SD + int n_ = 2; + int c_ = 320; + int h_ = 14; + int w_ = 14; + int groups_ = 32; + float epsilon_ = 0.000009999999747378752; +}; + +TEST_F(TensorRTDynamicShapeGNTest, test_trt_dynamic_shape_groupnorm) { + tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); + + float *bias = new float[c_]; + float *scale = new float[c_]; + for (int i = 0; i < c_; i++) { + bias[i] = (i % 100) / 100.f; + } + for (int i = 0; i < c_; i++) { + scale[i] = (i % 100) / 100.f; + } + + auto *x = engine_->DeclareInput( + "x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{n_, c_, h_, w_}); + + nvinfer1::Dims scale_dims; + scale_dims.nbDims = 1; + scale_dims.d[0] = 1; + // must set qscale_data = 1.f! + float qscale_data = 1.f; + float dqscale_data = 1.f; + TensorRTEngine::Weight q_weight(nvinfer1::DataType::kFLOAT, &qscale_data, 1); + TensorRTEngine::Weight dq_weight( + nvinfer1::DataType::kFLOAT, &dqscale_data, 1); + + auto *qscale_tensor = + TRT_ENGINE_ADD_LAYER(engine_, Constant, scale_dims, q_weight.get()) + ->getOutput(0); + auto *dqscale_tensor = + TRT_ENGINE_ADD_LAYER(engine_, Constant, scale_dims, dq_weight.get()) + ->getOutput(0); + + auto *q_layer = TRT_ENGINE_ADD_LAYER(engine_, Quantize, *x, *qscale_tensor); + q_layer->setAxis(1); + auto *q_layer_tensor = q_layer->getOutput(0); + + int gn_num = n_ * groups_; + std::vector mean_shape({gn_num}); + std::vector variance_shape({gn_num}); + bool with_fp16 = true; + bool with_int8 = true; + bool with_silu = true; + plugin::GroupNormPluginDynamic *plugin = + new plugin::GroupNormPluginDynamic(scale, + c_, + bias, + c_, + epsilon_, + groups_, + mean_shape, + variance_shape, + with_silu, + with_fp16, + with_int8); + + nvinfer1::ILayer *groupnorm_layer = + engine_->AddDynamicPlugin(&q_layer_tensor, 1, plugin); + groupnorm_layer->setOutputType(0, nvinfer1::DataType::kINT8); + auto *gn_tensor = groupnorm_layer->getOutput(0); + auto *dq_layer = + TRT_ENGINE_ADD_LAYER(engine_, Dequantize, *gn_tensor, *dqscale_tensor); + dq_layer->setAxis(1); + + PADDLE_ENFORCE_NOT_NULL(groupnorm_layer, + platform::errors::InvalidArgument( + "TRT GN plugin layer building failed.")); + + engine_->DeclareOutput(dq_layer, 0, "y"); + engine_->FreezeNetwork(); + + int input_num = n_ * c_ * h_ * w_; + std::vector shape_v = {n_, c_, h_, w_}; + + std::vector x_v(input_num); + for (int i = 0; i < input_num; i++) { + x_v[i] = i % 32 - 16; + } + + PrepareInputOutput(x_v, shape_v); + + engine_->context()->setBindingDimensions(0, nvinfer1::Dims4{n_, c_, h_, w_}); + + auto *x_gpu_data = x_.data(); + auto *y_gpu_data = y_.mutable_data(ctx_->GetPlace()); + std::vector buffers(2); + buffers[0] = reinterpret_cast(x_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + + engine_->Execute(-1, &buffers, ctx_->stream()); + cudaStreamSynchronize(ctx_->stream()); + + std::vector y_cpu; + GetOutput(&y_cpu); + std::vector y_base(input_num); + naive_groupnorm_post_qdq(y_base.data(), + x_v.data(), + n_, + c_, + h_, + w_, + groups_, + epsilon_, + dqscale_data, + bias, + scale, + with_silu); + float max_diff = -1; + int right_num = 0; + for (uint64_t i = 0; i < y_cpu.size(); i++) { + float diff = std::abs(y_base[i] - y_cpu[i]); + if (diff < 6e-2) right_num++; + if (diff > max_diff) max_diff = diff; + } + + ASSERT_EQ(right_num, input_num); + + delete[] bias; + delete[] scale; + return; +} +#endif } // namespace tensorrt } // namespace inference } // namespace paddle -- GitLab