From a1abb7c9b9bc214c39d6967a0e0f2bba5da9a1fe Mon Sep 17 00:00:00 2001 From: wenbin Date: Wed, 11 May 2022 10:38:11 +0800 Subject: [PATCH] swish refactor (#42610) * swish refactor * bug fix * trt7 non-linear bug fix --- .../inference/tensorrt/convert/swish_op.cc | 2 +- .../tensorrt/plugin/swish_op_plugin.cu | 43 +++++++++++++++---- .../tensorrt/plugin/swish_op_plugin.h | 25 +++++++++-- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/swish_op.cc b/paddle/fluid/inference/tensorrt/convert/swish_op.cc index b2e394d14eb..0df5c013d34 100644 --- a/paddle/fluid/inference/tensorrt/convert/swish_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/swish_op.cc @@ -75,7 +75,7 @@ class SwishOpConverter : public OpConverter { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta, with_fp16); - layer = engine_->AddPlugin(&input, input_num, plugin); + layer = engine_->AddPluginV2Ext(&input, input_num, plugin); } auto output_name = op_desc.Output("Out")[0]; diff --git a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu index 9720719fd0b..2c2fad74b9a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu @@ -24,6 +24,16 @@ namespace tensorrt { namespace plugin { int SwishPlugin::initialize() TRT_NOEXCEPT { return 0; } +void SwishPlugin::terminate() TRT_NOEXCEPT {} + +bool SwishPlugin::supportsFormat( + nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { + if (with_fp16_) { + return type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF; + } + return type == nvinfer1::DataType::kFLOAT; +} nvinfer1::Dims SwishPlugin::getOutputDimensions(int index, const nvinfer1::Dims *inputDims, @@ -85,17 +95,29 @@ int SwishPlugin::enqueue(int batch_size, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { #endif - // input dims is CHW. const auto &input_dims = this->getInputDims(0); - const float *input = reinterpret_cast(inputs[0]); - float *output = reinterpret_cast(outputs)[0]; int num = batch_size; for (int i = 0; i < input_dims.nbDims; i++) { num *= input_dims.d[i]; } int threads = 1024; int blocks = (num + threads - 1) / threads; - swish_kernel<<>>(num, input, output, beta_); + auto type = getDataType(); + if (type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. Swish-->fp32"; + const float *input = reinterpret_cast(inputs[0]); + float *output = reinterpret_cast(outputs)[0]; + swish_kernel<<>>(num, input, output, beta_); + } else if (type == nvinfer1::DataType::kHALF) { + VLOG(1) << "TRT Plugin DataType selected. Swish-->fp16"; + const half *input = reinterpret_cast(inputs[0]); + half *output = reinterpret_cast(outputs)[0]; + swish_kernel<<>>(num, input, output, + (half)beta_); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The Swish TRT Plugin's input type should be float or half.")); + } return cudaGetLastError() != cudaSuccess; } @@ -140,12 +162,15 @@ bool SwishPluginDynamic::supportsFormatCombination( const nvinfer1::PluginTensorDesc &in = in_out[pos]; if (pos == 0) { if (with_fp16_) { - return (in.type == nvinfer1::DataType::kFLOAT || - in.type == nvinfer1::DataType::kHALF) && - (in.format == nvinfer1::TensorFormat::kLINEAR); + bool res = (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kHALF); +// encounter trt crash bug +#if IS_TRT_VERSION_LT(8000) + res = res && (in.format == nvinfer1::TensorFormat::kLINEAR); +#endif + return res; } else { - return (in.type == nvinfer1::DataType::kFLOAT) && - (in.format == nvinfer1::TensorFormat::kLINEAR); + return in.type == nvinfer1::DataType::kFLOAT; } } const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; diff --git a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h index c4bdc5f9215..aa8fdce23fa 100644 --- a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h @@ -26,7 +26,7 @@ namespace inference { namespace tensorrt { namespace plugin { -class SwishPlugin : public PluginTensorRT { +class SwishPlugin : public PluginTensorRTV2Ext { private: float beta_; @@ -55,13 +55,24 @@ class SwishPlugin : public PluginTensorRT { int initialize() TRT_NOEXCEPT override; - SwishPlugin* clone() const TRT_NOEXCEPT override { - return new SwishPlugin(beta_, with_fp16_); + nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override { + auto* plugin = new SwishPlugin(beta_, with_fp16_); + plugin->data_format_ = data_format_; + plugin->data_type_ = data_type_; + plugin->input_dims_ = input_dims_; + return plugin; } const char* getPluginType() const TRT_NOEXCEPT override { return "swish_plugin"; } + + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT override { + return input_types[0]; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) TRT_NOEXCEPT override; @@ -71,6 +82,12 @@ class SwishPlugin : public PluginTensorRT { int enqueue(int batchSize, const void* const* inputs, void* const* outputs, #endif void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; + + void terminate() TRT_NOEXCEPT override; + void destroy() TRT_NOEXCEPT override { delete this; } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "2"; } + bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) + const TRT_NOEXCEPT override; }; class SwishPluginCreator : public TensorRTPluginCreator { @@ -79,7 +96,7 @@ class SwishPluginCreator : public TensorRTPluginCreator { return "swish_plugin"; } - const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "2"; } nvinfer1::IPluginV2* deserializePlugin( const char* name, const void* serial_data, -- GitLab