未验证 提交 babd26ee 编写于 作者: W wenbin 提交者: GitHub

groupnorm nhwc8 (#49160)

* gn nhwc8

* remove error
上级 6439e91d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
__half* dst;
// The input buffer. Layout NHWC.
__half const* srcX;
// The input buffer. Layout NHWC.
__half const* srcY;
// The gamma scaling factor.
void const* gamma;
// The beta term to add in GN.
void const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h, w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Swish activation function?
bool withSwish;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw, hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock, cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
// epsilon, Constant for numerical stability
float eps;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -15,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -25,6 +28,262 @@ namespace tensorrt {
namespace plugin {
using DataLayout = phi::DataLayout;
static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
static inline __device__ __host__ float sigmoid(float x) {
return 1.F / (1.F + expf(-x));
}
struct GroupSums {
// Is it the 1st element of the group?
int32_t flag;
// The sum.
float sum;
// The sum of squares.
float sumSq;
};
struct GroupSumsOp {
inline __device__ GroupSums operator()(GroupSums const &a,
GroupSums const &b) {
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.flag = a.flag + b.flag;
return dst;
}
};
static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
}
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
maxDivisor = divisor2;
}
}
}
return maxDivisor;
}
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) {
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
// 2 channels per thread
if (cj == params.cPerGroup - 2) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCSumKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCSumKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCSumKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
}
}
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}
// Load gamma/beta.
float2 gammaF2, betaF2;
if (ci < params.c) {
gammaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.gamma) + ci));
betaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.beta) + ci));
}
// Compute the mean.
float mean = sum * params.invHWC;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Swish if needed.
if (params.withSwish) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
if (ci < params.c) {
*reinterpret_cast<__half2 *>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
}
void groupNormNHWCScale(const GroupNormNHWCParams &params,
cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCScaleKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCScaleKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCScaleKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
default:
PADDLE_THROW(
platform::errors::Fatal("The function groupNormNHWCScale of "
"GroupNorm TRT Plugin encounter error"));
}
}
int GroupNormPlugin::initialize() TRT_NOEXCEPT {
if (!with_fp16_) {
// if use fp32
......@@ -188,7 +447,8 @@ bool GroupNormPluginDynamic::supportsFormatCombination(
if (pos == 0) {
if (with_fp16_) {
return ((in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::PluginFormat::kLINEAR));
(in.format == nvinfer1::PluginFormat::kLINEAR ||
in.format == nvinfer1::PluginFormat::kHWC8));
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
......@@ -275,9 +535,7 @@ int GroupNormPluginDynamic::enqueue(
int C = input_shape[1];
int image_size = input_shape[2] * input_shape[3];
int batchSize = input_shape[0];
std::vector<int64_t> batched_mean_shape = {batchSize * mean_shape_[0]};
std::vector<int64_t> batched_variance_shape = {batchSize *
variance_shape_[0]};
PADDLE_ENFORCE_EQ(
C,
scale_.size(),
......@@ -320,25 +578,76 @@ int GroupNormPluginDynamic::enqueue(
VLOG(1) << "TRT Plugin DataType selected. GroupNorm-->fp16";
const half *input = reinterpret_cast<const half *>(inputs[0]);
half *output = static_cast<half *>(outputs[0]);
if (input_desc[0].format == nvinfer1::PluginFormat::kLINEAR) {
phi::GroupNormDirectCUDAFunctor<half, float> group_norm;
group_norm(stream,
input,
input_shape,
reinterpret_cast<half *>(bias_gpu_),
reinterpret_cast<half *>(scale_gpu_),
temp_variance_d,
groups,
eps,
output,
mean_d,
variance_d,
DataLayout::kNCHW);
} else if (input_desc[0].format == nvinfer1::PluginFormat::kHWC8) {
int32_t cPerBlock = 320;
int32_t maxBlocksPerHW = 1024;
switch (input_desc[0].dims.d[1]) {
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
}
phi::GroupNormDirectCUDAFunctor<half, float> group_norm;
group_norm(stream,
input,
input_shape,
reinterpret_cast<half *>(bias_gpu_),
reinterpret_cast<half *>(scale_gpu_),
temp_variance_d,
groups,
eps,
output,
mean_d,
variance_d,
DataLayout::kNCHW);
params_.withSwish = false;
params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]);
params_.gamma = scale_gpu_;
params_.beta = bias_gpu_;
params_.redBuffer = static_cast<float *>(workspace);
params_.n = input_desc[0].dims.d[0];
params_.h = input_desc[0].dims.d[2];
params_.w = input_desc[0].dims.d[3];
params_.c = input_desc[0].dims.d[1];
params_.groups = groups_;
params_.hw = params_.h * params_.w;
const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW);
params_.hwPerBlock = divUp(params_.hw, blocksPerHW);
params_.cPerBlock = cPerBlock;
params_.cPerGroup = params_.c / params_.groups;
params_.hwc = params_.hw * params_.c;
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.groupsPerBlock = cPerBlock / params_.cPerGroup;
params_.eps = eps_;
cudaMemsetAsync(params_.redBuffer,
0,
2 * sizeof(float) * params_.n * groups_,
stream);
groupNormNHWCSum(params_, stream);
groupNormNHWCScale(params_, stream);
} else {
PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin's only support nchw or nhwc8 input"));
}
} else {
// input not float
PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin's only support fp32 input"));
"The Groupnorm TRT Plugin's only support fp32 or fp16 input"));
}
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -21,7 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
......@@ -274,6 +276,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
float eps_;
std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_;
GroupNormNHWCParams params_;
bool with_fp16_;
};
class GroupNormPluginDynamicCreator : public TensorRTPluginCreator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册