未验证 提交 e61d892a 编写于 作者: Y yangjianfengo1 提交者: GitHub

[Inference] Replace groupNorm when data types are bf16 and fp16, and data...

[Inference] Replace groupNorm when data types are bf16 and fp16, and data format is NHWC implementation. (#55399)

* finish

* cpergroup odd

* fix bf16

* single channel

* code style

* jingdu duiqi

* add head_file

* add bf16 head file

* bf16 2

* bf16

* bf16 head

* bf16 compile

* py test

* bf16 compile

* bf16 compile

* unset py test

* nhwc

* test

* mean var

* bf16 success

* su

* ctest success

* use is_same_as

* is_same

* use is_same

* rtol

* gpu_stream

* del sigmod

* fix bfloat16 type

* use cuda_bf16_hpp

* use_cuda_arch

* bfloat162float2

* del inplace_tol

* del max_releative_tol

* temp store

* jingdu duiqi

* temp store

* plugin

* jingdu duiqi

* duiqi

* include cuda.h

* del half

* half single

* ci

* add const

* ci

* cudamemset

* del printf

* fp16 test

* add half compute

* del br16 ci

* del ci

* ci approve

* del fluid include
上级 883ccdd5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
__half* dst;
// The output buffer. Layout NHWC.
__half* eleOut;
// The input buffer. Layout NHWC.
__half const* srcX;
// The input buffer. Layout NHWC.
__half const* srcY;
// The gamma scaling factor.
void const* gamma;
// The beta term to add in GN.
void const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h, w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Silu activation function?
bool withSilu;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw, hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock, cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
// epsilon, Constant for numerical stability
float eps;
// for NCHW32 int8 use
float dqScaleIn;
float inv_qScale;
};
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -73,123 +73,8 @@ static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { ...@@ -73,123 +73,8 @@ static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
} }
template <int tTHREADS_PER_BLOCK> template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { __global__ void groupNormNCHW32SumKernelQDQ(
// The object in charge of doing the sums for the different blocks. const GroupNormNHWCParams<__half> params) {
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[tTHREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sumSq += f2.x * f2.x + f2.y * f2.y;
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
// The data for the summations.
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
// Do the segmented scan.
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
// 2 channels per thread
if (cj == params.cPerGroup - 2) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
// Make sure the data is in shared memory.
__syncthreads();
// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
}
// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];
// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCSumKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCSumKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCSumKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of GroupNormPlugin TRT Plugin "
"encounter error"));
}
}
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) {
// The object in charge of doing the sums for the different blocks. // The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan; typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
...@@ -281,7 +166,7 @@ __global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) { ...@@ -281,7 +166,7 @@ __global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) {
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
} }
void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params, void groupNormNCHW32SumQDQ(const GroupNormNHWCParams<__half> &params,
cudaStream_t stream) { cudaStream_t stream) {
dim3 grid; dim3 grid;
...@@ -313,7 +198,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params, ...@@ -313,7 +198,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params,
template <int tTHREADS_PER_BLOCK> template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNCHW32ScaleKernelQDQ( __global__ void groupNormNCHW32ScaleKernelQDQ(
const GroupNormNHWCParams params) { const GroupNormNHWCParams<__half> params) {
// The instance in the batch. // The instance in the batch.
int32_t ni = blockIdx.z; int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2). // The channel loaded by that thread (2 channels per thread for F16x2).
...@@ -405,7 +290,7 @@ __global__ void groupNormNCHW32ScaleKernelQDQ( ...@@ -405,7 +290,7 @@ __global__ void groupNormNCHW32ScaleKernelQDQ(
} }
} }
void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params, void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams<__half> &params,
cudaStream_t stream) { cudaStream_t stream) {
dim3 grid; dim3 grid;
...@@ -439,112 +324,6 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params, ...@@ -439,112 +324,6 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params,
} }
} }
template <int tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}
// Load gamma/beta.
float2 gammaF2, betaF2;
if (ci < params.c) {
gammaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.gamma) + ci));
betaF2 = __half22float2(*reinterpret_cast<half2 const *>(
reinterpret_cast<half const *>(params.beta) + ci));
}
// Compute the mean.
float mean = sum * params.invHWC;
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2(0, 0);
if (ci < params.c) {
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
}
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Silu if needed.
if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
if (ci < params.c) {
*reinterpret_cast<__half2 *>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
}
void groupNormNHWCScale(const GroupNormNHWCParams &params,
cudaStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
switch (params.cPerBlock) {
case 320:
groupNormNHWCScaleKernel<160><<<grid, 160, 0, stream>>>(params);
break;
case 480:
groupNormNHWCScaleKernel<256><<<grid, 256, 0, stream>>>(params);
break;
case 256:
groupNormNHWCScaleKernel<128><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCScale of GroupNormPlugin TRT Plugin "
"encounter error"));
}
}
int GroupNormPlugin::initialize() TRT_NOEXCEPT { int GroupNormPlugin::initialize() TRT_NOEXCEPT {
if (!with_fp16_) { if (!with_fp16_) {
// if use fp32 // if use fp32
...@@ -886,9 +665,10 @@ int GroupNormPluginDynamic::enqueue( ...@@ -886,9 +665,10 @@ int GroupNormPluginDynamic::enqueue(
params_.withSilu = with_silu_; params_.withSilu = with_silu_;
params_.dst = static_cast<half *>(outputs[0]); params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
params_.gamma = scale_gpu_; params_.gamma = reinterpret_cast<half *>(scale_gpu_);
params_.beta = bias_gpu_; params_.beta = reinterpret_cast<half *>(bias_gpu_);
params_.redBuffer = static_cast<float *>(workspace); params_.redBuffer = static_cast<float *>(workspace);
params_.var_data = nullptr;
params_.n = input_desc[0].dims.d[0]; params_.n = input_desc[0].dims.d[0];
params_.h = input_desc[0].dims.d[2]; params_.h = input_desc[0].dims.d[2];
params_.w = input_desc[0].dims.d[3]; params_.w = input_desc[0].dims.d[3];
...@@ -903,13 +683,17 @@ int GroupNormPluginDynamic::enqueue( ...@@ -903,13 +683,17 @@ int GroupNormPluginDynamic::enqueue(
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup); params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.groupsPerBlock = cPerBlock / params_.cPerGroup; params_.groupsPerBlock = cPerBlock / params_.cPerGroup;
params_.eps = eps_; params_.eps = eps_;
params_.var_data = nullptr;
cudaMemsetAsync(params_.redBuffer, cudaMemsetAsync(params_.redBuffer,
0, 0,
2 * sizeof(float) * params_.n * groups_, 2 * sizeof(float) * params_.n * groups_,
stream); stream);
groupNormNHWCSum(params_, stream);
groupNormNHWCScale(params_, stream); phi::groupNormNHWCSum<half> nhwc_sum;
nhwc_sum(&params_, stream);
phi::groupNormNHWCScale<half> nhwc_scale;
nhwc_scale(params_, stream);
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); "The Groupnorm TRT Plugin's only support nchw or nhwc8 input"));
......
...@@ -21,13 +21,15 @@ limitations under the License. */ ...@@ -21,13 +21,15 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
using phi::GroupNormNHWCParams;
class GroupNormPlugin : public PluginTensorRT { class GroupNormPlugin : public PluginTensorRT {
public: public:
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
...@@ -287,7 +289,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -287,7 +289,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
float eps_; float eps_;
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
GroupNormNHWCParams params_; GroupNormNHWCParams<half> params_;
bool with_silu_; bool with_silu_;
bool with_fp16_; bool with_fp16_;
bool with_int8_; bool with_int8_;
......
...@@ -120,7 +120,8 @@ struct GroupSumsOp { ...@@ -120,7 +120,8 @@ struct GroupSumsOp {
}; };
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { __global__ void prelnGroupNormNHWCSumKernel(
GroupNormNHWCParams<__half> params) {
// The object in charge of doing the sums for the different blocks. // The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan; typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
...@@ -212,7 +213,7 @@ __global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { ...@@ -212,7 +213,7 @@ __global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
} }
void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params, void prelnGroupNormNHWCSum(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the values are as we expect. // Make sure the values are as we expect.
PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, PADDLE_ENFORCE_EQ(params.c % params.cPerBlock,
...@@ -272,7 +273,8 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params, ...@@ -272,7 +273,8 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params,
} }
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { __global__ void prelnGroupNormNHWCScaleKernel(
GroupNormNHWCParams<__half> params) {
// The instance in the batch. // The instance in the batch.
int32_t ni = blockIdx.z; int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2). // The channel loaded by that thread (2 channels per thread for F16x2).
...@@ -343,7 +345,7 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ...@@ -343,7 +345,7 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
} }
} }
void prelnGroupNormNHWCScale(GroupNormNHWCParams const &params, void prelnGroupNormNHWCScale(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect. // Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -21,13 +21,14 @@ limitations under the License. */ ...@@ -21,13 +21,14 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
using phi::GroupNormNHWCParams;
class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
public: public:
PrelnGroupnormActPluginDynamic(const float* scale, PrelnGroupnormActPluginDynamic(const float* scale,
...@@ -173,7 +174,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -173,7 +174,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> bias_; std::vector<float> bias_;
std::shared_ptr<void> scale_gpu_; std::shared_ptr<void> scale_gpu_;
std::shared_ptr<void> bias_gpu_; std::shared_ptr<void> bias_gpu_;
GroupNormNHWCParams params_; GroupNormNHWCParams<__half> params_;
int groups_; int groups_;
float eps_; float eps_;
bool with_silu_; bool with_silu_;
......
...@@ -131,7 +131,7 @@ struct GroupSumsOp { ...@@ -131,7 +131,7 @@ struct GroupSumsOp {
}; };
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams<__half> params) {
// The object in charge of doing the sums for the different blocks. // The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan; typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;
...@@ -224,7 +224,7 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { ...@@ -224,7 +224,7 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
} }
void skipGroupNormNHWCSum(GroupNormNHWCParams const &params, void skipGroupNormNHWCSum(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the values are as we expect. // Make sure the values are as we expect.
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -282,7 +282,8 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params, ...@@ -282,7 +282,8 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params,
} }
template <int32_t tTHREADS_PER_BLOCK> template <int32_t tTHREADS_PER_BLOCK>
__global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { __global__ void skipGroupNormNHWCScaleKernel(
GroupNormNHWCParams<__half> params) {
// The instance in the batch. // The instance in the batch.
int32_t ni = blockIdx.z; int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2). // The channel loaded by that thread (2 channels per thread for F16x2).
...@@ -353,7 +354,7 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ...@@ -353,7 +354,7 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
} }
} }
void skipGroupNormNHWCScale(GroupNormNHWCParams const &params, void skipGroupNormNHWCScale(GroupNormNHWCParams<__half> const &params,
cudaStream_t stream) { cudaStream_t stream) {
// Make sure the dimensions are aligned with what we expect. // Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, PADDLE_ENFORCE_EQ(params.c % params.cPerBlock,
......
...@@ -21,13 +21,14 @@ limitations under the License. */ ...@@ -21,13 +21,14 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/kernels/group_norm_kernel.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
using phi::GroupNormNHWCParams;
class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT {
public: public:
SkipGroupnormActPluginDynamic(const float* scale, SkipGroupnormActPluginDynamic(const float* scale,
...@@ -168,7 +169,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -168,7 +169,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> bias_; std::vector<float> bias_;
std::shared_ptr<void> scale_gpu_; std::shared_ptr<void> scale_gpu_;
std::shared_ptr<void> bias_gpu_; std::shared_ptr<void> bias_gpu_;
GroupNormNHWCParams params_; GroupNormNHWCParams<__half> params_;
int groups_; int groups_;
float eps_; float eps_;
bool with_fp16_; bool with_fp16_;
......
...@@ -26,6 +26,676 @@ ...@@ -26,6 +26,676 @@
namespace phi { namespace phi {
static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; }
static inline __device__ __host__ float sigmoid(float x) {
return 1.F / (1.F + expf(-x));
}
#ifdef PADDLE_CUDA_BF16
__host__ __device__ inline float2 bfloat1622float2(const __nv_bfloat162 a) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __bfloat1622float2(a);
#else
float hi_float;
float lo_float;
lo_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).x);
hi_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).y);
return make_float2(lo_float, hi_float);
#endif
}
__host__ __device__ inline __nv_bfloat162 float22bfloat162_rn(const float2 a) {
__nv_bfloat162 val;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
val = __float22bfloat162_rn(a);
#else
val.x = __float2bfloat16_rn(a.x);
val.y = __float2bfloat16_rn(a.y);
#endif
return val;
}
#endif
template <typename T>
__host__ __device__ inline float __2float(const T a) {
return static_cast<float>(a);
}
template <>
__host__ __device__ inline float __2float<__half>(const __half a) {
return __half2float(a);
}
template <typename T>
__host__ __device__ inline T __2dst(const float a) {
return static_cast<T>(a);
}
template <>
__host__ __device__ inline __half __2dst<__half>(const float a) {
return __float2half(a);
}
struct GroupSums {
// Is it the 1st element of the group?
int32_t flag;
// The sum.
float sum;
// The sum of squares.
float sumSq;
};
struct GroupSumsOp {
inline __device__ GroupSums operator()(GroupSums const& a,
GroupSums const& b) {
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.flag = a.flag + b.flag;
return dst;
}
};
static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
int32_t maxDivisor = -1;
for (int32_t i = 1; i <= std::sqrt(n); i++) {
if (n % i == 0) {
int32_t divisor1 = n / i;
int32_t divisor2 = i;
if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
maxDivisor = divisor1;
}
if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
maxDivisor = divisor2;
}
}
}
return maxDivisor;
}
template <typename T, int THREADS_PER_CHANNEL>
inline __device__ void UpdateSum(const T* srcX, float* sum, float* sumSq) {
float src_data = phi::__2float<T>(*srcX);
*sum += src_data;
*sumSq += src_data * src_data;
}
template <>
inline __device__ void UpdateSum<__half, 2>(const __half* srcX,
float* sum,
float* sumSq) {
__half2 h2 = *reinterpret_cast<__half2 const*>(srcX);
float2 f2 = __half22float2(h2);
*sum += f2.x + f2.y;
*sumSq += f2.x * f2.x + f2.y * f2.y;
}
template <>
inline __device__ void UpdateSum<phi::dtype::float16, 2>(
const phi::dtype::float16* srcX, float* sum, float* sumSq) {
__half2 h2 = *reinterpret_cast<__half2 const*>(srcX);
float2 f2 = __half22float2(h2);
*sum += f2.x + f2.y;
*sumSq += f2.x * f2.x + f2.y * f2.y;
}
#ifdef PADDLE_CUDA_BF16
template <>
inline __device__ void UpdateSum<phi::dtype::bfloat16, 2>(
const phi::dtype::bfloat16* srcX, float* sum, float* sumSq) {
__nv_bfloat162 h2 = *reinterpret_cast<__nv_bfloat162 const*>(srcX);
float2 f2 = phi::bfloat1622float2(h2);
*sum += f2.x + f2.y;
*sumSq += f2.x * f2.x + f2.y * f2.y;
}
#endif
template <typename T, int THREADS_PER_BLOCK>
__global__ void groupNormNHWCSumSingerChannelKernel(
const GroupNormNHWCParams<T> params) {
// The instance in the batch.
__shared__ float2 smem[THREADS_PER_BLOCK];
int32_t ni = blockIdx.z;
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x;
if (ci >= params.c) {
return;
}
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
float src_data = *reinterpret_cast<float const*>(&params.srcX[offset]);
UpdateSum<T, 1>(&params.srcX[offset], &sum, &sumSq);
}
smem[threadIdx.x] = make_float2(sum, sumSq);
__syncthreads();
float2 sums = smem[threadIdx.x];
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + ci],
sums.x * params.invHWC);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + ci], sums.y);
}
template <typename T, int THREADS_PER_BLOCK, int THREADS_PER_CHANNEL>
__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams<T> params) {
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<GroupSums, THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for BlockScan.
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[THREADS_PER_BLOCK];
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci =
blockIdx.x * params.cPerBlock + threadIdx.x * THREADS_PER_CHANNEL;
if (ci >= params.c || threadIdx.x * THREADS_PER_CHANNEL >= params.cPerBlock) {
return;
}
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
// The sums.
float sum = 0.F;
float sumSq = 0.F;
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The offset.
int64_t offset = static_cast<int64_t>(ni) * params.hwc +
static_cast<int64_t>(hwi) * params.c + ci;
float src_data = *reinterpret_cast<float const*>(&params.srcX[offset]);
UpdateSum<T, THREADS_PER_CHANNEL>(&params.srcX[offset], &sum, &sumSq);
}
// The group that thread works on and the channel in the group (modulus).
int32_t gi =
ci / params.cPerGroup - blockIdx.x * params.cPerBlock / params.cPerGroup;
int32_t cj = ci % params.cPerGroup;
int flag = (cj == 0 || threadIdx.x == 0) ? 1 : 0;
GroupSums inp{flag, sum, sumSq};
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
if (cj == params.cPerGroup - THREADS_PER_CHANNEL ||
threadIdx.x * THREADS_PER_CHANNEL ==
params.cPerBlock - THREADS_PER_CHANNEL) {
smem[gi] = make_float2(out.sum, out.sumSq);
}
__syncthreads();
int32_t gj = ci / params.cPerGroup;
if (cj == params.cPerGroup - THREADS_PER_CHANNEL ||
threadIdx.x * THREADS_PER_CHANNEL ==
params.cPerBlock - THREADS_PER_CHANNEL) {
float2 sums = smem[gi];
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj],
sums.x * params.invHWC);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}
}
template <typename T>
void groupNormNHWCSum<T>::operator()(GroupNormNHWCParams<T>* params,
gpuStream_t stream) {
dim3 grid;
grid.x = divUp(params->c, params->cPerBlock);
grid.y = divUp(params->hw, params->hwPerBlock);
grid.z = params->n;
if (params->cPerGroup % 2 == 0) {
switch (params->cPerBlock) {
case 512:
case 480:
groupNormNHWCSumKernel<T, 256, 2><<<grid, 256, 0, stream>>>(*params);
break;
case 320:
groupNormNHWCSumKernel<T, 160, 2><<<grid, 160, 0, stream>>>(*params);
break;
case 256:
groupNormNHWCSumKernel<T, 128, 2><<<grid, 128, 0, stream>>>(*params);
break;
case 128:
groupNormNHWCSumKernel<T, 64, 2><<<grid, 64, 0, stream>>>(*params);
break;
default:
grid.x = divUp(params->c, 128);
params->cPerBlock = 128;
groupNormNHWCSumKernel<T, 64, 2><<<grid, 64, 0, stream>>>(*params);
}
} else {
if (params->cPerGroup != 1) {
switch (params->cPerBlock) {
case 512:
groupNormNHWCSumKernel<T, 512, 1><<<grid, 512, 0, stream>>>(*params);
break;
case 480:
groupNormNHWCSumKernel<T, 480, 1><<<grid, 480, 0, stream>>>(*params);
break;
case 320:
groupNormNHWCSumKernel<T, 320, 1><<<grid, 320, 0, stream>>>(*params);
break;
case 256:
groupNormNHWCSumKernel<T, 256, 1><<<grid, 256, 0, stream>>>(*params);
break;
case 128:
groupNormNHWCSumKernel<T, 128, 1><<<grid, 128, 0, stream>>>(*params);
break;
default:
grid.x = divUp(params->c, 128);
params->cPerBlock = 128;
groupNormNHWCSumKernel<T, 128, 1><<<grid, 128, 0, stream>>>(*params);
}
} else {
switch (params->cPerBlock) {
case 512:
groupNormNHWCSumSingerChannelKernel<T, 512>
<<<grid, 512, 0, stream>>>(*params);
break;
case 480:
groupNormNHWCSumSingerChannelKernel<T, 480>
<<<grid, 480, 0, stream>>>(*params);
break;
case 320:
groupNormNHWCSumSingerChannelKernel<T, 320>
<<<grid, 320, 0, stream>>>(*params);
break;
case 256:
groupNormNHWCSumSingerChannelKernel<T, 256>
<<<grid, 256, 0, stream>>>(*params);
break;
case 128:
groupNormNHWCSumSingerChannelKernel<T, 128>
<<<grid, 128, 0, stream>>>(*params);
break;
default:
grid.x = divUp(params->c, 128);
params->cPerBlock = 128;
groupNormNHWCSumSingerChannelKernel<T, 128>
<<<grid, 128, 0, stream>>>(*params);
}
}
}
}
template class groupNormNHWCSum<half>;
template <typename T, int THREADS_PER_CHANNEL>
inline __device__ void GroupNormCompute(int32_t hwBegin,
int32_t hwEnd,
int32_t ci,
const GroupNormNHWCParams<T>& params,
float mean,
float invStdDev) {
float gamma =
phi::__2float<T>(*(reinterpret_cast<T const*>(params.gamma) + ci));
float beta =
phi::__2float<T>(*(reinterpret_cast<T const*>(params.beta) + ci));
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci;
const float src_data = phi::__2float<T>(params.srcX[offset]);
// Normalize the channels.
float dst_data = (src_data - mean) * invStdDev;
// Scale by gamma and add beta.
dst_data = gamma * dst_data + beta;
// Apply Silu if needed.
if (params.withSilu) {
dst_data = dst_data * sigmoid(dst_data);
}
// Store the scaled values.
*reinterpret_cast<T*>(&params.dst[offset]) = phi::__2dst<T>(dst_data);
}
}
template <>
inline __device__ void GroupNormCompute<phi::dtype::float16, 2>(
int32_t hwBegin,
int32_t hwEnd,
int32_t ci,
const GroupNormNHWCParams<phi::dtype::float16>& params,
float mean,
float invStdDev) {
float2 gammaF2, betaF2;
gammaF2 = __half22float2(*reinterpret_cast<__half2 const*>(
reinterpret_cast<half const*>(params.gamma) + ci));
betaF2 = __half22float2(*reinterpret_cast<__half2 const*>(
reinterpret_cast<half const*>(params.beta) + ci));
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2 = *reinterpret_cast<__half2 const*>(&params.srcX[offset]);
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Silu if needed.
if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
*reinterpret_cast<__half2*>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
template <>
inline __device__ void GroupNormCompute<__half, 2>(
int32_t hwBegin,
int32_t hwEnd,
int32_t ci,
const GroupNormNHWCParams<__half>& params,
float mean,
float invStdDev) {
float2 gammaF2, betaF2;
gammaF2 = __half22float2(*reinterpret_cast<__half2 const*>(
reinterpret_cast<half const*>(params.gamma) + ci));
betaF2 = __half22float2(*reinterpret_cast<__half2 const*>(
reinterpret_cast<half const*>(params.beta) + ci));
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__half2 h2 = *reinterpret_cast<__half2 const*>(&params.srcX[offset]);
// Extract the two half values.
float2 f2 = __half22float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Silu if needed.
if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
*reinterpret_cast<__half2*>(&params.dst[offset]) = __float22half2_rn(f2);
}
}
#ifdef PADDLE_CUDA_BF16
template <>
inline __device__ void GroupNormCompute<phi::dtype::bfloat16, 2>(
int32_t hwBegin,
int32_t hwEnd,
int32_t ci,
const GroupNormNHWCParams<phi::dtype::bfloat16>& params,
float mean,
float invStdDev) {
float2 gammaF2, betaF2;
gammaF2 = phi::bfloat1622float2(*reinterpret_cast<__nv_bfloat162 const*>(
reinterpret_cast<__nv_bfloat16 const*>(params.gamma) + ci));
betaF2 = phi::bfloat1622float2(*reinterpret_cast<__nv_bfloat162 const*>(
reinterpret_cast<__nv_bfloat16 const*>(params.beta) + ci));
// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci;
// Fetch two channels per thread.
__nv_bfloat162 h2 =
*reinterpret_cast<__nv_bfloat162 const*>(&params.srcX[offset]);
// Extract the two half values.
float2 f2 = phi::bfloat1622float2(h2);
// Normalize the channels.
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;
// Scale by gamma and add beta.
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Silu if needed.
if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
// Store the scaled values.
*reinterpret_cast<__nv_bfloat162*>(&params.dst[offset]) =
phi::float22bfloat162_rn(f2);
}
}
#endif
template <typename T, int THREADS_PER_CHANNEL>
__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams<T> params) {
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci =
blockIdx.x * params.cPerBlock + threadIdx.x * THREADS_PER_CHANNEL;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.cPerGroup;
if (ci >= params.c || gi >= params.groups) {
return;
}
// Load the sum and sum of squares for the group.
float mean = params.redBuffer[(2 * ni + 0) * params.groups + gi];
float sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
if (params.var_data != nullptr) {
params.var_data[ni * params.groups + gi] = var;
}
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.eps);
// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
GroupNormCompute<T, THREADS_PER_CHANNEL>(
hwBegin, hwEnd, ci, params, mean, invStdDev);
}
template <typename T>
void groupNormNHWCScale<T>::operator()(const GroupNormNHWCParams<T>& params,
gpuStream_t stream) {
dim3 grid;
// The number of blocks to compute all the channels.
grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;
if (params.cPerGroup % 2 == 0) {
switch (params.cPerBlock) {
case 512:
case 480:
groupNormNHWCScaleKernel<T, 2><<<grid, 256, 0, stream>>>(params);
break;
case 320:
groupNormNHWCScaleKernel<T, 2><<<grid, 160, 0, stream>>>(params);
break;
case 256:
groupNormNHWCScaleKernel<T, 2><<<grid, 128, 0, stream>>>(params);
break;
case 128:
groupNormNHWCScaleKernel<T, 2><<<grid, 64, 0, stream>>>(params);
break;
default:
grid.x = divUp(params.c, 128);
groupNormNHWCScaleKernel<T, 2><<<grid, 64, 0, stream>>>(params);
}
} else {
switch (params.cPerBlock) {
case 512:
groupNormNHWCScaleKernel<T, 1><<<grid, 512, 0, stream>>>(params);
break;
case 480:
groupNormNHWCScaleKernel<T, 1><<<grid, 480, 0, stream>>>(params);
break;
case 320:
groupNormNHWCScaleKernel<T, 1><<<grid, 320, 0, stream>>>(params);
break;
case 256:
groupNormNHWCScaleKernel<T, 1><<<grid, 256, 0, stream>>>(params);
break;
case 128:
groupNormNHWCScaleKernel<T, 1><<<grid, 128, 0, stream>>>(params);
break;
default:
grid.x = divUp(params.c, 128);
groupNormNHWCScaleKernel<T, 1><<<grid, 128, 0, stream>>>(params);
}
}
}
template class groupNormNHWCScale<half>;
template <typename T, typename Context>
void GroupNormNHWCKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
GroupNormNHWCParams<T> params_;
params_.withSilu = false;
const auto x_dims = x.dims();
dev_ctx.template Alloc<T>(y);
const T* x_data = x.data<T>();
T* y_data = y->data<T>();
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
params_.n = x_dims[0];
params_.c = x_dims[3];
params_.h = x_dims[1];
params_.w = x_dims[2];
dev_ctx.template Alloc<AccT>(mean);
dev_ctx.template Alloc<AccT>(var);
auto* mean_data = mean->data<AccT>();
auto* var_data = var->data<AccT>();
params_.var_data = var_data;
int32_t cPerBlock = 320;
int32_t maxBlocksPerHW = 1024;
switch (params_.c) {
case 2048:
case 1024:
cPerBlock = 512;
break;
case 960:
case 1920:
cPerBlock = 480;
break;
case 512:
case 256:
cPerBlock = 256;
break;
case 128:
cPerBlock = 128;
break;
default:
cPerBlock = 320;
}
params_.groups = groups;
params_.cPerGroup = params_.c / params_.groups;
if (cPerBlock % params_.cPerGroup != 0) {
cPerBlock = params_.cPerGroup;
}
params_.srcX = reinterpret_cast<const T*>(x_data);
params_.dst = reinterpret_cast<T*>(y_data);
params_.gamma = scale_data;
params_.beta = bias_data;
params_.hw = params_.h * params_.w;
const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW);
params_.hwPerBlock = divUp(params_.hw, blocksPerHW);
params_.cPerBlock = cPerBlock;
params_.hwc = params_.hw * params_.c;
params_.invHWC = 1.F / static_cast<float>(params_.hw * params_.cPerGroup);
params_.eps = epsilon;
auto stream = dev_ctx.stream();
DenseTensor redBuffer;
int buffer_sizes = 2 * params_.n * groups;
redBuffer.Resize({1, buffer_sizes});
params_.redBuffer = dev_ctx.template Alloc<float>(&redBuffer);
#ifdef PADDLE_WITH_HIP
hipMemset(params_.redBuffer, 0, buffer_sizes * sizeof(float));
#else
cudaMemset(params_.redBuffer, 0, buffer_sizes * sizeof(float));
#endif
groupNormNHWCSum<T> nhwc_sum;
nhwc_sum(&params_, stream);
groupNormNHWCScale<T> nhwc_scale;
nhwc_scale(params_, stream);
#ifdef PADDLE_WITH_HIP
phi::backends::gpu::GpuMemcpyAsync(mean_data,
params_.redBuffer,
params_.n * groups * sizeof(float),
hipMemcpyDeviceToHost,
stream);
#else
phi::backends::gpu::GpuMemcpyAsync(mean_data,
params_.redBuffer,
params_.n * groups * sizeof(float),
cudaMemcpyDeviceToHost,
stream);
#endif
}
template <typename T, typename AccT> template <typename T, typename AccT>
__global__ void GroupNormForwardGetMeanAndVar(const T* x, __global__ void GroupNormForwardGetMeanAndVar(const T* x,
int N, int N,
...@@ -117,27 +787,128 @@ __global__ void GroupNormForward(const T* x, ...@@ -117,27 +787,128 @@ __global__ void GroupNormForward(const T* x,
} }
} }
template <typename T, typename AccT>
void GroupNormDirectCUDAFunctor<T, AccT>::operator()(
gpuStream_t stream,
const T* input,
std::vector<int> input_shape,
const T* bias,
const T* scale,
AccT* temp_variance,
int groups,
float eps,
T* output,
AccT* mean,
AccT* variance,
const DataLayout data_layout) {
const auto input_ddim = phi::make_ddim(input_shape);
const int C =
(data_layout == DataLayout::kNCHW ? input_ddim[1]
: input_ddim[input_ddim.size() - 1]);
const int group_size = C / groups;
const int W =
(data_layout == DataLayout::kNCHW ? input_ddim[input_ddim.size() - 1]
: input_ddim[input_ddim.size() - 2]);
int image_size = 1;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < input_ddim.size(); ++i) {
image_size *= input_ddim[i];
}
} else {
for (int i = 1; i < input_ddim.size() - 1; ++i) {
image_size *= input_ddim[i];
}
}
#ifdef __HIPCC__
int block_size = std::max(std::min(256, image_size), 64);
#else
int block_size = std::min(1024, image_size);
#endif
dim3 grid(group_size, groups, input_ddim[0]);
dim3 threads(block_size, 1, 1);
if (data_layout == DataLayout::kNCHW) {
constexpr int vec_size = sizeof(float4) / sizeof(T);
int size = group_size * image_size; // group element size
const int max_num_threads = 1024;
int max_block_size = std::min(size / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, phi::kps::details::kWarpSize);
dim3 grids(input_ddim[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size * block_size_nchw) {
phi::ScalarGetMeanAndVarNCHW<T, AccT>
<<<grids, blocks, 0, stream>>>(input, mean, temp_variance, size);
} else {
phi::VectorizedGetMeanAndVarNCHW<T, AccT, vec_size>
<<<grids, blocks, 0, stream>>>(input, mean, temp_variance, size);
}
} else {
#ifdef PADDLE_WITH_HIP
hipMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups);
hipMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups);
#else
cudaMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups);
cudaMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups);
#endif
phi::GroupNormForwardGetMeanAndVar<T, AccT>
<<<grid, threads, 0, stream>>>(input,
input_ddim[0],
C,
W,
image_size,
groups,
group_size,
mean,
temp_variance);
}
GroupNormForward<T, AccT, 3>
<<<grid, threads, 0, stream>>>(input,
mean,
temp_variance,
scale,
bias,
input_ddim[0],
C,
W,
image_size,
groups,
group_size,
static_cast<AccT>(eps),
output,
variance,
data_layout);
}
template class GroupNormDirectCUDAFunctor<float, float>;
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template class GroupNormDirectCUDAFunctor<half, float>;
#endif
template <typename T, typename Context> template <typename T, typename Context>
void GroupNormKernel(const Context& dev_ctx, void GroupNormGeneralCaseKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const paddle::optional<DenseTensor>& scale, const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias, const paddle::optional<DenseTensor>& bias,
float epsilon, float epsilon,
int groups, int groups,
const std::string& data_layout_str, const std::string& data_layout_str,
DenseTensor* y, DenseTensor* y,
DenseTensor* mean, DenseTensor* mean,
DenseTensor* var) { DenseTensor* var) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type; using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr(); const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr(); const auto bias_ptr = bias.get_ptr();
const auto x_dims = x.dims(); const auto x_dims = x.dims();
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int group_size = C / groups; const int group_size = C / groups;
const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]); : x_dims[x_dims.size() - 2]);
...@@ -235,108 +1006,51 @@ void GroupNormKernel(const Context& dev_ctx, ...@@ -235,108 +1006,51 @@ void GroupNormKernel(const Context& dev_ctx,
data_layout); data_layout);
} }
template <typename T, typename AccT> template <typename T, typename Context>
void GroupNormDirectCUDAFunctor<T, AccT>::operator()( void GroupNormKernel(const Context& dev_ctx,
gpuStream_t stream, const DenseTensor& x,
const T* input, const paddle::optional<DenseTensor>& scale,
std::vector<int> input_shape, const paddle::optional<DenseTensor>& bias,
const T* bias, float epsilon,
const T* scale, int groups,
AccT* temp_variance, const std::string& data_layout_str,
int groups, DenseTensor* y,
float eps, DenseTensor* mean,
T* output, DenseTensor* var) {
AccT* mean, using std::is_same;
AccT* variance, if (is_same<T, phi::dtype::float16>::value && data_layout_str == "NHWC") {
const DataLayout data_layout) { GroupNormNHWCKernel<phi::dtype::float16, Context>(dev_ctx,
const auto input_ddim = phi::make_ddim(input_shape); x,
const int C = scale,
(data_layout == DataLayout::kNCHW ? input_ddim[1] bias,
: input_ddim[input_ddim.size() - 1]); epsilon,
const int group_size = C / groups; groups,
const int W = data_layout_str,
(data_layout == DataLayout::kNCHW ? input_ddim[input_ddim.size() - 1] y,
: input_ddim[input_ddim.size() - 2]); mean,
var);
int image_size = 1; return;
if (data_layout == DataLayout::kNCHW) {
for (int i = 2; i < input_ddim.size(); ++i) {
image_size *= input_ddim[i];
}
} else {
for (int i = 1; i < input_ddim.size() - 1; ++i) {
image_size *= input_ddim[i];
}
} }
#ifdef __HIPCC__
int block_size = std::max(std::min(256, image_size), 64);
#else
int block_size = std::min(1024, image_size);
#endif
dim3 grid(group_size, groups, input_ddim[0]);
dim3 threads(block_size, 1, 1);
if (data_layout == DataLayout::kNCHW) {
constexpr int vec_size = sizeof(float4) / sizeof(T);
int size = group_size * image_size; // group element size
const int max_num_threads = 1024;
int max_block_size = std::min(size / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, phi::kps::details::kWarpSize);
dim3 grids(input_ddim[0] * groups);
dim3 blocks(block_size_nchw);
if (size < vec_size * block_size_nchw) { #ifdef PADDLE_CUDA_BF16
phi::ScalarGetMeanAndVarNCHW<T, AccT> if (is_same<T, phi::dtype::bfloat16>::value && data_layout_str == "NHWC") {
<<<grids, blocks, 0, stream>>>(input, mean, temp_variance, size); GroupNormNHWCKernel<phi::dtype::bfloat16, Context>(dev_ctx,
} else { x,
phi::VectorizedGetMeanAndVarNCHW<T, AccT, vec_size> scale,
<<<grids, blocks, 0, stream>>>(input, mean, temp_variance, size); bias,
} epsilon,
} else { groups,
#ifdef PADDLE_WITH_HIP data_layout_str,
hipMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups); y,
hipMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups); mean,
#else var);
cudaMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups); return;
cudaMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups); }
#endif #endif
phi::GroupNormForwardGetMeanAndVar<T, AccT> GroupNormGeneralCaseKernel<T, Context>(
<<<grid, threads, 0, stream>>>(input, dev_ctx, x, scale, bias, epsilon, groups, data_layout_str, y, mean, var);
input_ddim[0],
C,
W,
image_size,
groups,
group_size,
mean,
temp_variance);
}
GroupNormForward<T, AccT, 3>
<<<grid, threads, 0, stream>>>(input,
mean,
temp_variance,
scale,
bias,
input_ddim[0],
C,
W,
image_size,
groups,
group_size,
static_cast<AccT>(eps),
output,
variance,
data_layout);
} }
template class GroupNormDirectCUDAFunctor<float, float>;
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template class GroupNormDirectCUDAFunctor<half, float>;
#endif
} // namespace phi } // namespace phi
......
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_fp16.h>
#endif
#include <stdint.h>
namespace phi { namespace phi {
...@@ -52,4 +57,70 @@ class GroupNormDirectCUDAFunctor { ...@@ -52,4 +57,70 @@ class GroupNormDirectCUDAFunctor {
}; };
#endif #endif
template <typename T>
struct GroupNormNHWCParams {
// The output buffer. Layout NHWC.
T* dst;
// The output buffer. Layout NHWC.
T* eleOut;
// The input buffer. Layout NHWC.
T const* srcX;
// The input buffer. Layout NHWC.
T const* srcY;
// The gamma scaling factor.
void const* gamma;
// The beta term to add in GN.
void const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
float* var_data;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h, w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Silu activation function?
bool withSilu;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw, hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock, cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
// epsilon, Constant for numerical stability
float eps;
// for NCHW32 int8 use
float dqScaleIn;
float inv_qScale;
};
template <typename T>
class groupNormNHWCSum {
public:
void operator()(GroupNormNHWCParams<T>* params, const gpuStream_t stream);
};
template <typename T>
class groupNormNHWCScale {
public:
void operator()(const GroupNormNHWCParams<T>& params,
const gpuStream_t stream);
};
} // namespace phi } // namespace phi
...@@ -511,7 +511,11 @@ class GroupNorm(Layer): ...@@ -511,7 +511,11 @@ class GroupNorm(Layer):
"Mean": mean_out, "Mean": mean_out,
"Variance": variance_out, "Variance": variance_out,
}, },
attrs={"epsilon": self._epsilon, "groups": self._num_groups}, attrs={
"epsilon": self._epsilon,
"groups": self._num_groups,
"data_layout": self._data_format,
},
) )
return self._helper.append_activation(group_norm_out, None) return self._helper.append_activation(group_norm_out, None)
......
...@@ -350,6 +350,51 @@ class TestGroupNormOp2_With_NHWC(TestGroupNormOp): ...@@ -350,6 +350,51 @@ class TestGroupNormOp2_With_NHWC(TestGroupNormOp):
self.data_format = "NHWC" self.data_format = "NHWC"
class TestGroupNormFP16Op_With_NHWC(TestGroupNormFP16OP):
def init_test_case(self):
self.attrs['groups'] = 1
self.data_format = "NHWC"
self.attrs['epsilon'] = 0.5
self.shape = (1, 100, 4, 4)
self.dtype = np.float16
class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op):
def setUp(self):
self.op_type = "group_norm"
self.python_api = group_norm_wrapper
self.python_out_sig = ["Y"]
self.data_format = "NHWC"
self.dtype = np.uint16
self.shape = (1, 3, 5, 100)
self.attrs = {
'epsilon': 5e-2,
'groups': 2,
'data_layout': self.data_format,
}
self.compare_between_place = False
self.init_test_case()
input = np.random.random(self.shape).astype(np.float32)
scale = np.random.random([self.shape[3]]).astype(np.float32)
bias = np.random.random([self.shape[3]]).astype(np.float32)
output, mean, var = group_norm_naive(
input,
scale,
bias,
self.attrs['epsilon'],
self.attrs['groups'],
self.data_format,
)
self.inputs = {
'X': convert_float_to_uint16(input),
'Scale': convert_float_to_uint16(scale),
'Bias': convert_float_to_uint16(bias),
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp): class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp):
def init_test_case(self): def init_test_case(self):
self.attrs['groups'] = 1 self.attrs['groups'] = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册