未验证 提交 d93c63a0 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] GroupNorm int8 nchw32 fake kernel (#50146)

* add fmha_flashattention oss plugin

* add fmhca

* add oss fmhca

* code reconstruct and add ut

* code style refine

* fix ut and enforce check

* refine trt version check

refine compile

fix compile

* fix cross ut

* code refine

* use runtime trt version check

* bug fix and code refine

* compile fix

* merge develop

* add GN QDQ kernel

* support GN int8 fake kernel

* add with_int8

* add GN int8 fake kernel

* add GN int8 fake kernel

* add GN int8 fake kernel

* add GN int8 fake kernel

* add GN int8 fake kernel

* add GN int8 fake kernel

* add GN int8 fake kernel

* add GN int8  UT

* add verison > 8000  in GN int8  UT

* add some check in .cu

* add stdlib.h in UT

* little change  in .cu

* remove rand_r use rand

* remove use rand

* setAxis(1)

* when int8 is on allow fall back to fp16

---------
Co-authored-by: Nwwbitejotunn <wang_bojun@outlook.com>
上级 d9b70950
...@@ -67,7 +67,9 @@ class GroupNormOpConverter : public OpConverter { ...@@ -67,7 +67,9 @@ class GroupNormOpConverter : public OpConverter {
auto scale_weights = GetWeight(scale_name, &scale_dims); auto scale_weights = GetWeight(scale_name, &scale_dims);
auto bias_weights = GetWeight(bias_name, &bias_dims); auto bias_weights = GetWeight(bias_name, &bias_dims);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
bool with_int8 = engine_->WithInt8();
// when int8 is on, allow fall back to fp16
if (with_int8) with_fp16 = true;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
int gn_num = groups; int gn_num = groups;
std::vector<int64_t> mean_shape({gn_num}); std::vector<int64_t> mean_shape({gn_num});
...@@ -83,7 +85,8 @@ class GroupNormOpConverter : public OpConverter { ...@@ -83,7 +85,8 @@ class GroupNormOpConverter : public OpConverter {
mean_shape, mean_shape,
variance_shape, variance_shape,
with_silu, with_silu,
with_fp16); with_fp16,
with_int8);
nvinfer1::ILayer* groupnorm_layer = nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(&input_itensor, 1, plugin); engine_->AddDynamicPlugin(&input_itensor, 1, plugin);
auto output_name = op_desc.Output("Y")[0]; auto output_name = op_desc.Output("Y")[0];
......
...@@ -361,6 +361,12 @@ class TensorRTEngine { ...@@ -361,6 +361,12 @@ class TensorRTEngine {
return enable_fp16 && support_fp16; return enable_fp16 && support_fp16;
} }
bool WithInt8() {
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
bool support_int8 = infer_builder_->platformHasFastInt8();
return enable_int8 && support_int8;
}
int GetDeviceId() { return device_id_; } int GetDeviceId() { return device_id_; }
nvinfer1::IPluginV2Layer* AddPlugin(nvinfer1::ITensor* const* inputs, nvinfer1::IPluginV2Layer* AddPlugin(nvinfer1::ITensor* const* inputs,
......
...@@ -69,6 +69,9 @@ struct GroupNormNHWCParams { ...@@ -69,6 +69,9 @@ struct GroupNormNHWCParams {
int32_t groupsPerBlock; int32_t groupsPerBlock;
// epsilon, Constant for numerical stability // epsilon, Constant for numerical stability
float eps; float eps;
// for NCHW32 int8 use
float dqScaleIn;
float inv_qScale;
}; };
} // namespace plugin } // namespace plugin
......
...@@ -188,6 +188,257 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) { ...@@ -188,6 +188,257 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
} }
} }
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) {
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for int8x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
const int8_t *src_ptr = reinterpret_cast<const int8_t *>(params.srcX);
// nchw32 layout
// batch offset + channel offset
int nc_offset = static_cast<int64_t>(ni) * params.hwc +
ci / 32 * params.hw * 32 + ci % 32;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = nc_offset + static_cast<int64_t>(hwi) * 32;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
int8_t tmp_in[2];
*reinterpret_cast<int16_t *>(tmp_in) =
*reinterpret_cast<int16_t const *>(&src_ptr[offset]);
h2.x = params.dqScaleIn * tmp_in[0];
h2.y = params.dqScaleIn * tmp_in[1];
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
// 2 channels per thread
if (cj == params.cPerGroup - 2) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params,
cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNCHW32SumKernelQDQ<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNCHW32SumKernelQDQ<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNCHW32SumKernelQDQ<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNCHW32SumKernelQDQ<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNCHW32SumKernelQDQ<4><<<grid, 4, 0, stream>>>(params);
break;
}
}
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNCHW32ScaleKernelQDQ(
const GroupNormNHWCParams params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
const int8_t *src_ptr = reinterpret_cast<const int8_t *>(params.srcX);
int8_t *dst_ptr = reinterpret_cast<int8_t *>(params.dst);
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}
// Load gamma/beta.
float2 gammaF2, betaF2;
if (ci < params.c) {
gammaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.gamma) + ci));
betaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.beta) + ci));
}
// Compute the mean.
float mean = sum * params.invHWC;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// nchw32 layout
int c_offset = ci / 32 * params.hw * 32 + ci % 32;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc + c_offset +
static_cast<int64_t>(hwi) * 32;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
int8_t tmp_in[2];
*reinterpret_cast<int16_t *>(tmp_in) =
*reinterpret_cast<int16_t const *>(&src_ptr[offset]);
h2.x = params.dqScaleIn * tmp_in[0];
h2.y = params.dqScaleIn * tmp_in[1];
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Silu if needed.
if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
if (ci < params.c) {
int8_t tmp_in[2];
int32_t tmpq0 = __float2int_rn(params.inv_qScale * f2.x);
int32_t tmpq1 = __float2int_rn(params.inv_qScale * f2.y);
tmpq0 = max(-128, tmpq0);
tmpq0 = min(127, tmpq0);
tmpq1 = max(-128, tmpq1);
tmpq1 = min(127, tmpq1);
tmp_in[0] = tmpq0;
tmp_in[1] = tmpq1;
*reinterpret_cast<int16_t *>(&dst_ptr[offset]) =
*reinterpret_cast<int16_t *>(tmp_in);
}
}
}
void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params,
cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNCHW32ScaleKernelQDQ<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNCHW32ScaleKernelQDQ<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNCHW32ScaleKernelQDQ<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNCHW32ScaleKernelQDQ<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNCHW32ScaleKernelQDQ<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(
platform::errors::Fatal("The function groupNormNCHW32ScaleQDQ of "
"GroupNorm TRT Plugin encounter error"));
}
}
template <int tTHREADS_PER_BLOCK> template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) {
// The instance in the batch. // The instance in the batch.
...@@ -454,11 +705,19 @@ bool GroupNormPluginDynamic::supportsFormatCombination( ...@@ -454,11 +705,19 @@ bool GroupNormPluginDynamic::supportsFormatCombination(
pos, pos,
nb_inputs + nb_outputs)); nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) { bool int8_support = in.type == nvinfer1::DataType::kINT8 &&
return ((in.type == nvinfer1::DataType::kHALF) && in.format == nvinfer1::PluginFormat::kCHW32;
bool fp16_support =
(in.type == nvinfer1::DataType::kHALF) &&
((!with_silu_ && in.format == nvinfer1::PluginFormat::kLINEAR) || ((!with_silu_ && in.format == nvinfer1::PluginFormat::kLINEAR) ||
in.format == nvinfer1::PluginFormat::kHWC8)); in.format == nvinfer1::PluginFormat::kHWC8);
if (pos == 0) {
if (with_int8_) {
return int8_support || fp16_support;
} else if (with_fp16_) {
return fp16_support;
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (in.format == nvinfer1::TensorFormat::kLINEAR);
...@@ -655,10 +914,76 @@ int GroupNormPluginDynamic::enqueue( ...@@ -655,10 +914,76 @@ int GroupNormPluginDynamic::enqueue(
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); "The Groupnorm TRT Plugin's only support nchw or nhwc8 input"));
} }
} else if (input_type == nvinfer1::DataType::kINT8) {
const int8_t *input = reinterpret_cast<const int8_t *>(inputs[0]);
int8_t *output = static_cast<int8_t *>(outputs[0]);
if (input_desc[0].format == nvinfer1::PluginFormat::kCHW32) {
int32_t cPerBlock = 320;
int32_t maxBlocksPerHW = 1024;
switch (input_desc[0].dims.d[1]) {
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
}
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSilu = with_silu_;
params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]);
params_.gamma = scale_gpu_;
params_.beta = bias_gpu_;
params_.redBuffer = static_cast<float *>(workspace);
params_.n = input_desc[0].dims.d[0];
params_.h = input_desc[0].dims.d[2];
params_.w = input_desc[0].dims.d[3];
params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.hw = params_.h * params_.w;
const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW);
params_.hwPerBlock = divUp(params_.hw, blocksPerHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
params_.hwc = params_.hw * params_.c;
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.groupsPerBlock = cPerBlock / params_.cPerGroup;
CHECK_EQ(cPerBlock % params_.cPerGroup, 0);
CHECK_EQ(params_.cPerGroup % 2, 0);
params_.eps = eps_;
params_.dqScaleIn = input_desc[0].scale;
params_.inv_qScale = 1.f / output_desc[0].scale;
// Just used for TensorRTDynamicShapeGNTes in test_dynamic_engine.cc
// Do not Edit it
// params_.dqScaleIn = 1.f;
// params_.inv_qScale = 1 / 0.05f;
cudaMemsetAsync(params_.redBuffer,
0,
2 * sizeof(float) * params_.n * groups_,
stream);
groupNormNCHW32SumQDQ(params_, stream);
groupNormNCHW32ScaleQDQ(params_, stream);
} else {
PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin only support nchw32 input"));
}
} else { } else {
// input not float // input not float
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin's only support fp32 or fp16 input")); "The Groupnorm TRT Plugin's only support fp32, fp16 or int8 input"));
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
......
...@@ -165,13 +165,15 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -165,13 +165,15 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
std::vector<int64_t> mean_shape, std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape, std::vector<int64_t> variance_shape,
bool with_silu, bool with_silu,
bool with_fp16) bool with_fp16,
bool with_int8)
: groups_(groups), : groups_(groups),
eps_(eps), eps_(eps),
mean_shape_(mean_shape), mean_shape_(mean_shape),
variance_shape_(variance_shape), variance_shape_(variance_shape),
with_silu_(with_silu), with_silu_(with_silu),
with_fp16_(with_fp16) { with_fp16_(with_fp16),
with_int8_(with_int8) {
scale_.resize(scale_num); scale_.resize(scale_num);
bias_.resize(bias_num); bias_.resize(bias_num);
std::copy(scale, scale + scale_num, scale_.data()); std::copy(scale, scale + scale_num, scale_.data());
...@@ -187,6 +189,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -187,6 +189,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_silu_); DeserializeValue(&serialData, &serialLength, &with_silu_);
DeserializeValue(&serialData, &serialLength, &with_fp16_); DeserializeValue(&serialData, &serialLength, &with_fp16_);
DeserializeValue(&serialData, &serialLength, &with_int8_);
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto* ptr = new GroupNormPluginDynamic(scale_.data(), auto* ptr = new GroupNormPluginDynamic(scale_.data(),
...@@ -198,7 +201,8 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -198,7 +201,8 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
mean_shape_, mean_shape_,
variance_shape_, variance_shape_,
with_silu_, with_silu_,
with_fp16_); with_fp16_,
with_int8_);
ptr->scale_gpu_ = scale_gpu_; ptr->scale_gpu_ = scale_gpu_;
ptr->bias_gpu_ = bias_gpu_; ptr->bias_gpu_ = bias_gpu_;
return ptr; return ptr;
...@@ -214,7 +218,8 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -214,7 +218,8 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
return SerializedSize(scale_) + SerializedSize(bias_) + return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) + SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(mean_shape_) + SerializedSize(variance_shape_) + SerializedSize(mean_shape_) + SerializedSize(variance_shape_) +
SerializedSize(with_silu_) + SerializedSize(with_fp16_); SerializedSize(with_silu_) + SerializedSize(with_fp16_) +
+SerializedSize(with_int8_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_); SerializeValue(&buffer, scale_);
...@@ -225,6 +230,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -225,6 +230,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_silu_); SerializeValue(&buffer, with_silu_);
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, with_int8_);
} }
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int output_index, int output_index,
...@@ -284,6 +290,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -284,6 +290,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
GroupNormNHWCParams params_; GroupNormNHWCParams params_;
bool with_silu_; bool with_silu_;
bool with_fp16_; bool with_fp16_;
bool with_int8_;
}; };
class GroupNormPluginDynamicCreator : public TensorRTPluginCreator { class GroupNormPluginDynamicCreator : public TensorRTPluginCreator {
public: public:
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h"
#endif #endif
#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
...@@ -745,7 +746,310 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) { ...@@ -745,7 +746,310 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
LOG(INFO) << "finish"; LOG(INFO) << "finish";
#endif #endif
} }
#if IS_TRT_VERSION_GE(8000)
class TensorRTDynamicShapeGNTest : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
ctx_->SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
ctx_->SetZeroAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CUDAPlace(0))
.get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator();
std::map<std::string, std::vector<int>> min_input_shape = {
{"x", {n_, c_, h_, w_}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {n_, c_, h_, w_}}};
std::map<std::string, std::vector<int>> optim_input_shape = {
{"x", {n_, c_, h_, w_}}};
std::map<std::string, std::vector<int>> min_input_value = {};
std::map<std::string, std::vector<int>> max_input_value = {};
std::map<std::string, std::vector<int>> optim_input_value = {};
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kInt8,
nullptr,
0,
min_input_shape,
max_input_shape,
optim_input_shape,
min_input_value,
max_input_value,
optim_input_value,
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
}
void PrepareInputOutput(const std::vector<float> &input,
std::vector<int> output_shape) {
paddle::framework::TensorFromVector(input, *ctx_, &x_);
paddle::framework::TensorFromVector(input, *ctx_, &y_);
}
void GetOutput(std::vector<float> *output) {
paddle::framework::TensorToVector(y_, *ctx_, output);
}
struct logical_struct {
int n;
int c;
int h;
int w;
};
int nchw(struct logical_struct shape, struct logical_struct index) {
return index.n * shape.c * shape.h * shape.w + index.c * shape.h * shape.w +
index.h * shape.w + index.w;
}
// this function
void naive_qdq_cpu(
float *output, const float *input, int n, float q, float dq) {
for (int i = 0; i < n; i++) {
float tmp = input[i];
int32_t qtmp = std::round(tmp / q);
qtmp = std::max(-128, qtmp);
qtmp = std::min(127, qtmp);
output[i] = qtmp * dq;
}
}
void naive_groupnorm_post_qdq(float *output,
const float *input,
int n,
int c,
int h,
int w,
int groups,
float epsilon,
float post_scale,
const float *scale,
const float *bias,
bool with_silu) {
assert(c % groups == 0);
struct logical_struct shape {
n, c, h, w
};
for (int ni = 0; ni < n; ni++) {
for (int group_i = 0; group_i < groups; group_i++) {
int ci_begin = group_i * (c / groups);
int ci_end = (group_i + 1) * (c / groups);
float sum = 0.f;
float q2sum = 0.f;
for (int ci = ci_begin; ci < ci_end; ci++) {
for (int hi = 0; hi < h; hi++) {
for (int wi = 0; wi < w; wi++) {
struct logical_struct index {
ni, ci, hi, wi
};
float tmp_data = *(input + nchw(shape, index));
sum += tmp_data;
q2sum += tmp_data * tmp_data;
}
}
}
int nums = h * w * c / groups;
float mean = sum / nums;
float sigma = sqrtf(q2sum / nums - mean * mean + epsilon);
for (int ci = ci_begin; ci < ci_end; ci++) {
for (int hi = 0; hi < h; hi++) {
for (int wi = 0; wi < w; wi++) {
struct logical_struct index {
ni, ci, hi, wi
};
float tmp_data = *(input + nchw(shape, index));
float norm_data = (tmp_data - mean) / sigma;
*(output + nchw(shape, index)) = norm_data;
}
}
}
}
}
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < c; ci++) {
for (int hi = 0; hi < h; hi++) {
for (int wi = 0; wi < w; wi++) {
struct logical_struct index {
ni, ci, hi, wi
};
float tmp = *(output + nchw(shape, index));
float scale_v = scale[ci];
float bias_v = bias[ci];
float x = tmp * scale_v + bias_v;
if (with_silu) {
x = x / (1 + std::exp(-x));
}
*(output + nchw(shape, index)) = x;
}
}
}
}
naive_qdq_cpu(output, output, n * c * h * w, post_scale, post_scale);
}
protected:
phi::DenseTensor x_;
phi::DenseTensor y_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
// case from SD
int n_ = 2;
int c_ = 320;
int h_ = 14;
int w_ = 14;
int groups_ = 32;
float epsilon_ = 0.000009999999747378752;
};
TEST_F(TensorRTDynamicShapeGNTest, test_trt_dynamic_shape_groupnorm) {
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
float *bias = new float[c_];
float *scale = new float[c_];
for (int i = 0; i < c_; i++) {
bias[i] = (i % 100) / 100.f;
}
for (int i = 0; i < c_; i++) {
scale[i] = (i % 100) / 100.f;
}
auto *x = engine_->DeclareInput(
"x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4{n_, c_, h_, w_});
nvinfer1::Dims scale_dims;
scale_dims.nbDims = 1;
scale_dims.d[0] = 1;
// must set qscale_data = 1.f!
float qscale_data = 1.f;
float dqscale_data = 1.f;
TensorRTEngine::Weight q_weight(nvinfer1::DataType::kFLOAT, &qscale_data, 1);
TensorRTEngine::Weight dq_weight(
nvinfer1::DataType::kFLOAT, &dqscale_data, 1);
auto *qscale_tensor =
TRT_ENGINE_ADD_LAYER(engine_, Constant, scale_dims, q_weight.get())
->getOutput(0);
auto *dqscale_tensor =
TRT_ENGINE_ADD_LAYER(engine_, Constant, scale_dims, dq_weight.get())
->getOutput(0);
auto *q_layer = TRT_ENGINE_ADD_LAYER(engine_, Quantize, *x, *qscale_tensor);
q_layer->setAxis(1);
auto *q_layer_tensor = q_layer->getOutput(0);
int gn_num = n_ * groups_;
std::vector<int64_t> mean_shape({gn_num});
std::vector<int64_t> variance_shape({gn_num});
bool with_fp16 = true;
bool with_int8 = true;
bool with_silu = true;
plugin::GroupNormPluginDynamic *plugin =
new plugin::GroupNormPluginDynamic(scale,
c_,
bias,
c_,
epsilon_,
groups_,
mean_shape,
variance_shape,
with_silu,
with_fp16,
with_int8);
nvinfer1::ILayer *groupnorm_layer =
engine_->AddDynamicPlugin(&q_layer_tensor, 1, plugin);
groupnorm_layer->setOutputType(0, nvinfer1::DataType::kINT8);
auto *gn_tensor = groupnorm_layer->getOutput(0);
auto *dq_layer =
TRT_ENGINE_ADD_LAYER(engine_, Dequantize, *gn_tensor, *dqscale_tensor);
dq_layer->setAxis(1);
PADDLE_ENFORCE_NOT_NULL(groupnorm_layer,
platform::errors::InvalidArgument(
"TRT GN plugin layer building failed."));
engine_->DeclareOutput(dq_layer, 0, "y");
engine_->FreezeNetwork();
int input_num = n_ * c_ * h_ * w_;
std::vector<int> shape_v = {n_, c_, h_, w_};
std::vector<float> x_v(input_num);
for (int i = 0; i < input_num; i++) {
x_v[i] = i % 32 - 16;
}
PrepareInputOutput(x_v, shape_v);
engine_->context()->setBindingDimensions(0, nvinfer1::Dims4{n_, c_, h_, w_});
auto *x_gpu_data = x_.data<float>();
auto *y_gpu_data = y_.mutable_data<float>(ctx_->GetPlace());
std::vector<void *> buffers(2);
buffers[0] = reinterpret_cast<void *>(x_gpu_data);
buffers[1] = reinterpret_cast<void *>(y_gpu_data);
engine_->Execute(-1, &buffers, ctx_->stream());
cudaStreamSynchronize(ctx_->stream());
std::vector<float> y_cpu;
GetOutput(&y_cpu);
std::vector<float> y_base(input_num);
naive_groupnorm_post_qdq(y_base.data(),
x_v.data(),
n_,
c_,
h_,
w_,
groups_,
epsilon_,
dqscale_data,
bias,
scale,
with_silu);
float max_diff = -1;
int right_num = 0;
for (uint64_t i = 0; i < y_cpu.size(); i++) {
float diff = std::abs(y_base[i] - y_cpu[i]);
if (diff < 6e-2) right_num++;
if (diff > max_diff) max_diff = diff;
}
ASSERT_EQ(right_num, input_num);
delete[] bias;
delete[] scale;
return;
}
#endif
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册