From bb6b84800ffedcf5fca95b68beafa55851087137 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 7 Aug 2020 10:03:27 +0800 Subject: [PATCH] gelu op use nvinfer_plugin --- .../inference/tensorrt/convert/gelu_op.cc | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/gelu_op.cc b/paddle/fluid/inference/tensorrt/convert/gelu_op.cc index 7927b6cd1bb..c10fffc706e 100644 --- a/paddle/fluid/inference/tensorrt/convert/gelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/gelu_op.cc @@ -47,8 +47,27 @@ class GeluOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { #if IS_TRT_VERSION_GE(6000) - plugin::GeluPluginDynamic* plugin = new plugin::GeluPluginDynamic(); - layer = engine_->AddPluginV2(&input, input_num, plugin); + auto creator = + getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1"); + assert(creator != nullptr); + int type = static_cast((engine_->WithFp16() == 1) + ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT); + const std::vector fields{ + { "type_id", + &type, + nvinfer1::PluginFieldType::kINT32, + 1 }}; + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(*pluginPtr) + + fields.size() * sizeof(nvinfer1::PluginField))); + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + auto pluginObj = + creator->createPlugin("CustomGeluPluginDynamic", pluginPtr); + layer = engine_->network()->addPluginV2(&input, input_num, *pluginObj); + assert(layer != nullptr); #else PADDLE_THROW(platform::errors::Fatal( "You are running the TRT Dynamic Shape mode, need to confirm that " -- GitLab