From acab7daf0c0fda7a8fa82b1df8d82c9bbf83e8d2 Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 10 Jan 2023 10:34:43 +0800 Subject: [PATCH] gn bug fix (#49658) * gn bug fix * bug fix * gn bug fix --- .../tensorrt/plugin/group_norm_op_plugin.cu | 19 ++++++++++++++++--- .../plugin/preln_groupnorm_act_op_plugin.cu | 9 +++++++++ .../plugin/skip_groupnorm_act_op_plugin.cu | 9 +++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index e8c9c593a4..77c00d47d4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -178,6 +178,13 @@ void groupNormNHWCSum(const GroupNormNHWCParams ¶ms, cudaStream_t stream) { case 128: groupNormNHWCSumKernel<64><<>>(params); break; + case 8: + groupNormNHWCSumKernel<4><<>>(params); + break; + default: + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCSum of GroupNormPlugin TRT Plugin " + "encounter error")); } } @@ -277,10 +284,13 @@ void groupNormNHWCScale(const GroupNormNHWCParams ¶ms, case 128: groupNormNHWCScaleKernel<64><<>>(params); break; + case 8: + groupNormNHWCScaleKernel<4><<>>(params); + break; default: - PADDLE_THROW( - platform::errors::Fatal("The function groupNormNHWCScale of " - "GroupNorm TRT Plugin encounter error")); + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCScale of GroupNormPlugin TRT Plugin " + "encounter error")); } } @@ -610,6 +620,9 @@ int GroupNormPluginDynamic::enqueue( default: cPerBlock = 320; } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } params_.withSwish = false; params_.dst = static_cast(outputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu index 3e1c8aa5f8..86d57d4da1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -254,6 +254,9 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, case 128: prelnGroupNormNHWCSumKernel<64><<>>(params); break; + case 8: + prelnGroupNormNHWCSumKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " @@ -375,6 +378,9 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, case 128: prelnGroupNormNHWCScaleKernel<64><<>>(params); break; + case 8: + prelnGroupNormNHWCScaleKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " @@ -415,6 +421,9 @@ int PrelnGroupnormActPluginDynamic::enqueue( default: cPerBlock = 320; } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } params_.withSwish = true; params_.dst = static_cast(outputs[1]); params_.eleOut = static_cast(outputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu index 7acd250c50..cb20f75f8f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu @@ -264,6 +264,9 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, case 128: skipGroupNormNHWCSumKernel<64><<>>(params); break; + case 8: + skipGroupNormNHWCSumKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " @@ -384,6 +387,9 @@ void skipGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, case 128: skipGroupNormNHWCScaleKernel<64><<>>(params); break; + case 8: + skipGroupNormNHWCScaleKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " @@ -423,6 +429,9 @@ int SkipGroupnormActPluginDynamic::enqueue( default: cPerBlock = 320; } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } params_.withSwish = true; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); -- GitLab