From 0c968b9db42be0f68cc08262eb58a72e11ff036a Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Wed, 6 Apr 2022 19:27:56 +0800 Subject: [PATCH] add div plugin and add filter (#41243) --- paddle/fluid/inference/tensorrt/op_teller.cc | 8 +++++++ .../tensorrt/plugin/elementwise_op_plugin.cu | 21 +++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index cfdccecb5c8..85c5dc7107f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1007,6 +1007,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto* y_var_desc = block->FindVar(desc.Input("Y")[0]); const auto x_shape = x_var_desc->GetShape(); const auto y_shape = y_var_desc->GetShape(); + if (op_type == "elementwise_add" && y_var_desc->Persistable()) { + if (y_shape.size() != 1) { + return false; + } + if (y_shape[0] != x_shape[1]) { + return false; + } + } if (x_shape.size() == 1 && y_shape.size() == 1) { VLOG(3) << "Now trt may not support two 1d tensor elementwise op."; return false; diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu index d6a1cdb9e68..c9163e62a2e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu @@ -30,6 +30,11 @@ template struct Mul { __device__ T operator()(const T &a, const T &b) const { return a * b; } }; + +template +struct Div { + __device__ T operator()(const T &a, const T &b) const { return a / b; } +}; } // namespace details template @@ -130,6 +135,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs, elementwise_kernel<<>>( num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, details::Mul()); + } else if (type_ == "div") { + elementwise_kernel<<>>( + num, x, y, out, prev_size_, batch_size * midd_size_, post_size_, + details::Div()); } else { PADDLE_THROW(platform::errors::Fatal( "The %s type elementwise is not implemented in trt plugin.", type_)); @@ -242,11 +251,15 @@ int ElementwisePluginDynamic::enqueue( } else if (type_ == "mul") { elementwise_kernel<<>>( num, x, y, out, prev_size, midd_size, post_size, details::Mul()); + } else if (type_ == "div") { + elementwise_kernel<<>>( + num, x, y, out, prev_size, midd_size, post_size, details::Div()); } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Paddle-TRT only support elementwise operation: {add, mul} currently, " - "but got %s.", - type_)); + PADDLE_THROW( + platform::errors::Unimplemented("Paddle-TRT only support elementwise " + "operation: {add, mul, div} currently, " + "but got %s.", + type_)); } return cudaGetLastError() != cudaSuccess; -- GitLab