未验证 提交 e35afed7 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] fix GN when params.c% params.cPerBlock != 0 (#51836)

* fix GN when params.c% params.cPerBlock != 0

* fix GN when params.cnot divisable by params.cPerBlock
上级 f47a5f7f
...@@ -159,7 +159,7 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) { ...@@ -159,7 +159,7 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
dim3 grid; dim3 grid;
// The number of blocks to compute all the channels. // The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock; grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance. // The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock); grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances. // The number of instances.
...@@ -286,7 +286,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params, ...@@ -286,7 +286,7 @@ void groupNormNCHW32SumQDQ(const GroupNormNHWCParams &params,
dim3 grid; dim3 grid;
// The number of blocks to compute all the channels. // The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock; grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance. // The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock); grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances. // The number of instances.
...@@ -410,7 +410,7 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params, ...@@ -410,7 +410,7 @@ void groupNormNCHW32ScaleQDQ(const GroupNormNHWCParams &params,
dim3 grid; dim3 grid;
// The number of blocks to compute all the channels. // The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock; grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance. // The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock); grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances. // The number of instances.
...@@ -516,7 +516,7 @@ void groupNormNHWCScale(const GroupNormNHWCParams &params, ...@@ -516,7 +516,7 @@ void groupNormNHWCScale(const GroupNormNHWCParams &params,
dim3 grid; dim3 grid;
// The number of blocks to compute all the channels. // The number of blocks to compute all the channels.
grid.x = params.c / params.cPerBlock; grid.x = divUp(params.c, params.cPerBlock);
// The number of blocks to compute all the activations in a given instance. // The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hwPerBlock); grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances. // The number of instances.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册