From 2480bdef6c60dfd56076cad8d561255b343642f3 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Tue, 8 Dec 2020 14:53:18 +0800 Subject: [PATCH] change hard_swish from plugin to layer (#29177) * change hard_swish from plugin to layer * add ut when threshold != scale --- .../tensorrt/convert/hard_swish_op.cc | 20 +++++++++++++------ .../ir/inference/test_trt_subgraph_pass.py | 6 ++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc b/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc index 967f79a164..57f8fa1351 100644 --- a/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc @@ -65,13 +65,21 @@ class HardSwishOpConverter : public OpConverter { const float offset = op_desc.HasAttr("offset") ? BOOST_GET_CONST(float, op_desc.GetAttr("offset")) : 3.0f; - nvinfer1::ILayer* layer = nullptr; - - plugin::HardSwishPlugin* plugin = - new plugin::HardSwishPlugin(threshold, scale, offset); - layer = engine_->AddPlugin(&input, input_num, plugin); - + if (threshold == scale) { + auto* hsig_layer = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *input, nvinfer1::ActivationType::kHARD_SIGMOID); + hsig_layer->setAlpha(1.0 / scale); + hsig_layer->setBeta(offset / scale); + nvinfer1::IElementWiseLayer* eltwise_layer = TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *input, *(hsig_layer->getOutput(0)), + nvinfer1::ElementWiseOperation::kPROD); + layer = eltwise_layer; + } else { + plugin::HardSwishPlugin* plugin = + new plugin::HardSwishPlugin(threshold, scale, offset); + layer = engine_->AddPlugin(&input, input_num, plugin); + } auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index 77457efa39..e5cee55a31 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -346,6 +346,12 @@ class TensorRTSubgraphPassHardSigmoidTest(TensorRTSubgraphPassActivationTest): return fluid.layers.hard_sigmoid(x) +class TensorRTSubgraphPassHardSwishPluginTest( + TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.hard_swish(x, threshold=4.0, scale=8.0) + + class TensorRTSubgraphPassClipTest(TensorRTSubgraphPassActivationTest): def append_act(self, x): return fluid.layers.clip(x, 0, 1) -- GitLab