From 4cd8a78ad72551ce2055e9acb6ca169099ed051c Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 11 Jan 2022 16:51:41 +0800 Subject: [PATCH] [cherry-pick]mish trt plugin (#38866) * add mish trt plugin, compile & install success, run error. test=develop * modify code of mish plugin * upgrade mish trt plugin * modify code according to review * add TRT_NOEXCEPT for mish trt plugin * add unittest for mish trt plugin * remove unnecessary check of mish in op_teller.cc * fix some problem of trt8 * add check and modify unittest while converting mish to trt plugin Co-authored-by: dengkaipeng --- paddle/fluid/framework/ir/is_test_pass.cc | 2 +- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/mish_op.cc | 74 ++++++ .../tensorrt/convert/test_mish_op.cc | 47 ++++ paddle/fluid/inference/tensorrt/op_teller.cc | 41 ++- .../inference/tensorrt/plugin/CMakeLists.txt | 1 + .../tensorrt/plugin/mish_op_plugin.cu | 235 ++++++++++++++++++ .../tensorrt/plugin/mish_op_plugin.h | 175 +++++++++++++ .../ir/inference/test_trt_activation_pass.py | 36 +++ .../ir/inference/test_trt_convert_mish.py | 174 +++++++++++++ 11 files changed, 785 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/mish_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_mish_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_mish.py diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 25bf03f426..a97873e82f 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -35,7 +35,7 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const { "hard_shrink", "hard_sigmoid", "relu6", "soft_relu", "swish", "thresholded_relu", "log", "square", "softplus", - "softsign", "silu"}; + "softsign", "silu", "mish"}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 771bd09dde..d9d6dc4f26 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1414,6 +1414,7 @@ USE_TRT_CONVERTER(tile); USE_TRT_CONVERTER(conv3d); USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(pool3d); +USE_TRT_CONVERTER(mish); #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b73fa57bcd..aabd7b2fa0 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -18,6 +18,7 @@ nv_library(tensorrt_converter tile_op.cc conv3d_op.cc pool3d_op.cc + mish_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/mish_op.cc b/paddle/fluid/inference/tensorrt/convert/mish_op.cc new file mode 100644 index 0000000000..6b646d9935 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/mish_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Mish converter from fluid to tensorRT. + */ +class MishOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid Mish op to tensorrt Mish plugin"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + + const float threshold = + op_desc.HasAttr("threshold") + ? BOOST_GET_CONST(float, op_desc.GetAttr("threshold")) + : 20.0f; + + nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { + bool with_fp16 = + engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + plugin::MishPluginDynamic* plugin = + new plugin::MishPluginDynamic(threshold, with_fp16); + layer = engine_->AddDynamicPlugin(&input, input_num, plugin); + } else { + bool with_fp16 = + engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + plugin::MishPlugin* plugin = new plugin::MishPlugin(threshold, with_fp16); + layer = engine_->AddPlugin(&input, input_num, plugin); + } + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "mish", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(mish, MishOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_mish_op.cc b/paddle/fluid/inference/tensorrt/convert/test_mish_op.cc new file mode 100644 index 0000000000..c84c30255f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_mish_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(mish_op, test_mish) { + std::unordered_set parameters; + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("mish-X", nvinfer1::Dims3(3, 2, 2)); + validator.DeclOutputVar("mish-Out", nvinfer1::Dims3(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("mish"); + desc.SetInput("X", {"mish-X"}); + desc.SetOutput("Out", {"mish-Out"}); + + desc.SetAttr("threshold", 20.0f); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(mish); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 1c3f4e9ec8..a310ce81a8 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -169,7 +169,8 @@ struct SimpleOpTypeSetTeller : public Teller { "reduce_mean", "conv3d", "conv3d_transpose", - "pool3d"}; + "pool3d", + "mish"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, @@ -1160,6 +1161,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, #endif } + if (op_type == "mish") { + if (desc.Input("X").size() != 1) { + VLOG(3) << "Invalid input X's size of mish TRT converter. " + "Expected 1, received " + << desc.Input("X").size() << "."; + return false; + } + if (desc.Output("Out").size() != 1) { + VLOG(3) << "Invalid output Out's size of mish TRT converter. " + "Expected 1, received " + << desc.Output("Out").size() << "."; + return false; + } + + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + + auto x_var_name = desc.Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + if (x_shape.size() == 1) { + VLOG(3) << "mish op does not support input's dim is 1 in tensorrt."; + return false; + } + + if (!with_dynamic_shape) { + if (x_shape.size() == 2) { + VLOG(3) << "mish op does not support input's dim is 2 in tensorrt."; + return false; + } + } + } + if (op_type == "roi_align") { if (!with_dynamic_shape) { VLOG(3) << "TRT roi align plugin only accept the dynamic shape, " diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 8f948e61f1..e2685c9fa3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -10,6 +10,7 @@ nv_library(tensorrt_plugin roi_align_op_plugin.cu gather_nd_op_plugin.cu pool3d_op_plugin.cu + mish_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.cu new file mode 100644 index 0000000000..6e268e7b0b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.cu @@ -0,0 +1,235 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +int MishPlugin::initialize() TRT_NOEXCEPT { return 0; } + +bool MishPlugin::supportsFormat( + nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT { + if (with_fp16_) { + return ((type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF) && + (format == nvinfer1::PluginFormat::kLINEAR)); + } else { + return ((type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::PluginFormat::kLINEAR)); + } +} + +nvinfer1::Dims MishPlugin::getOutputDimensions(int index, + const nvinfer1::Dims* in_dims, + int nb_inputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nb_inputs, 1, platform::errors::InvalidArgument( + "We expect [number of inputs] == 1" + "in TRT Mish op plugin, but got " + "[number of inputs] = %d.", + nb_inputs)); + PADDLE_ENFORCE_LT(index, this->getNbOutputs(), + platform::errors::InvalidArgument( + "We expect [index] < [number of outputs]" + "in TRT Mish op plugin, but got " + "[index] = %d, [number of outputs] = %d.", + index, this->getNbOutputs())); + nvinfer1::Dims const& input_dims = in_dims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +template +__device__ T kTanh(T x) { + return tanh(x); +} + +template <> +__device__ half kTanh(half x) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const float tmp = tanhf(__half2float(x)); + return __float2half(tmp); +#endif +} + +template +__device__ T kSoftplus(T x, T threshold) { + return x > threshold ? x : log(exp(x) + static_cast(1.0f)); +} + +template <> +__device__ half kSoftplus(half x, half threshold) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + return x > threshold ? x : hlog(hexp(x) + static_cast(1.0f)); +#endif +} + +template +__global__ void mish_kernel(float threshold, int n, const T* input, T* output) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + const T in = input[idx]; + output[idx] = in * kTanh(kSoftplus(in, static_cast(threshold))); + } +} + +template <> +__global__ void mish_kernel(float threshold, int n, const half* input, + half* output) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + const half in = input[idx]; + output[idx] = + in * kTanh(kSoftplus(in, static_cast(threshold))); + } +#endif +} + +#if IS_TRT_VERSION_LT(8000) +int MishPlugin::enqueue(int batchSize, const void* const* inputs, + void** outputs, +#else +int MishPlugin::enqueue(int batchSize, const void* const* inputs, + void* const* outputs, +#endif + void* workspace, cudaStream_t stream) TRT_NOEXCEPT { + const auto& input_dims = this->getInputDims(0); + int num = batchSize; + for (int i = 0; i < input_dims.nbDims; i++) { + num *= input_dims.d[i]; + } + + const int block_size = 256; + const int grid_size = (num + block_size - 1) / block_size; + + auto type = getDataType(); + if (type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. Mish-->fp32"; + const float* input = static_cast(inputs[0]); + float* output = static_cast(outputs[0]); + mish_kernel<<>>(threshold_, num, + input, output); + } else if (type == nvinfer1::DataType::kHALF) { + VLOG(1) << "TRT Plugin DataType selected. Mish-->fp16"; + const half* input = static_cast(inputs[0]); + half* output = static_cast(outputs[0]); + mish_kernel<<>>(threshold_, num, + input, output); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The Mish TRT Plugin's input type should be float or half.")); + } + + return cudaGetLastError() != cudaSuccess; +} + +// Dynamic Plugin below. +int MishPluginDynamic::initialize() TRT_NOEXCEPT { + getPluginNamespace(); + return 0; +} + +size_t MishPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { + return SerializedSize(threshold_) + SerializedSize(with_fp16_); +} + +void MishPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT { + SerializeValue(&buffer, threshold_); + SerializeValue(&buffer, with_fp16_); +} + +nvinfer1::DimsExprs MishPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { + return inputs[0]; +} + +bool MishPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of mish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc& in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } else { + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + } + const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType MishPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The Mish Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +int MishPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + auto input_dims = input_desc[0].dims; + size_t num = ProductDim(input_dims); + const int block_size = 256; + const int grid_size = (num + block_size - 1) / block_size; + + auto input_type = input_desc[0].type; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. Mish-->fp32"; + const float* input = static_cast(inputs[0]); + float* output = static_cast(outputs[0]); + mish_kernel<<>>(threshold_, num, + input, output); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(1) << "TRT Plugin DataType selected. Mish-->fp16"; + const half* input = static_cast(inputs[0]); + half* output = static_cast(outputs[0]); + mish_kernel<<>>(threshold_, num, + input, output); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The Mish TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h new file mode 100644 index 0000000000..75390666ea --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h @@ -0,0 +1,175 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class MishPlugin : public PluginTensorRT { + private: + float threshold_; + + protected: + size_t getSerializationSize() const TRT_NOEXCEPT override { + return getBaseSerializationSize() + SerializedSize(threshold_); + } + + // TRT will call this func to serialize the configuration of TRT + // It should not be called by users. + void serialize(void* buffer) const TRT_NOEXCEPT override { + serializeBase(buffer); + SerializeValue(&buffer, threshold_); + } + + public: + explicit MishPlugin(const float threshold, const bool with_fp16) + : threshold_(threshold) { + with_fp16_ = with_fp16; + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + MishPlugin(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &threshold_); + } + + ~MishPlugin() {} + MishPlugin* clone() const TRT_NOEXCEPT override { + return new MishPlugin(threshold_, with_fp16_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "mish_plugin"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override; + bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) + const TRT_NOEXCEPT override; + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) TRT_NOEXCEPT override; +#if IS_TRT_VERSION_LT(8000) + int enqueue(int batchSize, const void* const* inputs, void** outputs, +#else + int enqueue(int batchSize, const void* const* inputs, void* const* outputs, +#endif + void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; +}; + +class MishPluginCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "mish_plugin"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, const void* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + return new MishPlugin(serial_data, serial_length); + } +}; + +REGISTER_TRT_PLUGIN_V2(MishPluginCreator); + +class MishPluginDynamic : public DynamicPluginTensorRT { + public: + explicit MishPluginDynamic(const float threshold, const bool with_fp16) + : threshold_(threshold) { + with_fp16_ = with_fp16; + } + MishPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &threshold_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new MishPluginDynamic(threshold_, with_fp16_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "mish_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override; + + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + float threshold_; +}; + +class MishPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "mish_plugin_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, const void* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + auto plugin = new MishPluginDynamic(serial_data, serial_length); + return plugin; + } +}; + +REGISTER_TRT_PLUGIN_V2(MishPluginDynamicCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_activation_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_activation_pass.py index 8e196f5081..62825caf51 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_activation_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_activation_pass.py @@ -139,6 +139,42 @@ class TensorRTSubgraphPassDynamicSwishFp16SerializeTest( return fluid.layers.swish(x) +class TensorRTSubgraphPassMishTest(TensorRTSubgraphPassActivationTest): + def setUpTensorRTParam(self): + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False) + + def append_act(self, x): + return fluid.layers.mish(x) + + +class TensorRTSubgraphPassMishFp16SerializeTest( + TensorRTSubgraphPassActivationTest): + def setUpTensorRTParam(self): + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Half, True, False) + + def append_act(self, x): + return fluid.layers.mish(x) + + +class TensorRTSubgraphPassDynamicMishFp16SerializeTest( + TensorRTSubgraphPassActivationTest): + def setUpTensorRTParam(self): + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Half, False, False) + self.dynamic_shape_params = TensorRTSubgraphPassActivationTest.DynamicShapeParam( + { + 'data': [1, 6, 8, 8] + }, {'data': [1, 6, 512, 512]}, {'data': [1, 6, 256, 256]}, False) + + def append_act(self, x): + return fluid.layers.mish(x) + + class TensorRTSubgraphPassPreluAllTest(TensorRTSubgraphPassActivationTest): def append_act(self, x): return fluid.layers.prelu(x, mode='all') diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_mish.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_mish.py new file mode 100644 index 0000000000..d223fd529a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_mish.py @@ -0,0 +1,174 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertMishTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(batch, dim1, dim2, dim3): + shape = [batch] + if dim1 != 0: + shape.append(dim1) + if dim2 != 0: + shape.append(dim2) + if dim3 != 0: + shape.append(dim3) + return np.random.random(shape).astype(np.float32) + + for batch in [1, 4]: + for dim1 in [0, 3]: + for dim2 in [0, 16]: + for dim3 in [0, 32]: + for thre in [5.0, 20.0]: + self.dim1 = dim1 + self.dim2 = dim2 + self.dim3 = dim3 + + if dim1 == 0 and dim2 != 0: + continue + if dim1 == 0 and dim2 == 0 and dim3 != 0: + continue + + ops_config = [{ + "op_type": "mish", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["mish_output_data"] + }, + "op_attrs": { + "threshold": thre + } + }] + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, batch, + dim1, dim2, dim3)) + }, + outputs=["mish_output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + def generate_dynamic_shape(attrs): + if self.dim1 == 0: + self.dynamic_shape.min_input_shape = {"input_data": [1], } + self.dynamic_shape.max_input_shape = {"input_data": [4], } + self.dynamic_shape.opt_input_shape = {"input_data": [2], } + else: + if self.dim2 == 0 and self.dim3 == 0: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 1], + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 64], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [2, 3], + } + elif self.dim2 != 0 and self.dim3 != 0: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 1, 1, 1], + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 64, 128, 128], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [2, 3, 16, 32], + } + elif self.dim3 == 0: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 1, 1], + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 64, 256], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [2, 3, 128], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: + return True + return False + + self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, + "Trt does not support 1-dimensional input.") + + def teller2(program_config, predictor_config): + if (len(self.dynamic_shape.min_input_shape) == 0): + if self.dim1 != 0 and self.dim2 == 0 and self.dim3 == 0: + return True + return False + + self.add_skip_case( + teller2, SkipReasons.TRT_NOT_SUPPORT, + "Need to repair the case: the output of GPU and tensorrt has diff when the input dimension is 2 in static shape mode." + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab