diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h index f113f05a35ae74b8ad63dca33d2e385f55c48cd3..b88c39c628abb4b7df254fdcc74bc1206fa795ae 100644 --- a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h @@ -121,7 +121,22 @@ class StackPluginV2Creator : public nvinfer1::IPluginCreator { nvinfer1::IPluginV2* createPlugin( const char* name, const nvinfer1::PluginFieldCollection* fc) override { - return nullptr; + int axis = -1; + int num_stack = -1; + + for (int i = 0; i < fc->nbFields; ++i) { + const std::string name(fc->fields[i].name); + if (name == "axis") { + axis = static_cast(fc->fields[i].data)[0]; + } else if (name == "num_stack") { + num_stack = static_cast(fc->fields[i].data)[0]; + } else { + PADDLE_THROW( + platform::errors::Fatal("Meet an unknown plugin field '" + name + + "' when creating stack op plugin.")); + } + } + return new StackPluginDynamic(axis, num_stack); } nvinfer1::IPluginV2* deserializePlugin(const char* name,