diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h index eadffbfb1a66d707598b83d612ce6a7b826aa251..f113f05a35ae74b8ad63dca33d2e385f55c48cd3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h @@ -61,7 +61,6 @@ class StackPluginDynamic : public DynamicPluginTensorRT { serialize_size += SerializedSize(axis_); serialize_size += SerializedSize(num_stack_); - // serialize_size += num_stack_ * sizeof(int64_t); return serialize_size; } @@ -69,7 +68,6 @@ class StackPluginDynamic : public DynamicPluginTensorRT { void serialize(void* buffer) const override { SerializeValue(&buffer, axis_); SerializeValue(&buffer, num_stack_); - // SerializeCudaPointer(&buffer, in_ptr_gpu_, num_stack_); } nvinfer1::DimsExprs getOutputDimensions( @@ -109,6 +107,45 @@ class StackPluginDynamic : public DynamicPluginTensorRT { framework::Tensor in_ptr_tensor_; int64_t* in_ptr_gpu_; }; + +class StackPluginV2Creator : public nvinfer1::IPluginCreator { + public: + StackPluginV2Creator() {} + const char* getPluginName() const override { return "stack_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new StackPluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(StackPluginV2Creator); #endif } // namespace plugin