// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { class MishPlugin : public PluginTensorRT { private: float threshold_; protected: size_t getSerializationSize() const TRT_NOEXCEPT override { return getBaseSerializationSize() + SerializedSize(threshold_); } // TRT will call this func to serialize the configuration of TRT // It should not be called by users. void serialize(void* buffer) const TRT_NOEXCEPT override { serializeBase(buffer); SerializeValue(&buffer, threshold_); } public: explicit MishPlugin(const float threshold, const bool with_fp16) : threshold_(threshold) { with_fp16_ = with_fp16; } // It was used for tensorrt deserialization. // It should not be called by users. MishPlugin(void const* serialData, size_t serialLength) { deserializeBase(serialData, serialLength); DeserializeValue(&serialData, &serialLength, &threshold_); } ~MishPlugin() {} MishPlugin* clone() const TRT_NOEXCEPT override { return new MishPlugin(threshold_, with_fp16_); } const char* getPluginType() const TRT_NOEXCEPT override { return "mish_plugin"; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int initialize() TRT_NOEXCEPT override; bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT override; nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) TRT_NOEXCEPT override; #if IS_TRT_VERSION_LT(8000) int enqueue(int batchSize, const void* const* inputs, void** outputs, #else int enqueue(int batchSize, const void* const* inputs, void* const* outputs, #endif void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; }; class MishPluginCreator : public TensorRTPluginCreator { public: const char* getPluginName() const TRT_NOEXCEPT override { return "mish_plugin"; } const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } nvinfer1::IPluginV2* deserializePlugin( const char* name, const void* serial_data, size_t serial_length) TRT_NOEXCEPT override { return new MishPlugin(serial_data, serial_length); } }; REGISTER_TRT_PLUGIN_V2(MishPluginCreator); class MishPluginDynamic : public DynamicPluginTensorRT { public: explicit MishPluginDynamic(const float threshold, const bool with_fp16) : threshold_(threshold) { with_fp16_ = with_fp16; } MishPluginDynamic(void const* serialData, size_t serialLength) { DeserializeValue(&serialData, &serialLength, &threshold_); DeserializeValue(&serialData, &serialLength, &with_fp16_); } nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { return new MishPluginDynamic(threshold_, with_fp16_); } const char* getPluginType() const TRT_NOEXCEPT override { return "mish_plugin_dynamic"; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int initialize() TRT_NOEXCEPT override; size_t getSerializationSize() const TRT_NOEXCEPT override; void serialize(void* buffer) const TRT_NOEXCEPT override; nvinfer1::DimsExprs getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override; bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override {} size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override { return 0; } int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; nvinfer1::DataType getOutputDataType( int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; void destroy() TRT_NOEXCEPT override { delete this; } private: float threshold_; }; class MishPluginDynamicCreator : public TensorRTPluginCreator { public: const char* getPluginName() const TRT_NOEXCEPT override { return "mish_plugin_dynamic"; } const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } nvinfer1::IPluginV2* deserializePlugin( const char* name, const void* serial_data, size_t serial_length) TRT_NOEXCEPT override { auto plugin = new MishPluginDynamic(serial_data, serial_length); return plugin; } }; REGISTER_TRT_PLUGIN_V2(MishPluginDynamicCreator); } // namespace plugin } // namespace tensorrt } // namespace inference } // namespace paddle