提交 bb6b8480 编写于 作者: Z zlsh80826

gelu op use nvinfer_plugin

上级 e0c94b86
......@@ -47,8 +47,27 @@ class GeluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::GeluPluginDynamic* plugin = new plugin::GeluPluginDynamic();
layer = engine_->AddPluginV2(&input, input_num, plugin);
auto creator =
getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
const std::vector<nvinfer1::PluginField> fields{
{ "type_id",
&type,
nvinfer1::PluginFieldType::kINT32,
1 }};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() * sizeof(nvinfer1::PluginField)));
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomGeluPluginDynamic", pluginPtr);
layer = engine_->network()->addPluginV2(&input, input_num, *pluginObj);
assert(layer != nullptr);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册