diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index e8c9c593a4654c17009f9c1d0a1ba4150e7c8a00..77c00d47d4cea3ee9fdf574172ea6d2b4793b076 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -178,6 +178,13 @@ void groupNormNHWCSum(const GroupNormNHWCParams ¶ms, cudaStream_t stream) { case 128: groupNormNHWCSumKernel<64><<>>(params); break; + case 8: + groupNormNHWCSumKernel<4><<>>(params); + break; + default: + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCSum of GroupNormPlugin TRT Plugin " + "encounter error")); } } @@ -277,10 +284,13 @@ void groupNormNHWCScale(const GroupNormNHWCParams ¶ms, case 128: groupNormNHWCScaleKernel<64><<>>(params); break; + case 8: + groupNormNHWCScaleKernel<4><<>>(params); + break; default: - PADDLE_THROW( - platform::errors::Fatal("The function groupNormNHWCScale of " - "GroupNorm TRT Plugin encounter error")); + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCScale of GroupNormPlugin TRT Plugin " + "encounter error")); } } @@ -610,6 +620,9 @@ int GroupNormPluginDynamic::enqueue( default: cPerBlock = 320; } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } params_.withSwish = false; params_.dst = static_cast(outputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu index 3e1c8aa5f842b4d283ba0737371e958101448a42..86d57d4da14c145bf85611d75b8f6e79d09fc4c9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -254,6 +254,9 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, case 128: prelnGroupNormNHWCSumKernel<64><<>>(params); break; + case 8: + prelnGroupNormNHWCSumKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " @@ -375,6 +378,9 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, case 128: prelnGroupNormNHWCScaleKernel<64><<>>(params); break; + case 8: + prelnGroupNormNHWCScaleKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " @@ -415,6 +421,9 @@ int PrelnGroupnormActPluginDynamic::enqueue( default: cPerBlock = 320; } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } params_.withSwish = true; params_.dst = static_cast(outputs[1]); params_.eleOut = static_cast(outputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu index 7acd250c50cf72f3913e5cbfa72f0b80d657aba6..cb20f75f8f86d566156d8cb988e2d4f6c46b2ed8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu @@ -264,6 +264,9 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, case 128: skipGroupNormNHWCSumKernel<64><<>>(params); break; + case 8: + skipGroupNormNHWCSumKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " @@ -384,6 +387,9 @@ void skipGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, case 128: skipGroupNormNHWCScaleKernel<64><<>>(params); break; + case 8: + skipGroupNormNHWCScaleKernel<4><<>>(params); + break; default: PADDLE_THROW(platform::errors::Fatal( "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " @@ -423,6 +429,9 @@ int SkipGroupnormActPluginDynamic::enqueue( default: cPerBlock = 320; } + if (cPerBlock > input_desc[0].dims.d[1]) { + cPerBlock = 8; + } params_.withSwish = true; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]);