From c528f1d4f3d95421fe4aacd9a981a0402a9e2de7 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Tue, 14 Apr 2020 09:48:49 +0800 Subject: [PATCH] [Paddle-TRT] Add hard_sigmoid and hard_swish support(support MobilenetV3) (#23672) * add hard_sigmoid trt op converter * add hard_swish op converter and plugin. test=develop * add macro to adapt lower trt version. test=develop --- .../fluid/inference/api/analysis_predictor.cc | 2 + .../inference/tensorrt/convert/CMakeLists.txt | 2 +- .../tensorrt/convert/hard_sigmoid_op.cc | 55 ++++++++++++ .../tensorrt/convert/hard_swish_op.cc | 72 ++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 2 + .../inference/tensorrt/plugin/CMakeLists.txt | 2 +- .../tensorrt/plugin/hard_swish_op_plugin.cu | 86 +++++++++++++++++++ .../tensorrt/plugin/hard_swish_op_plugin.h | 80 +++++++++++++++++ 8 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 1639df7d4b..6768a490cd 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -961,6 +961,8 @@ USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); +USE_TRT_CONVERTER(hard_sigmoid); +USE_TRT_CONVERTER(hard_swish); USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(prelu); USE_TRT_CONVERTER(conv2d_transpose); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index a5989bedd8..13f323f4bd 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -4,7 +4,7 @@ nv_library(tensorrt_converter batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc shuffle_channel_op.cc swish_op.cc instance_norm_op.cc -emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc +emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc hard_sigmoid_op.cc hard_swish_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc b/paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc new file mode 100644 index 0000000000..301b4140ac --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/hard_sigmoid_op.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * HardSigmoidOp, IActivationLayer in TRT. This Layer doesn't has weights. + */ +class HardSigmoidOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { +#if IS_TRT_VERSION_GE(5000) + VLOG(3) << "convert a fluid HardSigmoid op to tensorrt IActivationLayer " + "layer without bias"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + float slope = boost::get(op_desc.GetAttr("slope")); + float offset = boost::get(op_desc.GetAttr("offset")); + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input, + nvinfer1::ActivationType::kHARD_SIGMOID); + layer->setAlpha(slope); + layer->setBeta(offset); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "hard_sigmoid", {output_name}, test_mode); +#else + PADDLE_THROW(platform::errors::Fatal( + "Hard sigmoid TRT converter is only supported on TRT 5 or higher. " + "Please confirm your TRT version is no less than 5.0.")); +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(hard_sigmoid, HardSigmoidOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc b/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc new file mode 100644 index 0000000000..809dc415c3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * HardSwish converter from fluid to tensorRT. + */ +class HardSwishOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid HardSwish op to tensorrt HardSwish plugin"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE_EQ( + input_num, 1, + platform::errors::InvalidArgument( + "HardSwish op has only 1 input, but got %d", input_num)); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE_EQ( + output_num, 1, + platform::errors::InvalidArgument( + "HardSwish op has only 1 output, but got %d", output_num)); + + const float threshold = + op_desc.HasAttr("threshold") + ? boost::get(op_desc.GetAttr("threshold")) + : 6.0f; + const float scale = op_desc.HasAttr("scale") + ? boost::get(op_desc.GetAttr("scale")) + : 6.0f; + const float offset = op_desc.HasAttr("offset") + ? boost::get(op_desc.GetAttr("offset")) + : 3.0f; + + nvinfer1::ILayer* layer = nullptr; + + plugin::HardSwishPlugin* plugin = + new plugin::HardSwishPlugin(threshold, scale, offset); + layer = engine_->AddPlugin(&input, input_num, plugin); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(hard_swish, HardSwishOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 26eb26926f..671c40e5ba 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -23,6 +23,7 @@ struct SimpleOpTypeSetTeller : public Teller { SimpleOpTypeSetTeller() { #if IS_TRT_VERSION_GE(5130) teller_set.insert("relu6"); + teller_set.insert("hard_sigmoid"); #endif #if IS_TRT_VERSION_GE(6000) teller_set.insert("fused_embedding_eltwise_layernorm"); @@ -54,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller { "relu", "softmax", "sigmoid", + "hard_swish", "depthwise_conv2d", "batch_norm", "concat", diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 86edc85712..dc3e75389e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -3,5 +3,5 @@ nv_library(tensorrt_plugin prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu -qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu +qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu hard_swish_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu new file mode 100644 index 0000000000..8b2d0ac3cf --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +HardSwishPlugin* CreateHardSwishPluginDeserialize(const void* buffer, + size_t length) { + return new HardSwishPlugin(buffer, length); +} + +REGISTER_TRT_PLUGIN("hard_swish_plugin", CreateHardSwishPluginDeserialize); + +nvinfer1::Dims HardSwishPlugin::getOutputDimensions( + int index, const nvinfer1::Dims* in_dims, int nb_inputs) { + assert(nb_inputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const& input_dims = in_dims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +template +__device__ T kMax(T a, T b) { + return a > b ? a : b; +} + +template +__device__ T kMin(T a, T b) { + return a < b ? a : b; +} + +template +__global__ void hard_swish_kernel(float threshold, float scale, float offset, + int n, const T* input, T* output) { + const int idx = blockIdx.x * TPB + threadIdx.x; + if (idx < n) { + const T in = input[idx]; + output[idx] = in / scale * kMin(kMax(in + offset, 0), threshold); + } +} + +int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs, + void** outputs, void*, cudaStream_t stream) { + const auto& input_dims = this->getInputDims(0); + int num = batch_size; + for (int i = 0; i < input_dims.nbDims; i++) { + num *= input_dims.d[i]; + } + float threshold = threshold_; + float scale = scale_; + float offset = offset_; + + const int block_size = 256; + const int grid_size = (num + block_size - 1) / block_size; + + const float* input = static_cast(inputs[0]); + float* output = static_cast(outputs[0]); + hard_swish_kernel<<>>( + threshold, scale, offset, num, input, output); + + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h new file mode 100644 index 0000000000..2e1e1d03ba --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class HardSwishPlugin : public PluginTensorRT { + public: + HardSwishPlugin(const float threshold, const float scale, const float offset) + : threshold_(threshold), scale_(scale), offset_(offset) {} + + // It was used for tensorrt deserialization. + // It should not be called by users. + HardSwishPlugin(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &threshold_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &offset_); + } + + ~HardSwishPlugin() {} + HardSwishPlugin* clone() const override { + return new HardSwishPlugin(threshold_, scale_, offset_); + } + + const char* getPluginType() const override { return "hard_swish_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; + + protected: + float threshold_; + float scale_; + float offset_; + + size_t getSerializationSize() override { + return getBaseSerializationSize() + SerializedSize(threshold_) + + SerializedSize(scale_) + SerializedSize(offset_) + + SerializedSize(getPluginType()); + } + + // TRT will call this func to serialize the configuration of TRT + // It should not be called by users. + void serialize(void* buffer) override { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, threshold_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, offset_); + } +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle -- GitLab