未验证 提交 acab7daf 编写于 作者: W wenbin 提交者: GitHub

gn bug fix (#49658)

* gn bug fix

* bug fix

* gn bug fix
上级 a227ae2b
...@@ -178,6 +178,13 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) { ...@@ -178,6 +178,13 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
case 128: case 128:
groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params); groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break; break;
case 8:
groupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of GroupNormPlugin TRT Plugin "
"encounter error"));
} }
} }
...@@ -277,10 +284,13 @@ void groupNormNHWCScale(const GroupNormNHWCParams &params, ...@@ -277,10 +284,13 @@ void groupNormNHWCScale(const GroupNormNHWCParams &params,
case 128: case 128:
groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params); groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break; break;
case 8:
groupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default: default:
PADDLE_THROW( PADDLE_THROW(platform::errors::Fatal(
platform::errors::Fatal("The function groupNormNHWCScale of " "The function groupNormNHWCScale of GroupNormPlugin TRT Plugin "
"GroupNorm TRT Plugin encounter error")); "encounter error"));
} }
} }
...@@ -610,6 +620,9 @@ int GroupNormPluginDynamic::enqueue( ...@@ -610,6 +620,9 @@ int GroupNormPluginDynamic::enqueue(
default: default:
cPerBlock = 320; cPerBlock = 320;
} }
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = false; params_.withSwish = false;
params_.dst = static_cast<half *>(outputs[0]); params_.dst = static_cast<half *>(outputs[0]);
......
...@@ -254,6 +254,9 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params, ...@@ -254,6 +254,9 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params,
case 128: case 128:
prelnGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params); prelnGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break; break;
case 8:
prelnGroupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
...@@ -375,6 +378,9 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams const &params, ...@@ -375,6 +378,9 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams const &params,
case 128: case 128:
prelnGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params); prelnGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break; break;
case 8:
prelnGroupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
...@@ -415,6 +421,9 @@ int PrelnGroupnormActPluginDynamic::enqueue( ...@@ -415,6 +421,9 @@ int PrelnGroupnormActPluginDynamic::enqueue(
default: default:
cPerBlock = 320; cPerBlock = 320;
} }
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = true; params_.withSwish = true;
params_.dst = static_cast<half *>(outputs[1]); params_.dst = static_cast<half *>(outputs[1]);
params_.eleOut = static_cast<half *>(outputs[0]); params_.eleOut = static_cast<half *>(outputs[0]);
......
...@@ -264,6 +264,9 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params, ...@@ -264,6 +264,9 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params,
case 128: case 128:
skipGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params); skipGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break; break;
case 8:
skipGroupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
...@@ -384,6 +387,9 @@ void skipGroupNormNHWCScale(GroupNormNHWCParams const &params, ...@@ -384,6 +387,9 @@ void skipGroupNormNHWCScale(GroupNormNHWCParams const &params,
case 128: case 128:
skipGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params); skipGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break; break;
case 8:
skipGroupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
...@@ -423,6 +429,9 @@ int SkipGroupnormActPluginDynamic::enqueue( ...@@ -423,6 +429,9 @@ int SkipGroupnormActPluginDynamic::enqueue(
default: default:
cPerBlock = 320; cPerBlock = 320;
} }
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = true; params_.withSwish = true;
params_.dst = static_cast<half *>(outputs[0]); params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册