From e61d892ad97d5b9adced1f469189a11abc0bb32b Mon Sep 17 00:00:00 2001 From: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:29:26 +0800 Subject: [PATCH] [Inference] Replace groupNorm when data types are bf16 and fp16, and data format is NHWC implementation. (#55399) * finish * cpergroup odd * fix bf16 * single channel * code style * jingdu duiqi * add head_file * add bf16 head file * bf16 2 * bf16 * bf16 head * bf16 compile * py test * bf16 compile * bf16 compile * unset py test * nhwc * test * mean var * bf16 success * su * ctest success * use is_same_as * is_same * use is_same * rtol * gpu_stream * del sigmod * fix bfloat16 type * use cuda_bf16_hpp * use_cuda_arch * bfloat162float2 * del inplace_tol * del max_releative_tol * temp store * jingdu duiqi * temp store * plugin * jingdu duiqi * duiqi * include cuda.h * del half * half single * ci * add const * ci * cudamemset * del printf * fp16 test * add half compute * del br16 ci * del ci * ci approve * del fluid include --- .../plugin/common/groupNormPluginCommon.h | 80 -- .../tensorrt/plugin/group_norm_op_plugin.cu | 244 +---- .../tensorrt/plugin/group_norm_op_plugin.h | 6 +- .../plugin/preln_groupnorm_act_op_plugin.cu | 10 +- .../plugin/preln_groupnorm_act_op_plugin.h | 5 +- .../plugin/skip_groupnorm_act_op_plugin.cu | 9 +- .../plugin/skip_groupnorm_act_op_plugin.h | 5 +- paddle/phi/kernels/gpu/group_norm_kernel.cu | 932 ++++++++++++++++-- paddle/phi/kernels/group_norm_kernel.h | 71 ++ python/paddle/nn/layer/norm.py | 6 +- test/legacy_test/test_group_norm_op.py | 45 + 11 files changed, 979 insertions(+), 434 deletions(-) delete mode 100644 paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h diff --git a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h deleted file mode 100644 index 1ba134a6fce..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. - __half* dst; - // The output buffer. Layout NHWC. - __half* eleOut; - // The input buffer. Layout NHWC. - __half const* srcX; - // The input buffer. Layout NHWC. - __half const* srcY; - // The gamma scaling factor. - void const* gamma; - // The beta term to add in GN. - void const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h, w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Silu activation function? - bool withSilu; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw, hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock, cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; - // epsilon, Constant for numerical stability - float eps; - // for NCHW32 int8 use - float dqScaleIn; - float inv_qScale; -}; - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index 219868c49b4..4d5517ef111 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -73,123 +73,8 @@ static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { } template -__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { - // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + - static_cast(hwi) * params.c + ci; - - // Fetch two channels per thread. - __half2 h2(0, 0); - if (ci < params.c) { - h2 = *reinterpret_cast<__half2 const *>(¶ms.srcX[offset]); - } - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Update the sum. - sum += f2.x + f2.y; - // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - // 2 channels per thread - if (cj == params.cPerGroup - 2) { - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); -} - -void groupNormNHWCSum(const GroupNormNHWCParams ¶ms, cudaStream_t stream) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = divUp(params.c, params.cPerBlock); - // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); - // The number of instances. - grid.z = params.n; - - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<160><<>>(params); - break; - case 480: - groupNormNHWCSumKernel<256><<>>(params); - break; - case 256: - groupNormNHWCSumKernel<128><<>>(params); - break; - case 128: - groupNormNHWCSumKernel<64><<>>(params); - break; - case 8: - groupNormNHWCSumKernel<4><<>>(params); - break; - default: - PADDLE_THROW(platform::errors::Fatal( - "The function groupNormNHWCSum of GroupNormPlugin TRT Plugin " - "encounter error")); - } -} - -template -__global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) { +__global__ void groupNormNCHW32SumKernelQDQ( + const GroupNormNHWCParams<__half> params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -281,7 +166,7 @@ __global__ void groupNormNCHW32SumKernelQDQ(const GroupNormNHWCParams params) { atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void groupNormNCHW32SumQDQ(const GroupNormNHWCParams ¶ms, +void groupNormNCHW32SumQDQ(const GroupNormNHWCParams<__half> ¶ms, cudaStream_t stream) { dim3 grid; @@ -313,7 +198,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams ¶ms, template __global__ void groupNormNCHW32ScaleKernelQDQ( - const GroupNormNHWCParams params) { + const GroupNormNHWCParams<__half> params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -405,7 +290,7 @@ __global__ void groupNormNCHW32ScaleKernelQDQ( } } -void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams ¶ms, +void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams<__half> ¶ms, cudaStream_t stream) { dim3 grid; @@ -439,112 +324,6 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams ¶ms, } } -template -__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; - } - - // Load gamma/beta. - float2 gammaF2, betaF2; - if (ci < params.c) { - gammaF2 = __half22float2(*reinterpret_cast( - reinterpret_cast(params.gamma) + ci)); - betaF2 = __half22float2(*reinterpret_cast( - reinterpret_cast(params.beta) + ci)); - } - - // Compute the mean. - float mean = sum * params.invHWC; - // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = rsqrtf(var + params.eps); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - __half2 h2(0, 0); - if (ci < params.c) { - h2 = *reinterpret_cast<__half2 const *>(¶ms.srcX[offset]); - } - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; - - // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; - - // Apply Silu if needed. - if (params.withSilu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - // Store the scaled values. - if (ci < params.c) { - *reinterpret_cast<__half2 *>(¶ms.dst[offset]) = __float22half2_rn(f2); - } - } -} - -void groupNormNHWCScale(const GroupNormNHWCParams ¶ms, - cudaStream_t stream) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = divUp(params.c, params.cPerBlock); - // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); - // The number of instances. - grid.z = params.n; - - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<160><<>>(params); - break; - case 480: - groupNormNHWCScaleKernel<256><<>>(params); - break; - case 256: - groupNormNHWCScaleKernel<128><<>>(params); - break; - case 128: - groupNormNHWCScaleKernel<64><<>>(params); - break; - case 8: - groupNormNHWCScaleKernel<4><<>>(params); - break; - default: - PADDLE_THROW(platform::errors::Fatal( - "The function groupNormNHWCScale of GroupNormPlugin TRT Plugin " - "encounter error")); - } -} - int GroupNormPlugin::initialize() TRT_NOEXCEPT { if (!with_fp16_) { // if use fp32 @@ -886,9 +665,10 @@ int GroupNormPluginDynamic::enqueue( params_.withSilu = with_silu_; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); - params_.gamma = scale_gpu_; - params_.beta = bias_gpu_; + params_.gamma = reinterpret_cast(scale_gpu_); + params_.beta = reinterpret_cast(bias_gpu_); params_.redBuffer = static_cast(workspace); + params_.var_data = nullptr; params_.n = input_desc[0].dims.d[0]; params_.h = input_desc[0].dims.d[2]; params_.w = input_desc[0].dims.d[3]; @@ -903,13 +683,17 @@ int GroupNormPluginDynamic::enqueue( params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); params_.groupsPerBlock = cPerBlock / params_.cPerGroup; params_.eps = eps_; + params_.var_data = nullptr; cudaMemsetAsync(params_.redBuffer, 0, 2 * sizeof(float) * params_.n * groups_, stream); - groupNormNHWCSum(params_, stream); - groupNormNHWCScale(params_, stream); + + phi::groupNormNHWCSum nhwc_sum; + nhwc_sum(¶ms_, stream); + phi::groupNormNHWCScale nhwc_scale; + nhwc_scale(params_, stream); } else { PADDLE_THROW(platform::errors::Fatal( "The Groupnorm TRT Plugin's only support nchw or nhwc8 input")); diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h index 8f521c8bf58..e76d802f853 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h @@ -21,13 +21,15 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/phi/kernels/group_norm_kernel.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { + +using phi::GroupNormNHWCParams; class GroupNormPlugin : public PluginTensorRT { public: size_t getSerializationSize() const TRT_NOEXCEPT override { @@ -287,7 +289,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { float eps_; std::vector mean_shape_; std::vector variance_shape_; - GroupNormNHWCParams params_; + GroupNormNHWCParams params_; bool with_silu_; bool with_fp16_; bool with_int8_; diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu index d3ca36770a4..01a91662c2f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -120,7 +120,8 @@ struct GroupSumsOp { }; template -__global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { +__global__ void prelnGroupNormNHWCSumKernel( + GroupNormNHWCParams<__half> params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -212,7 +213,7 @@ __global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void prelnGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, +void prelnGroupNormNHWCSum(GroupNormNHWCParams<__half> const ¶ms, cudaStream_t stream) { // Make sure the values are as we expect. PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, @@ -272,7 +273,8 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, } template -__global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { +__global__ void prelnGroupNormNHWCScaleKernel( + GroupNormNHWCParams<__half> params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -343,7 +345,7 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { } } -void prelnGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, +void prelnGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, cudaStream_t stream) { // Make sure the dimensions are aligned with what we expect. PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h index 501372b9c32..e4c76e2d652 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h @@ -21,13 +21,14 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/phi/kernels/group_norm_kernel.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { +using phi::GroupNormNHWCParams; class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { public: PrelnGroupnormActPluginDynamic(const float* scale, @@ -173,7 +174,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { std::vector bias_; std::shared_ptr scale_gpu_; std::shared_ptr bias_gpu_; - GroupNormNHWCParams params_; + GroupNormNHWCParams<__half> params_; int groups_; float eps_; bool with_silu_; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu index 997205e9189..45bd8688da1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu @@ -131,7 +131,7 @@ struct GroupSumsOp { }; template -__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { +__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams<__half> params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -224,7 +224,7 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void skipGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, +void skipGroupNormNHWCSum(GroupNormNHWCParams<__half> const ¶ms, cudaStream_t stream) { // Make sure the values are as we expect. PADDLE_ENFORCE_EQ( @@ -282,7 +282,8 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, } template -__global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { +__global__ void skipGroupNormNHWCScaleKernel( + GroupNormNHWCParams<__half> params) { // The instance in the batch. int32_t ni = blockIdx.z; // The channel loaded by that thread (2 channels per thread for F16x2). @@ -353,7 +354,7 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { } } -void skipGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, +void skipGroupNormNHWCScale(GroupNormNHWCParams<__half> const ¶ms, cudaStream_t stream) { // Make sure the dimensions are aligned with what we expect. PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h index 5ed9dd14e71..0a93559f5ee 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h @@ -21,13 +21,14 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/phi/kernels/group_norm_kernel.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { +using phi::GroupNormNHWCParams; class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { public: SkipGroupnormActPluginDynamic(const float* scale, @@ -168,7 +169,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { std::vector bias_; std::shared_ptr scale_gpu_; std::shared_ptr bias_gpu_; - GroupNormNHWCParams params_; + GroupNormNHWCParams<__half> params_; int groups_; float eps_; bool with_fp16_; diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index ef39abd9394..73ae42c549b 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -26,6 +26,676 @@ namespace phi { +static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; } + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +#ifdef PADDLE_CUDA_BF16 +__host__ __device__ inline float2 bfloat1622float2(const __nv_bfloat162 a) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __bfloat1622float2(a); +#else + float hi_float; + float lo_float; + lo_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).x); + hi_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).y); + return make_float2(lo_float, hi_float); +#endif +} + +__host__ __device__ inline __nv_bfloat162 float22bfloat162_rn(const float2 a) { + __nv_bfloat162 val; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + val = __float22bfloat162_rn(a); +#else + val.x = __float2bfloat16_rn(a.x); + val.y = __float2bfloat16_rn(a.y); +#endif + return val; +} + +#endif + +template +__host__ __device__ inline float __2float(const T a) { + return static_cast(a); +} + +template <> +__host__ __device__ inline float __2float<__half>(const __half a) { + return __half2float(a); +} + +template +__host__ __device__ inline T __2dst(const float a) { + return static_cast(a); +} + +template <> +__host__ __device__ inline __half __2dst<__half>(const float a) { + return __float2half(a); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, + GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +template +inline __device__ void UpdateSum(const T* srcX, float* sum, float* sumSq) { + float src_data = phi::__2float(*srcX); + *sum += src_data; + *sumSq += src_data * src_data; +} + +template <> +inline __device__ void UpdateSum<__half, 2>(const __half* srcX, + float* sum, + float* sumSq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(srcX); + float2 f2 = __half22float2(h2); + *sum += f2.x + f2.y; + *sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum( + const phi::dtype::float16* srcX, float* sum, float* sumSq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(srcX); + float2 f2 = __half22float2(h2); + *sum += f2.x + f2.y; + *sumSq += f2.x * f2.x + f2.y * f2.y; +} + +#ifdef PADDLE_CUDA_BF16 +template <> +inline __device__ void UpdateSum( + const phi::dtype::bfloat16* srcX, float* sum, float* sumSq) { + __nv_bfloat162 h2 = *reinterpret_cast<__nv_bfloat162 const*>(srcX); + float2 f2 = phi::bfloat1622float2(h2); + *sum += f2.x + f2.y; + *sumSq += f2.x * f2.x + f2.y * f2.y; +} +#endif + +template +__global__ void groupNormNHWCSumSingerChannelKernel( + const GroupNormNHWCParams params) { + // The instance in the batch. + __shared__ float2 smem[THREADS_PER_BLOCK]; + int32_t ni = blockIdx.z; + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x; + if (ci >= params.c) { + return; + } + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + + static_cast(hwi) * params.c + ci; + float src_data = *reinterpret_cast(¶ms.srcX[offset]); + UpdateSum(¶ms.srcX[offset], &sum, &sumSq); + } + + smem[threadIdx.x] = make_float2(sum, sumSq); + + __syncthreads(); + + float2 sums = smem[threadIdx.x]; + + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + ci], + sums.x * params.invHWC); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + ci], sums.y); +} + +template +__global__ void groupNormNHWCSumKernel(const GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for BlockScan. + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = + blockIdx.x * params.cPerBlock + threadIdx.x * THREADS_PER_CHANNEL; + if (ci >= params.c || threadIdx.x * THREADS_PER_CHANNEL >= params.cPerBlock) { + return; + } + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + + static_cast(hwi) * params.c + ci; + float src_data = *reinterpret_cast(¶ms.srcX[offset]); + UpdateSum(¶ms.srcX[offset], &sum, &sumSq); + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = + ci / params.cPerGroup - blockIdx.x * params.cPerBlock / params.cPerGroup; + int32_t cj = ci % params.cPerGroup; + int flag = (cj == 0 || threadIdx.x == 0) ? 1 : 0; + GroupSums inp{flag, sum, sumSq}; + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + if (cj == params.cPerGroup - THREADS_PER_CHANNEL || + threadIdx.x * THREADS_PER_CHANNEL == + params.cPerBlock - THREADS_PER_CHANNEL) { + smem[gi] = make_float2(out.sum, out.sumSq); + } + + __syncthreads(); + + int32_t gj = ci / params.cPerGroup; + if (cj == params.cPerGroup - THREADS_PER_CHANNEL || + threadIdx.x * THREADS_PER_CHANNEL == + params.cPerBlock - THREADS_PER_CHANNEL) { + float2 sums = smem[gi]; + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], + sums.x * params.invHWC); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + } +} + +template +void groupNormNHWCSum::operator()(GroupNormNHWCParams* params, + gpuStream_t stream) { + dim3 grid; + grid.x = divUp(params->c, params->cPerBlock); + grid.y = divUp(params->hw, params->hwPerBlock); + grid.z = params->n; + if (params->cPerGroup % 2 == 0) { + switch (params->cPerBlock) { + case 512: + case 480: + groupNormNHWCSumKernel<<>>(*params); + break; + case 320: + groupNormNHWCSumKernel<<>>(*params); + break; + case 256: + groupNormNHWCSumKernel<<>>(*params); + break; + case 128: + groupNormNHWCSumKernel<<>>(*params); + break; + default: + grid.x = divUp(params->c, 128); + params->cPerBlock = 128; + groupNormNHWCSumKernel<<>>(*params); + } + } else { + if (params->cPerGroup != 1) { + switch (params->cPerBlock) { + case 512: + groupNormNHWCSumKernel<<>>(*params); + break; + case 480: + groupNormNHWCSumKernel<<>>(*params); + break; + case 320: + groupNormNHWCSumKernel<<>>(*params); + break; + case 256: + groupNormNHWCSumKernel<<>>(*params); + break; + case 128: + groupNormNHWCSumKernel<<>>(*params); + break; + default: + grid.x = divUp(params->c, 128); + params->cPerBlock = 128; + groupNormNHWCSumKernel<<>>(*params); + } + } else { + switch (params->cPerBlock) { + case 512: + groupNormNHWCSumSingerChannelKernel + <<>>(*params); + break; + case 480: + groupNormNHWCSumSingerChannelKernel + <<>>(*params); + break; + case 320: + groupNormNHWCSumSingerChannelKernel + <<>>(*params); + break; + case 256: + groupNormNHWCSumSingerChannelKernel + <<>>(*params); + break; + case 128: + groupNormNHWCSumSingerChannelKernel + <<>>(*params); + break; + default: + grid.x = divUp(params->c, 128); + params->cPerBlock = 128; + groupNormNHWCSumSingerChannelKernel + <<>>(*params); + } + } + } +} +template class groupNormNHWCSum; + +template +inline __device__ void GroupNormCompute(int32_t hwBegin, + int32_t hwEnd, + int32_t ci, + const GroupNormNHWCParams& params, + float mean, + float invStdDev) { + float gamma = + phi::__2float(*(reinterpret_cast(params.gamma) + ci)); + float beta = + phi::__2float(*(reinterpret_cast(params.beta) + ci)); + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + const float src_data = phi::__2float(params.srcX[offset]); + // Normalize the channels. + float dst_data = (src_data - mean) * invStdDev; + // Scale by gamma and add beta. + dst_data = gamma * dst_data + beta; + + // Apply Silu if needed. + if (params.withSilu) { + dst_data = dst_data * sigmoid(dst_data); + } + + // Store the scaled values. + *reinterpret_cast(¶ms.dst[offset]) = phi::__2dst(dst_data); + } +} + +template <> +inline __device__ void GroupNormCompute( + int32_t hwBegin, + int32_t hwEnd, + int32_t ci, + const GroupNormNHWCParams& params, + float mean, + float invStdDev) { + float2 gammaF2, betaF2; + gammaF2 = __half22float2(*reinterpret_cast<__half2 const*>( + reinterpret_cast(params.gamma) + ci)); + betaF2 = __half22float2(*reinterpret_cast<__half2 const*>( + reinterpret_cast(params.beta) + ci)); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(¶ms.srcX[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Silu if needed. + if (params.withSilu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + // Store the scaled values. + *reinterpret_cast<__half2*>(¶ms.dst[offset]) = __float22half2_rn(f2); + } +} + +template <> +inline __device__ void GroupNormCompute<__half, 2>( + int32_t hwBegin, + int32_t hwEnd, + int32_t ci, + const GroupNormNHWCParams<__half>& params, + float mean, + float invStdDev) { + float2 gammaF2, betaF2; + gammaF2 = __half22float2(*reinterpret_cast<__half2 const*>( + reinterpret_cast(params.gamma) + ci)); + betaF2 = __half22float2(*reinterpret_cast<__half2 const*>( + reinterpret_cast(params.beta) + ci)); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(¶ms.srcX[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Silu if needed. + if (params.withSilu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + // Store the scaled values. + *reinterpret_cast<__half2*>(¶ms.dst[offset]) = __float22half2_rn(f2); + } +} + +#ifdef PADDLE_CUDA_BF16 +template <> +inline __device__ void GroupNormCompute( + int32_t hwBegin, + int32_t hwEnd, + int32_t ci, + const GroupNormNHWCParams& params, + float mean, + float invStdDev) { + float2 gammaF2, betaF2; + gammaF2 = phi::bfloat1622float2(*reinterpret_cast<__nv_bfloat162 const*>( + reinterpret_cast<__nv_bfloat16 const*>(params.gamma) + ci)); + betaF2 = phi::bfloat1622float2(*reinterpret_cast<__nv_bfloat162 const*>( + reinterpret_cast<__nv_bfloat16 const*>(params.beta) + ci)); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)blockIdx.z * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __nv_bfloat162 h2 = + *reinterpret_cast<__nv_bfloat162 const*>(¶ms.srcX[offset]); + + // Extract the two half values. + float2 f2 = phi::bfloat1622float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Silu if needed. + if (params.withSilu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + // Store the scaled values. + *reinterpret_cast<__nv_bfloat162*>(¶ms.dst[offset]) = + phi::float22bfloat162_rn(f2); + } +} +#endif + +template +__global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = + blockIdx.x * params.cPerBlock + threadIdx.x * THREADS_PER_CHANNEL; + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + if (ci >= params.c || gi >= params.groups) { + return; + } + + // Load the sum and sum of squares for the group. + + float mean = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + float sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + + if (params.var_data != nullptr) { + params.var_data[ni * params.groups + gi] = var; + } + // Compute the inverse of the stddev. + float invStdDev = rsqrtf(var + params.eps); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + GroupNormCompute( + hwBegin, hwEnd, ci, params, mean, invStdDev); +} + +template +void groupNormNHWCScale::operator()(const GroupNormNHWCParams& params, + gpuStream_t stream) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = divUp(params.c, params.cPerBlock); + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + if (params.cPerGroup % 2 == 0) { + switch (params.cPerBlock) { + case 512: + case 480: + groupNormNHWCScaleKernel<<>>(params); + break; + case 320: + groupNormNHWCScaleKernel<<>>(params); + break; + case 256: + groupNormNHWCScaleKernel<<>>(params); + break; + case 128: + groupNormNHWCScaleKernel<<>>(params); + break; + default: + grid.x = divUp(params.c, 128); + groupNormNHWCScaleKernel<<>>(params); + } + } else { + switch (params.cPerBlock) { + case 512: + groupNormNHWCScaleKernel<<>>(params); + break; + case 480: + groupNormNHWCScaleKernel<<>>(params); + break; + case 320: + groupNormNHWCScaleKernel<<>>(params); + break; + case 256: + groupNormNHWCScaleKernel<<>>(params); + break; + case 128: + groupNormNHWCScaleKernel<<>>(params); + break; + default: + grid.x = divUp(params.c, 128); + groupNormNHWCScaleKernel<<>>(params); + } + } +} +template class groupNormNHWCScale; + +template +void GroupNormNHWCKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + using AccT = typename phi::dtype::MPTypeTrait::Type; + GroupNormNHWCParams params_; + params_.withSilu = false; + + const auto x_dims = x.dims(); + dev_ctx.template Alloc(y); + const T* x_data = x.data(); + T* y_data = y->data(); + const auto scale_ptr = scale.get_ptr(); + const auto bias_ptr = bias.get_ptr(); + const T* scale_data = nullptr; + if (scale_ptr) scale_data = scale_ptr->data(); + const T* bias_data = nullptr; + if (bias_ptr) bias_data = bias_ptr->data(); + params_.n = x_dims[0]; + params_.c = x_dims[3]; + params_.h = x_dims[1]; + params_.w = x_dims[2]; + + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(var); + auto* mean_data = mean->data(); + auto* var_data = var->data(); + params_.var_data = var_data; + + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + switch (params_.c) { + case 2048: + case 1024: + cPerBlock = 512; + break; + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + params_.groups = groups; + params_.cPerGroup = params_.c / params_.groups; + if (cPerBlock % params_.cPerGroup != 0) { + cPerBlock = params_.cPerGroup; + } + params_.srcX = reinterpret_cast(x_data); + params_.dst = reinterpret_cast(y_data); + + params_.gamma = scale_data; + params_.beta = bias_data; + params_.hw = params_.h * params_.w; + const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); + params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.cPerBlock = cPerBlock; + params_.hwc = params_.hw * params_.c; + params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.eps = epsilon; + auto stream = dev_ctx.stream(); + DenseTensor redBuffer; + int buffer_sizes = 2 * params_.n * groups; + redBuffer.Resize({1, buffer_sizes}); + params_.redBuffer = dev_ctx.template Alloc(&redBuffer); +#ifdef PADDLE_WITH_HIP + hipMemset(params_.redBuffer, 0, buffer_sizes * sizeof(float)); +#else + cudaMemset(params_.redBuffer, 0, buffer_sizes * sizeof(float)); +#endif + groupNormNHWCSum nhwc_sum; + nhwc_sum(¶ms_, stream); + groupNormNHWCScale nhwc_scale; + nhwc_scale(params_, stream); +#ifdef PADDLE_WITH_HIP + phi::backends::gpu::GpuMemcpyAsync(mean_data, + params_.redBuffer, + params_.n * groups * sizeof(float), + hipMemcpyDeviceToHost, + stream); +#else + phi::backends::gpu::GpuMemcpyAsync(mean_data, + params_.redBuffer, + params_.n * groups * sizeof(float), + cudaMemcpyDeviceToHost, + stream); +#endif +} + template __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, @@ -117,27 +787,128 @@ __global__ void GroupNormForward(const T* x, } } +template +void GroupNormDirectCUDAFunctor::operator()( + gpuStream_t stream, + const T* input, + std::vector input_shape, + const T* bias, + const T* scale, + AccT* temp_variance, + int groups, + float eps, + T* output, + AccT* mean, + AccT* variance, + const DataLayout data_layout) { + const auto input_ddim = phi::make_ddim(input_shape); + const int C = + (data_layout == DataLayout::kNCHW ? input_ddim[1] + : input_ddim[input_ddim.size() - 1]); + const int group_size = C / groups; + const int W = + (data_layout == DataLayout::kNCHW ? input_ddim[input_ddim.size() - 1] + : input_ddim[input_ddim.size() - 2]); + + int image_size = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < input_ddim.size(); ++i) { + image_size *= input_ddim[i]; + } + } else { + for (int i = 1; i < input_ddim.size() - 1; ++i) { + image_size *= input_ddim[i]; + } + } +#ifdef __HIPCC__ + int block_size = std::max(std::min(256, image_size), 64); +#else + int block_size = std::min(1024, image_size); +#endif + dim3 grid(group_size, groups, input_ddim[0]); + dim3 threads(block_size, 1, 1); + if (data_layout == DataLayout::kNCHW) { + constexpr int vec_size = sizeof(float4) / sizeof(T); + int size = group_size * image_size; // group element size + const int max_num_threads = 1024; + int max_block_size = std::min(size / vec_size, max_num_threads); + int block_size_nchw = 1; + while (block_size_nchw < max_block_size) { + block_size_nchw *= 2; + } + + block_size_nchw = std::max(block_size_nchw, phi::kps::details::kWarpSize); + dim3 grids(input_ddim[0] * groups); + dim3 blocks(block_size_nchw); + + if (size < vec_size * block_size_nchw) { + phi::ScalarGetMeanAndVarNCHW + <<>>(input, mean, temp_variance, size); + } else { + phi::VectorizedGetMeanAndVarNCHW + <<>>(input, mean, temp_variance, size); + } + } else { +#ifdef PADDLE_WITH_HIP + hipMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups); + hipMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups); +#else + cudaMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups); + cudaMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups); +#endif + + phi::GroupNormForwardGetMeanAndVar + <<>>(input, + input_ddim[0], + C, + W, + image_size, + groups, + group_size, + mean, + temp_variance); + } + GroupNormForward + <<>>(input, + mean, + temp_variance, + scale, + bias, + input_ddim[0], + C, + W, + image_size, + groups, + group_size, + static_cast(eps), + output, + variance, + data_layout); +} +template class GroupNormDirectCUDAFunctor; +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) +template class GroupNormDirectCUDAFunctor; +#endif + template -void GroupNormKernel(const Context& dev_ctx, - const DenseTensor& x, - const paddle::optional& scale, - const paddle::optional& bias, - float epsilon, - int groups, - const std::string& data_layout_str, - DenseTensor* y, - DenseTensor* mean, - DenseTensor* var) { +void GroupNormGeneralCaseKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { using AccT = typename phi::dtype::MPTypeTrait::Type; const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const auto scale_ptr = scale.get_ptr(); const auto bias_ptr = bias.get_ptr(); - const auto x_dims = x.dims(); const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); const int group_size = C / groups; - const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); @@ -235,108 +1006,51 @@ void GroupNormKernel(const Context& dev_ctx, data_layout); } -template -void GroupNormDirectCUDAFunctor::operator()( - gpuStream_t stream, - const T* input, - std::vector input_shape, - const T* bias, - const T* scale, - AccT* temp_variance, - int groups, - float eps, - T* output, - AccT* mean, - AccT* variance, - const DataLayout data_layout) { - const auto input_ddim = phi::make_ddim(input_shape); - const int C = - (data_layout == DataLayout::kNCHW ? input_ddim[1] - : input_ddim[input_ddim.size() - 1]); - const int group_size = C / groups; - const int W = - (data_layout == DataLayout::kNCHW ? input_ddim[input_ddim.size() - 1] - : input_ddim[input_ddim.size() - 2]); - - int image_size = 1; - if (data_layout == DataLayout::kNCHW) { - for (int i = 2; i < input_ddim.size(); ++i) { - image_size *= input_ddim[i]; - } - } else { - for (int i = 1; i < input_ddim.size() - 1; ++i) { - image_size *= input_ddim[i]; - } +template +void GroupNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon, + int groups, + const std::string& data_layout_str, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + using std::is_same; + if (is_same::value && data_layout_str == "NHWC") { + GroupNormNHWCKernel(dev_ctx, + x, + scale, + bias, + epsilon, + groups, + data_layout_str, + y, + mean, + var); + return; } -#ifdef __HIPCC__ - int block_size = std::max(std::min(256, image_size), 64); -#else - int block_size = std::min(1024, image_size); -#endif - dim3 grid(group_size, groups, input_ddim[0]); - dim3 threads(block_size, 1, 1); - if (data_layout == DataLayout::kNCHW) { - constexpr int vec_size = sizeof(float4) / sizeof(T); - int size = group_size * image_size; // group element size - const int max_num_threads = 1024; - int max_block_size = std::min(size / vec_size, max_num_threads); - int block_size_nchw = 1; - while (block_size_nchw < max_block_size) { - block_size_nchw *= 2; - } - - block_size_nchw = std::max(block_size_nchw, phi::kps::details::kWarpSize); - dim3 grids(input_ddim[0] * groups); - dim3 blocks(block_size_nchw); - if (size < vec_size * block_size_nchw) { - phi::ScalarGetMeanAndVarNCHW - <<>>(input, mean, temp_variance, size); - } else { - phi::VectorizedGetMeanAndVarNCHW - <<>>(input, mean, temp_variance, size); - } - } else { -#ifdef PADDLE_WITH_HIP - hipMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups); - hipMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups); -#else - cudaMemset(mean, 0, sizeof(AccT) * input_ddim[0] * groups); - cudaMemset(temp_variance, 0, sizeof(AccT) * input_ddim[0] * groups); +#ifdef PADDLE_CUDA_BF16 + if (is_same::value && data_layout_str == "NHWC") { + GroupNormNHWCKernel(dev_ctx, + x, + scale, + bias, + epsilon, + groups, + data_layout_str, + y, + mean, + var); + return; + } #endif - phi::GroupNormForwardGetMeanAndVar - <<>>(input, - input_ddim[0], - C, - W, - image_size, - groups, - group_size, - mean, - temp_variance); - } - GroupNormForward - <<>>(input, - mean, - temp_variance, - scale, - bias, - input_ddim[0], - C, - W, - image_size, - groups, - group_size, - static_cast(eps), - output, - variance, - data_layout); + GroupNormGeneralCaseKernel( + dev_ctx, x, scale, bias, epsilon, groups, data_layout_str, y, mean, var); } -template class GroupNormDirectCUDAFunctor; -#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) -template class GroupNormDirectCUDAFunctor; -#endif } // namespace phi diff --git a/paddle/phi/kernels/group_norm_kernel.h b/paddle/phi/kernels/group_norm_kernel.h index f3e39ddbeb3..9acdeca0e67 100644 --- a/paddle/phi/kernels/group_norm_kernel.h +++ b/paddle/phi/kernels/group_norm_kernel.h @@ -18,6 +18,11 @@ #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/core/dense_tensor.h" +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif +#include namespace phi { @@ -52,4 +57,70 @@ class GroupNormDirectCUDAFunctor { }; #endif +template +struct GroupNormNHWCParams { + // The output buffer. Layout NHWC. + T* dst; + // The output buffer. Layout NHWC. + T* eleOut; + // The input buffer. Layout NHWC. + T const* srcX; + // The input buffer. Layout NHWC. + T const* srcY; + // The gamma scaling factor. + void const* gamma; + // The beta term to add in GN. + void const* beta; + // The temporary buffer to do the global parallel reduction. Size: + // BLOCKS_PER_BATCH x C x 2. + float* redBuffer; + + float* var_data; + + // The number of instances in the batch. + int32_t n; + // The height and width of each activation map. + int32_t h, w; + // The number of channels. + int32_t c; + // The number of groups. + int32_t groups; + // Do we apply the Silu activation function? + bool withSilu; + + // Precomputed values and parameters to control the execution of the kernels. + + // The number of activations per instance (h * w) and the number of + // activations per block. + int32_t hw, hwPerBlock; + // The number of channels per group and blocks per activation in the C + // dimension. + int32_t cPerBlock, cPerGroup; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hwc in floats (to compute mean/var). + float invHWC; + // The precomputed number of groups per block. + int32_t groupsPerBlock; + // epsilon, Constant for numerical stability + float eps; + // for NCHW32 int8 use + float dqScaleIn; + float inv_qScale; +}; + +template +class groupNormNHWCSum { + public: + void operator()(GroupNormNHWCParams* params, const gpuStream_t stream); +}; + +template +class groupNormNHWCScale { + public: + void operator()(const GroupNormNHWCParams& params, + const gpuStream_t stream); +}; + } // namespace phi diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index f27d109d7d0..c85fa4f60ce 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -511,7 +511,11 @@ class GroupNorm(Layer): "Mean": mean_out, "Variance": variance_out, }, - attrs={"epsilon": self._epsilon, "groups": self._num_groups}, + attrs={ + "epsilon": self._epsilon, + "groups": self._num_groups, + "data_layout": self._data_format, + }, ) return self._helper.append_activation(group_norm_out, None) diff --git a/test/legacy_test/test_group_norm_op.py b/test/legacy_test/test_group_norm_op.py index ef1c5e6384e..9b3c145b87a 100644 --- a/test/legacy_test/test_group_norm_op.py +++ b/test/legacy_test/test_group_norm_op.py @@ -350,6 +350,51 @@ class TestGroupNormOp2_With_NHWC(TestGroupNormOp): self.data_format = "NHWC" +class TestGroupNormFP16Op_With_NHWC(TestGroupNormFP16OP): + def init_test_case(self): + self.attrs['groups'] = 1 + self.data_format = "NHWC" + self.attrs['epsilon'] = 0.5 + self.shape = (1, 100, 4, 4) + self.dtype = np.float16 + + +class TestGroupNormBF16Op_With_NHWC(TestGroupNormBF16Op): + def setUp(self): + self.op_type = "group_norm" + self.python_api = group_norm_wrapper + self.python_out_sig = ["Y"] + self.data_format = "NHWC" + self.dtype = np.uint16 + self.shape = (1, 3, 5, 100) + self.attrs = { + 'epsilon': 5e-2, + 'groups': 2, + 'data_layout': self.data_format, + } + self.compare_between_place = False + self.init_test_case() + + input = np.random.random(self.shape).astype(np.float32) + scale = np.random.random([self.shape[3]]).astype(np.float32) + bias = np.random.random([self.shape[3]]).astype(np.float32) + output, mean, var = group_norm_naive( + input, + scale, + bias, + self.attrs['epsilon'], + self.attrs['groups'], + self.data_format, + ) + + self.inputs = { + 'X': convert_float_to_uint16(input), + 'Scale': convert_float_to_uint16(scale), + 'Bias': convert_float_to_uint16(bias), + } + self.outputs = {'Y': output, 'Mean': mean, 'Variance': var} + + class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp): def init_test_case(self): self.attrs['groups'] = 1 -- GitLab