未验证 提交 6c54e0e8 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] update trt hardswish plugin to layer (#47745)

上级 7c304580
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,7 +40,6 @@ class HardSwishOpConverter : public OpConverter { ...@@ -41,7 +40,6 @@ class HardSwishOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
int input_num = op_desc.Input("X").size();
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
const float threshold = const float threshold =
...@@ -69,21 +67,32 @@ class HardSwishOpConverter : public OpConverter { ...@@ -69,21 +67,32 @@ class HardSwishOpConverter : public OpConverter {
nvinfer1::ElementWiseOperation::kPROD); nvinfer1::ElementWiseOperation::kPROD);
layer = eltwise_layer; layer = eltwise_layer;
} else { } else {
if (engine_->with_dynamic_shape()) { int32_t rank = input->getDimensions().nbDims;
#if IS_TRT_VERSION_GE(6000) nvinfer1::Dims constant_shape;
plugin::HardSwishPluginDynamic* plugin = constant_shape.nbDims = rank;
new plugin::HardSwishPluginDynamic(threshold, scale, offset); std::fill(constant_shape.d, constant_shape.d + rank, 1);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin); std::vector<float> weight_threshold_data{threshold};
#else std::vector<float> weight_scale_data{scale};
PADDLE_THROW(platform::errors::Fatal( std::vector<float> weight_offset_data{offset};
"You are running the TRT Dynamic Shape mode, need to confirm that " std::vector<float> weight_zero_data{0.f};
"your TRT version is no less than 6.0")); auto* threshold_data =
#endif AddConstantLayer(weight_threshold_data.data(), constant_shape);
} else { auto* scale_data =
plugin::HardSwishPlugin* plugin = AddConstantLayer(weight_scale_data.data(), constant_shape);
new plugin::HardSwishPlugin(threshold, scale, offset); auto* offset_data =
layer = engine_->AddPlugin(&input, input_num, plugin); AddConstantLayer(weight_offset_data.data(), constant_shape);
} auto* zero_data =
AddConstantLayer(weight_zero_data.data(), constant_shape);
auto* input_sum_with_offset = Sum(input, offset_data);
auto* pre_max_with_zero = Max(input_sum_with_offset, zero_data);
auto* pre_min_with_threshold = Min(pre_max_with_zero, threshold_data);
auto* pre_prod_with_input = Prod(pre_min_with_threshold, input);
layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*pre_prod_with_input,
*scale_data,
nvinfer1::ElementWiseOperation::kDIV);
} }
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册