diff --git a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..3798dcf1d501b3bd0717777221fced76699a2bb5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +struct GroupNormNHWCParams { + // The output buffer. Layout NHWC. + __half* dst; + // The input buffer. Layout NHWC. + __half const* srcX; + // The input buffer. Layout NHWC. + __half const* srcY; + // The gamma scaling factor. + void const* gamma; + // The beta term to add in GN. + void const* beta; + // The temporary buffer to do the global parallel reduction. Size: + // BLOCKS_PER_BATCH x C x 2. + float* redBuffer; + + // The number of instances in the batch. + int32_t n; + // The height and width of each activation map. + int32_t h, w; + // The number of channels. + int32_t c; + // The number of groups. + int32_t groups; + // Do we apply the Swish activation function? + bool withSwish; + + // Precomputed values and parameters to control the execution of the kernels. + + // The number of activations per instance (h * w) and the number of + // activations per block. + int32_t hw, hwPerBlock; + // The number of channels per group and blocks per activation in the C + // dimension. + int32_t cPerBlock, cPerGroup; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hwc in floats (to compute mean/var). + float invHWC; + // The precomputed number of groups per block. + int32_t groupsPerBlock; + // epsilon, Constant for numerical stability + float eps; +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index 4affb00898a3ee838b7587c4c8485ec7ebcbb952..e8c9c593a4654c17009f9c1d0a1ba4150e7c8a00 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -1,4 +1,6 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h" #include "paddle/phi/kernels/group_norm_kernel.h" +#include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -25,6 +28,262 @@ namespace tensorrt { namespace plugin { using DataLayout = phi::DataLayout; +static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; } + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const &a, + GroupSums const &b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +template +__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + + static_cast(hwi) * params.c + ci; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + h2 = *reinterpret_cast<__half2 const *>(¶ms.srcX[offset]); + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + // 2 channels per thread + if (cj == params.cPerGroup - 2) { + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +void groupNormNHWCSum(const GroupNormNHWCParams ¶ms, cudaStream_t stream) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCSumKernel<160><<>>(params); + break; + case 480: + groupNormNHWCSumKernel<256><<>>(params); + break; + case 256: + groupNormNHWCSumKernel<128><<>>(params); + break; + case 128: + groupNormNHWCSumKernel<64><<>>(params); + break; + } +} + +template +__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2, betaF2; + if (ci < params.c) { + gammaF2 = __half22float2(*reinterpret_cast( + reinterpret_cast(params.gamma) + ci)); + betaF2 = __half22float2(*reinterpret_cast( + reinterpret_cast(params.beta) + ci)); + } + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = rsqrtf(var + params.eps); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + h2 = *reinterpret_cast<__half2 const *>(¶ms.srcX[offset]); + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (params.withSwish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + // Store the scaled values. + if (ci < params.c) { + *reinterpret_cast<__half2 *>(¶ms.dst[offset]) = __float22half2_rn(f2); + } + } +} + +void groupNormNHWCScale(const GroupNormNHWCParams ¶ms, + cudaStream_t stream) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCScaleKernel<160><<>>(params); + break; + case 480: + groupNormNHWCScaleKernel<256><<>>(params); + break; + case 256: + groupNormNHWCScaleKernel<128><<>>(params); + break; + case 128: + groupNormNHWCScaleKernel<64><<>>(params); + break; + default: + PADDLE_THROW( + platform::errors::Fatal("The function groupNormNHWCScale of " + "GroupNorm TRT Plugin encounter error")); + } +} + int GroupNormPlugin::initialize() TRT_NOEXCEPT { if (!with_fp16_) { // if use fp32 @@ -188,7 +447,8 @@ bool GroupNormPluginDynamic::supportsFormatCombination( if (pos == 0) { if (with_fp16_) { return ((in.type == nvinfer1::DataType::kHALF) && - (in.format == nvinfer1::PluginFormat::kLINEAR)); + (in.format == nvinfer1::PluginFormat::kLINEAR || + in.format == nvinfer1::PluginFormat::kHWC8)); } else { return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR); @@ -275,9 +535,7 @@ int GroupNormPluginDynamic::enqueue( int C = input_shape[1]; int image_size = input_shape[2] * input_shape[3]; int batchSize = input_shape[0]; - std::vector batched_mean_shape = {batchSize * mean_shape_[0]}; - std::vector batched_variance_shape = {batchSize * - variance_shape_[0]}; + PADDLE_ENFORCE_EQ( C, scale_.size(), @@ -320,25 +578,76 @@ int GroupNormPluginDynamic::enqueue( VLOG(1) << "TRT Plugin DataType selected. GroupNorm-->fp16"; const half *input = reinterpret_cast(inputs[0]); half *output = static_cast(outputs[0]); + if (input_desc[0].format == nvinfer1::PluginFormat::kLINEAR) { + phi::GroupNormDirectCUDAFunctor group_norm; + group_norm(stream, + input, + input_shape, + reinterpret_cast(bias_gpu_), + reinterpret_cast(scale_gpu_), + temp_variance_d, + groups, + eps, + output, + mean_d, + variance_d, + DataLayout::kNCHW); + } else if (input_desc[0].format == nvinfer1::PluginFormat::kHWC8) { + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + switch (input_desc[0].dims.d[1]) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } - phi::GroupNormDirectCUDAFunctor group_norm; - group_norm(stream, - input, - input_shape, - reinterpret_cast(bias_gpu_), - reinterpret_cast(scale_gpu_), - temp_variance_d, - groups, - eps, - output, - mean_d, - variance_d, - DataLayout::kNCHW); + params_.withSwish = false; + params_.dst = static_cast(outputs[0]); + params_.srcX = static_cast(inputs[0]); + params_.gamma = scale_gpu_; + params_.beta = bias_gpu_; + params_.redBuffer = static_cast(workspace); + params_.n = input_desc[0].dims.d[0]; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + params_.c = input_desc[0].dims.d[1]; + params_.groups = groups_; + params_.hw = params_.h * params_.w; + const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); + params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.cPerBlock = cPerBlock; + params_.cPerGroup = params_.c / params_.groups; + params_.hwc = params_.hw * params_.c; + params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.groupsPerBlock = cPerBlock / params_.cPerGroup; + params_.eps = eps_; + + cudaMemsetAsync(params_.redBuffer, + 0, + 2 * sizeof(float) * params_.n * groups_, + stream); + groupNormNHWCSum(params_, stream); + groupNormNHWCScale(params_, stream); + } else { + PADDLE_THROW(platform::errors::Fatal( + "The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); + } } else { // input not float PADDLE_THROW(platform::errors::Fatal( - "The Groupnorm TRT Plugin's only support fp32 input")); + "The Groupnorm TRT Plugin's only support fp32 or fp16 input")); } + return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h index 2eb81c0220747dd23dfa1098435bb207019d3422..1fa505c077ea81ffe1dcdded3363d21504f2c499 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h @@ -21,7 +21,9 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + namespace paddle { namespace inference { namespace tensorrt { @@ -274,6 +276,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { float eps_; std::vector mean_shape_; std::vector variance_shape_; + GroupNormNHWCParams params_; bool with_fp16_; }; class GroupNormPluginDynamicCreator : public TensorRTPluginCreator {