From e35afed7e0d6875f9ee791dbf02cc15dd2a75a7a Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 21 Mar 2023 15:57:45 +0800 Subject: [PATCH] [Paddle-TRT] fix GN when params.c% params.cPerBlock != 0 (#51836) * fix GN when params.c% params.cPerBlock != 0 * fix GN when params.cnot divisable by params.cPerBlock --- .../inference/tensorrt/plugin/group_norm_op_plugin.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index 279a0058963..219868c49b4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -159,7 +159,7 @@ void groupNormNHWCSum(const GroupNormNHWCParams ¶ms, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. grid.y = divUp(params.hw, params.hwPerBlock); // The number of instances. @@ -286,7 +286,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams ¶ms, dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. grid.y = divUp(params.hw, params.hwPerBlock); // The number of instances. @@ -410,7 +410,7 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams ¶ms, dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. grid.y = divUp(params.hw, params.hwPerBlock); // The number of instances. @@ -516,7 +516,7 @@ void groupNormNHWCScale(const GroupNormNHWCParams ¶ms, dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = divUp(params.c, params.cPerBlock); // The number of blocks to compute all the activations in a given instance. grid.y = divUp(params.hw, params.hwPerBlock); // The number of instances. -- GitLab