“f1143f0cdbab985d1d3878ff27a47409d237e932”上不存在“paddle/phi/kernels/selected_rows/shape_kernel.cc”
未验证 提交 e61d892a 编写于 作者: Y yangjianfengo1 提交者: GitHub

[Inference] Replace groupNorm when data types are bf16 and fp16, and data...

[Inference] Replace groupNorm when data types are bf16 and fp16, and data format is NHWC implementation. (#55399)

* finish

* cpergroup odd

* fix bf16

* single channel

* code style

* jingdu duiqi

* add head_file

* add bf16 head file

* bf16 2

* bf16

* bf16 head

* bf16 compile

* py test

* bf16 compile

* bf16 compile

* unset py test

* nhwc

* test

* mean var

* bf16 success

* su

* ctest success

* use is_same_as

* is_same

* use is_same

* rtol

* gpu_stream

* del sigmod

* fix bfloat16 type

* use cuda_bf16_hpp

* use_cuda_arch

* bfloat162float2

* del inplace_tol

* del max_releative_tol

* temp store

* jingdu duiqi

* temp store

* plugin

* jingdu duiqi

* duiqi

* include cuda.h

* del half

* half single

* ci

* add const

* ci

* cudamemset

* del printf

* fp16 test

* add half compute

* del br16 ci

* del ci

* ci approve

* del fluid include
上级 883ccdd5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
__half* dst;
// The output buffer. Layout NHWC.
__half* eleOut;
// The input buffer. Layout NHWC.
__half const* srcX;
// The input buffer. Layout NHWC.
__half const* srcY;
// The gamma scaling factor.
void const* gamma;
// The beta term to add in GN.
void const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h, w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Silu activation function?
bool withSilu;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw, hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock, cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
// epsilon, Constant for numerical stability
float eps;
// for NCHW32 int8 use
float dqScaleIn;
float inv_qScale;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -73,123 +73,8 @@ static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { ...@@ -73,123 +73,8 @@ static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
} }
template <int tTHREADS_PER_BLOCK> template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { __global__ void groupNormNCHW32SumKernelQDQ(
// The object in charge of doing the sums for the different blocks. const GroupNormNHWCParams<__half> params) {
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
// 2 channels per thread
if (cj == params.cPerGroup - 2) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCSumKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCSumKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCSumKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of GroupNormPlugin TRT Plugin "
"encounter error"));
}
}
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) {
// The object in charge of doing the sums for the different blocks. // The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan; typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
...@@ -281,7 +166,7 @@ __global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) { ...@@ -281,7 +166,7 @@ __global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) {
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
} }
void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params, void groupNormNCHW32SumQDQ(const GroupNormNHWCParams<__half> &params,
cudaStream_t stream) { cudaStream_t stream) {
dim3 grid; dim3 grid;
...@@ -313,7 +198,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params, ...@@ -313,7 +198,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params,
template <int tTHREADS_PER_BLOCK> template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNCHW32ScaleKernelQDQ( __global__ void groupNormNCHW32ScaleKernelQDQ(
const GroupNormNHWCParams params) { const GroupNormNHWCParams<__half> params) {
// The instance in the batch. // The instance in the batch.
int32_t ni = blockIdx.z; int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2). // The channel loaded by that thread (2 channels per thread for F16x2).
...@@ -405,7 +290,7 @@ __global__ void groupNormNCHW32ScaleKernelQDQ( ...@@ -405,7 +290,7 @@ __global__ void groupNormNCHW32ScaleKernelQDQ(
} }
} }
void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params, void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams<__half> &params,
cudaStream_t stream) { cudaStream_t stream) {
dim3 grid; dim3 grid;
...@@ -439,112 +324,6 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params, ...@@ -439,112 +324,6 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params,
} }
} }
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}
// Load gamma/beta.
float2 gammaF2, betaF2;
if (ci < params.c) {
gammaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.gamma) + ci));
betaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.beta) + ci));
}
// Compute the mean.
float mean = sum * params.invHWC;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Silu if needed.
if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
if (ci < params.c) {
*reinterpret_cast<__half2 *>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
}
void groupNormNHWCScale(const GroupNormNHWCParams &params,
cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCScaleKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCScaleKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCScaleKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCScale of GroupNormPlugin TRT Plugin "
"encounter error"));
}
}
int GroupNormPlugin::initialize() TRT_NOEXCEPT { int GroupNormPlugin::initialize() TRT_NOEXCEPT {
if (!with_fp16_) { if (!with_fp16_) {
// if use fp32 // if use fp32
...@@ -886,9 +665,10 @@ int GroupNormPluginDynamic::enqueue( ...@@ -886,9 +665,10 @@ int GroupNormPluginDynamic::enqueue(
params_.withSilu = with_silu_; params_.withSilu = with_silu_;
params_.dst = static_cast<half *>(outputs[0]); params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
params_.gamma = scale_gpu_; params_.gamma = reinterpret_cast<half *>(scale_gpu_);
params_.beta = bias_gpu_; params_.beta = reinterpret_cast<half *>(bias_gpu_);
params_.redBuffer = static_cast<float *>(workspace); params_.redBuffer = static_cast<float *>(workspace);
params_.var_data = nullptr;
params_.n = input_desc[0].dims.d[0]; params_.n = input_desc[0].dims.d[0];
params_.h = input_desc[0].dims.d[2]; params_.h = input_desc[0].dims.d[2];
params_.w = input_desc[0].dims.d[3]; params_.w = input_desc[0].dims.d[3];
...@@ -903,13 +683,17 @@ int GroupNormPluginDynamic::enqueue( ...@@ -903,13 +683,17 @@ int GroupNormPluginDynamic::enqueue(
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup); params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.groupsPerBlock = cPerBlock / params_.cPerGroup; params_.groupsPerBlock = cPerBlock / params_.cPerGroup;
params_.eps = eps_; params_.eps = eps_;
params_.var_data = nullptr;
cudaMemsetAsync(params_.redBuffer, cudaMemsetAsync(params_.redBuffer,
0, 0,
2 * sizeof(float) * params_.n * groups_, 2 * sizeof(float) * params_.n * groups_,
stream); stream);
groupNormNHWCSum(params_, stream);
groupNormNHWCScale(params_, stream); phi::groupNormNHWCSum<half> nhwc_sum;
nhwc_sum(&params_, stream);
phi::groupNormNHWCScale<half> nhwc_scale;
nhwc_scale(params_, stream);
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); "The Groupnorm TRT Plugin's only support nchw or nhwc8 input"));
......
...@@ -21,13 +21,15 @@ limitations under the License. */ ...@@ -21,13 +21,15 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
using phi::GroupNormNHWCParams;
class GroupNormPlugin : public PluginTensorRT { class GroupNormPlugin : public PluginTensorRT {
public: public:
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
...@@ -287,7 +289,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -287,7 +289,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
float eps_; float eps_;
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
GroupNormNHWCParams params_; GroupNormNHWCParams<half> params_;
bool with_silu_; bool with_silu_;
bool with_fp16_; bool with_fp16_;
bool with_int8_; bool with_int8_;
......
...@@ -120,7 +120,8 @@ struct GroupSumsOp { ...@@ -120,7 +120,8 @@ struct GroupSumsOp {
}; };
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { __global__ void prelnGroupNormNHWCSumKernel(
GroupNormNHWCParams<__half> params) {
// The object in charge of doing the sums for the different blocks. // The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan; typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
...@@ -212,7 +213,7 @@ __global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { ...@@ -212,7 +213,7 @@ __global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
} }
void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params, void prelnGroupNormNHWCSum(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the values are as we expect. // Make sure the values are as we expect.
PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, PADDLE_ENFORCE_EQ(params.c % params.cPerBlock,
...@@ -272,7 +273,8 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params, ...@@ -272,7 +273,8 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params,
} }
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { __global__ void prelnGroupNormNHWCScaleKernel(
GroupNormNHWCParams<__half> params) {
// The instance in the batch. // The instance in the batch.
int32_t ni = blockIdx.z; int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2). // The channel loaded by that thread (2 channels per thread for F16x2).
...@@ -343,7 +345,7 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ...@@ -343,7 +345,7 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
} }
} }
void prelnGroupNormNHWCScale(GroupNormNHWCParams const &params, void prelnGroupNormNHWCScale(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect. // Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -21,13 +21,14 @@ limitations under the License. */ ...@@ -21,13 +21,14 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
using phi::GroupNormNHWCParams;
class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
public: public:
PrelnGroupnormActPluginDynamic(const float* scale, PrelnGroupnormActPluginDynamic(const float* scale,
...@@ -173,7 +174,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -173,7 +174,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> bias_; std::vector<float> bias_;
std::shared_ptr<void> scale_gpu_; std::shared_ptr<void> scale_gpu_;
std::shared_ptr<void> bias_gpu_; std::shared_ptr<void> bias_gpu_;
GroupNormNHWCParams params_; GroupNormNHWCParams<__half> params_;
int groups_; int groups_;
float eps_; float eps_;
bool with_silu_; bool with_silu_;
......
...@@ -131,7 +131,7 @@ struct GroupSumsOp { ...@@ -131,7 +131,7 @@ struct GroupSumsOp {
}; };
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams<__half> params) {
// The object in charge of doing the sums for the different blocks. // The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan; typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
...@@ -224,7 +224,7 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { ...@@ -224,7 +224,7 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
} }
void skipGroupNormNHWCSum(GroupNormNHWCParams const &params, void skipGroupNormNHWCSum(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the values are as we expect. // Make sure the values are as we expect.
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -282,7 +282,8 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params, ...@@ -282,7 +282,8 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params,
} }
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { __global__ void skipGroupNormNHWCScaleKernel(
GroupNormNHWCParams<__half> params) {
// The instance in the batch. // The instance in the batch.
int32_t ni = blockIdx.z; int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2). // The channel loaded by that thread (2 channels per thread for F16x2).
...@@ -353,7 +354,7 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ...@@ -353,7 +354,7 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
} }
} }
void skipGroupNormNHWCScale(GroupNormNHWCParams const &params, void skipGroupNormNHWCScale(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect. // Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, PADDLE_ENFORCE_EQ(params.c % params.cPerBlock,
......
...@@ -21,13 +21,14 @@ limitations under the License. */ ...@@ -21,13 +21,14 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
using phi::GroupNormNHWCParams;
class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT {
public: public:
SkipGroupnormActPluginDynamic(const float* scale, SkipGroupnormActPluginDynamic(const float* scale,
...@@ -168,7 +169,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -168,7 +169,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> bias_; std::vector<float> bias_;
std::shared_ptr<void> scale_gpu_; std::shared_ptr<void> scale_gpu_;
std::shared_ptr<void> bias_gpu_; std::shared_ptr<void> bias_gpu_;
GroupNormNHWCParams params_; GroupNormNHWCParams<__half> params_;
int groups_; int groups_;
float eps_; float eps_;
bool with_fp16_; bool with_fp16_;
......
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_fp16.h>
#endif
#include <stdint.h>
namespace phi { namespace phi {
...@@ -52,4 +57,70 @@ class GroupNormDirectCUDAFunctor { ...@@ -52,4 +57,70 @@ class GroupNormDirectCUDAFunctor {
}; };
#endif #endif
template <typename T>
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
T* dst;
// The output buffer. Layout NHWC.
T* eleOut;
// The input buffer. Layout NHWC.
T const* srcX;
// The input buffer. Layout NHWC.
T const* srcY;
// The gamma scaling factor.
void const* gamma;
// The beta term to add in GN.
void const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
float* var_data;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h, w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Silu activation function?
bool withSilu;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw, hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock, cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
// epsilon, Constant for numerical stability
float eps;
// for NCHW32 int8 use
float dqScaleIn;
float inv_qScale;
};
template <typename T>
class groupNormNHWCSum {
public:
void operator()(GroupNormNHWCParams<T>* params, const gpuStream_t stream);
};
template <typename T>
class groupNormNHWCScale {
public:
void operator()(const GroupNormNHWCParams<T>& params,
const gpuStream_t stream);
};
} // namespace phi } // namespace phi
...@@ -511,7 +511,11 @@ class GroupNorm(Layer): ...@@ -511,7 +511,11 @@ class GroupNorm(Layer):
"Mean": mean_out, "Mean": mean_out,
"Variance": variance_out, "Variance": variance_out,
}, },
attrs={"epsilon": self._epsilon, "groups": self._num_groups}, attrs={
"epsilon": self._epsilon,
"groups": self._num_groups,
"data_layout": self._data_format,
},
) )
return self._helper.append_activation(group_norm_out, None) return self._helper.append_activation(group_norm_out, None)
......
...@@ -350,6 +350,51 @@ class TestGroupNormOp2_With_NHWC(TestGroupNormOp): ...@@ -350,6 +350,51 @@ class TestGroupNormOp2_With_NHWC(TestGroupNormOp):
self.data_format = "NHWC" self.data_format = "NHWC"
class TestGroupNormFP16Op_With_NHWC(TestGroupNormFP16OP):
def init_test_case(self):
self.attrs['groups'] = 1
self.data_format = "NHWC"
self.attrs['epsilon'] = 0.5
self.shape = (1, 100, 4, 4)
self.dtype = np.float16
class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op):
def setUp(self):
self.op_type = "group_norm"
self.python_api = group_norm_wrapper
self.python_out_sig = ["Y"]
self.data_format = "NHWC"
self.dtype = np.uint16
self.shape = (1, 3, 5, 100)
self.attrs = {
'epsilon': 5e-2,
'groups': 2,
'data_layout': self.data_format,
}
self.compare_between_place = False
self.init_test_case()
input = np.random.random(self.shape).astype(np.float32)
scale = np.random.random([self.shape[3]]).astype(np.float32)
bias = np.random.random([self.shape[3]]).astype(np.float32)
output, mean, var = group_norm_naive(
input,
scale,
bias,
self.attrs['epsilon'],
self.attrs['groups'],
self.data_format,
)
self.inputs = {
'X': convert_float_to_uint16(input),
'Scale': convert_float_to_uint16(scale),
'Bias': convert_float_to_uint16(bias),
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp): class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp):
def init_test_case(self): def init_test_case(self):
self.attrs['groups'] = 1 self.attrs['groups'] = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册