From 8570e418f81970585d22551e9daa073d57adae51 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Sun, 9 Aug 2020 22:19:05 +0800 Subject: [PATCH] implement stack op createPlugin --- .../inference/tensorrt/plugin/stack_op_plugin.h | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h index f113f05a35a..b88c39c628a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h @@ -121,7 +121,22 @@ class StackPluginV2Creator : public nvinfer1::IPluginCreator { nvinfer1::IPluginV2* createPlugin( const char* name, const nvinfer1::PluginFieldCollection* fc) override { - return nullptr; + int axis = -1; + int num_stack = -1; + + for (int i = 0; i < fc->nbFields; ++i) { + const std::string name(fc->fields[i].name); + if (name == "axis") { + axis = static_cast(fc->fields[i].data)[0]; + } else if (name == "num_stack") { + num_stack = static_cast(fc->fields[i].data)[0]; + } else { + PADDLE_THROW( + platform::errors::Fatal("Meet an unknown plugin field '" + name + + "' when creating stack op plugin.")); + } + } + return new StackPluginDynamic(axis, num_stack); } nvinfer1::IPluginV2* deserializePlugin(const char* name, -- GitLab