未验证 提交 6c54e0e8 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] update trt hardswish plugin to layer (#47745)

上级 7c304580
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h"
namespace paddle {
namespace framework {
......@@ -41,7 +40,6 @@ class HardSwishOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
const float threshold =
......@@ -69,21 +67,32 @@ class HardSwishOpConverter : public OpConverter {
nvinfer1::ElementWiseOperation::kPROD);
layer = eltwise_layer;
} else {
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::HardSwishPluginDynamic* plugin =
new plugin::HardSwishPluginDynamic(threshold, scale, offset);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
int32_t rank = input->getDimensions().nbDims;
nvinfer1::Dims constant_shape;
constant_shape.nbDims = rank;
std::fill(constant_shape.d, constant_shape.d + rank, 1);
std::vector<float> weight_threshold_data{threshold};
std::vector<float> weight_scale_data{scale};
std::vector<float> weight_offset_data{offset};
std::vector<float> weight_zero_data{0.f};
auto* threshold_data =
AddConstantLayer(weight_threshold_data.data(), constant_shape);
auto* scale_data =
AddConstantLayer(weight_scale_data.data(), constant_shape);
auto* offset_data =
AddConstantLayer(weight_offset_data.data(), constant_shape);
auto* zero_data =
AddConstantLayer(weight_zero_data.data(), constant_shape);
auto* input_sum_with_offset = Sum(input, offset_data);
auto* pre_max_with_zero = Max(input_sum_with_offset, zero_data);
auto* pre_min_with_threshold = Min(pre_max_with_zero, threshold_data);
auto* pre_prod_with_input = Prod(pre_min_with_threshold, input);
layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*pre_prod_with_input,
*scale_data,
nvinfer1::ElementWiseOperation::kDIV);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册