未验证 提交 acab7daf 编写于 作者: W wenbin 提交者: GitHub

gn bug fix (#49658)

* gn bug fix

* bug fix

* gn bug fix
上级 a227ae2b
......@@ -178,6 +178,13 @@ void groupNormNHWCSum(const GroupNormNHWCParams &params, cudaStream_t stream) {
case 128:
groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of GroupNormPlugin TRT Plugin "
"encounter error"));
}
}
......@@ -277,10 +284,13 @@ void groupNormNHWCScale(const GroupNormNHWCParams &params,
case 128:
groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
groupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(
platform::errors::Fatal("The function groupNormNHWCScale of "
"GroupNorm TRT Plugin encounter error"));
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCScale of GroupNormPlugin TRT Plugin "
"encounter error"));
}
}
......@@ -610,6 +620,9 @@ int GroupNormPluginDynamic::enqueue(
default:
cPerBlock = 320;
}
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = false;
params_.dst = static_cast<half *>(outputs[0]);
......
......@@ -254,6 +254,9 @@ void prelnGroupNormNHWCSum(GroupNormNHWCParams const &params,
case 128:
prelnGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
prelnGroupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
......@@ -375,6 +378,9 @@ void prelnGroupNormNHWCScale(GroupNormNHWCParams const &params,
case 128:
prelnGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
prelnGroupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
......@@ -415,6 +421,9 @@ int PrelnGroupnormActPluginDynamic::enqueue(
default:
cPerBlock = 320;
}
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = true;
params_.dst = static_cast<half *>(outputs[1]);
params_.eleOut = static_cast<half *>(outputs[0]);
......
......@@ -264,6 +264,9 @@ void skipGroupNormNHWCSum(GroupNormNHWCParams const &params,
case 128:
skipGroupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
skipGroupNormNHWCSumKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
......@@ -384,6 +387,9 @@ void skipGroupNormNHWCScale(GroupNormNHWCParams const &params,
case 128:
skipGroupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params);
break;
case 8:
skipGroupNormNHWCScaleKernel<4><<<grid, 4, 0, stream>>>(params);
break;
default:
PADDLE_THROW(platform::errors::Fatal(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
......@@ -423,6 +429,9 @@ int SkipGroupnormActPluginDynamic::enqueue(
default:
cPerBlock = 320;
}
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = true;
params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册