From 3d232b29742ea1b6bcc6bb1da6571ec6a20917bd Mon Sep 17 00:00:00 2001 From: ccrrong <101700995+ccrrong@users.noreply.github.com> Date: Fri, 17 Jun 2022 11:29:09 +0800 Subject: [PATCH] add bilinear interp v2 converter (#43307) * add bilinear_interp_v2 converter --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../tensorrt/convert/bilinear_interp_v2_op.cc | 133 ++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 95 +++++++++++++ .../test_trt_convert_bilinear_interp_v2.py | 132 +++++++++++++++++ 5 files changed, 362 insertions(+) create mode 100644 paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index bb495860c9..76331bfe7c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1943,6 +1943,7 @@ USE_TRT_CONVERTER(multiclass_nms); USE_TRT_CONVERTER(multiclass_nms3); USE_TRT_CONVERTER(nearest_interp); USE_TRT_CONVERTER(nearest_interp_v2); +USE_TRT_CONVERTER(bilinear_interp_v2); USE_TRT_CONVERTER(reshape); USE_TRT_CONVERTER(reduce_sum); USE_TRT_CONVERTER(gather_nd); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 4c52d91fa1..e6c372e205 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -52,6 +52,7 @@ list( conv3d_op.cc mish_op.cc nearest_interp_v2_op.cc + bilinear_interp_v2_op.cc pool3d_op.cc deformable_conv_op.cc preln_emb_eltwise_layernorm.cc diff --git a/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc b/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc new file mode 100644 index 0000000000..f0e56082b8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class BilinearInterpolateV2OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid bilinear_interp_v2 op"; + + framework::OpDesc op_desc(op, nullptr); + + std::string input_name = op_desc.Input("X").front(); + std::string output_name = op_desc.Output("Out").front(); + + auto input = engine_->GetITensor(input_name); + + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout"))); + auto interp_method = + BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method")); + bool align_corners = + BOOST_GET_CONST(bool, op_desc.GetAttr("align_corners")); + auto align_mode = BOOST_GET_CONST(int, op_desc.GetAttr("align_mode")); + + auto resize_inputs = op_desc.Inputs(); + auto input_names = op_desc.Input("X"); + auto out_h = BOOST_GET_CONST(int, op_desc.GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, op_desc.GetAttr("out_w")); + + auto layer = TRT_ENGINE_ADD_LAYER(engine_, Resize, *input); + if (align_mode == 0 && !align_corners) { + layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR); + } + + auto in_dim = input->getDimensions(); + float scale_h = 1.f; + float scale_w = 1.f; + + // Scales Priority: Scale(tensor) > scale(attr) > out_d/out_h/out_w(attr) + bool has_scale_input_attr = + (resize_inputs.find("Scale") != resize_inputs.end()); + bool has_scale_input = + has_scale_input_attr && (op_desc.Input("Scale").size() > 0); + if (has_scale_input) { + auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]); + auto* scale_tensor = scale_var->GetMutable(); + auto* scale_d = scale_tensor->data(); + scale_h = scale_d[0]; + scale_w = scale_d[1]; + } else { + const std::vector scale_attr = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("scale")); + if (scale_attr.size() > 1) { + scale_h = scale_attr[0]; + scale_w = scale_attr[1]; + } + } + + // axis are different in static/dynamic mode + bool with_dynamic = engine_->with_dynamic_shape(); + int h_axis = (data_layout == framework::DataLayout::kNCHW) + with_dynamic; + int w_axis = + (data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic; + + if (scale_w > 0. && scale_h > 0.) { + out_h = static_cast(in_dim.d[h_axis] * scale_h); + out_w = static_cast(in_dim.d[w_axis] * scale_w); + } + + if (out_h > 0 && out_w > 0) { + scale_h = + static_cast(out_h) / static_cast(in_dim.d[h_axis]); + scale_w = + static_cast(out_w) / static_cast(in_dim.d[w_axis]); + } + + std::vector scales; + + if (engine_->with_dynamic_shape()) { + scales.push_back(1.f); + } + + if (data_layout == framework::DataLayout::kNCHW) { + scales.push_back(1.f); + scales.push_back(scale_h); + scales.push_back(scale_w); + } else if (data_layout == framework::DataLayout::kNHWC) { + scales.push_back(scale_h); + scales.push_back(scale_w); + scales.push_back(1.f); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Data layout must be NCHW or NHWC.")); + } + + layer->setScales(scales.data(), scales.size()); + RreplenishLayerAndOutput(layer, "bilinear_interp_v2", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(bilinear_interp_v2, BilinearInterpolateV2OpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index e5c52193f1..d7e481bfa9 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -144,6 +144,7 @@ struct SimpleOpTypeSetTeller : public Teller { "conv3d_transpose", "mish", "nearest_interp_v2", + "bilinear_interp_v2", "pool3d", "deformable_conv", "relu6", @@ -239,6 +240,7 @@ struct SimpleOpTypeSetTeller : public Teller { "conv3d", "conv3d_transpose", "mish", + "bilinear_interp_v2", "nearest_interp_v2", "pool3d", "deformable_conv", @@ -875,6 +877,99 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } + if (op_type == "bilinear_interp_v2") { + std::vector attrs{"data_layout", "interp_method", + "align_corners", "scale", + "out_h", "out_w"}; + for (auto const attr : attrs) { + if (!desc.HasAttr(attr)) { + VLOG(3) << "The op_type " << op_type << " doesn't have the attr " + << attr << " and return false"; + return false; + } + } + + auto resize_inputs = desc.Inputs(); + if (resize_inputs.find("SizeTensor") != resize_inputs.end()) { + if (desc.Input("SizeTensor").size() >= 1) { + VLOG(3) + << "The Paddle-TRT doesn't support the SizeTensor for op_type " + << op_type; + return false; + } + } + + if (resize_inputs.find("OutSize") != resize_inputs.end()) { + if (desc.Input("OutSize").size() >= 1) { + VLOG(3) << "The Paddle-TRT doesn't support the OutSize for op_type " + << op_type; + return false; + } + } + + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, desc.GetAttr("data_layout"))); + if (data_layout != framework::DataLayout::kNCHW && + data_layout != framework::DataLayout::kNHWC) { + VLOG(3) << "The op_type " << op_type + << " is not NCHW or NHWC return false"; + return false; + } + auto interp_method = + BOOST_GET_CONST(std::string, desc.GetAttr("interp_method")); + if (interp_method != "bilinear") { + VLOG(3) << "The interp_method of op_type " << op_type + << " is not bilinear"; + return false; + } + + auto align_corners = BOOST_GET_CONST(bool, desc.GetAttr("align_corners")); + if (align_corners != false) { + VLOG(3) + << "The bilinear_interp_v2 only supports align_corners with false."; + return false; + } + + bool has_scale_input_size = + (resize_inputs.find("Scale") != resize_inputs.end()); + + if (has_scale_input_size && desc.Input("Scale").size() != 1) { + const std::vector scale = + BOOST_GET_CONST(std::vector, desc.GetAttr("scale")); + if (scale.size() <= 1) { + if (!desc.HasAttr("out_h") || !desc.HasAttr("out_w")) { + VLOG(3) << "The op_type " << op_type + << " doesn't have Scale and the scale size <=1 and without " + "out_h / out_w, it will return false"; + return false; + } + auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w")); + if (!(out_h <= 0 && out_w <= 0)) { + if (out_h <= 0) { + VLOG(3) << "The op_type " << op_type + << "'s out_h must be greater than 0 if scale is not set."; + return false; + } + if (out_w <= 0) { + VLOG(3) << "The op_type " << op_type + << "'s out_w must be greater than 0 if scale is not set."; + return false; + } + } + } else { + for (size_t i = 0; i < scale.size(); i++) { + if (scale[i] <= 0 && with_dynamic_shape) { + VLOG(3) << "dynamic shape not support Attr(scale[" << i << "]) " + << scale[i] + << " less than 1 and Input(Scale) vector not set."; + return false; + } + } + } + } + } + if (op_type == "hard_swish") { if (desc.Input("X").size() != 1) { VLOG(3) << "HardSwish op has only 1 input, but got " diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py new file mode 100644 index 0000000000..3fe041db93 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + return True + + def sample_program_configs(self): + + def generate_input1(attrs: List[Dict[str, Any]]): + return np.ones([1, 3, 64, 64]).astype(np.float32) + + def generate_input2(attrs: List[Dict[str, Any]]): + return np.random.uniform(low=0.5, high=6.0, + size=(2)).astype("float32") + + for data_layout in ["NCHW", "NHWC"]: + for scale_y in [2.0, -1.0, 0.0]: + for scale_x in [2.0, -1.0, 0.0]: + scale = [scale_y, scale_x] + for out_h in [32, 64, 128, 192]: + for out_w in [32, 64]: + dics = [{ + "data_layout": data_layout, + "interp_method": "bilinear", + "align_corners": False, + "align_mode": 0, + "scale": scale, + "out_h": out_h, + "out_w": out_w + }] + + ops_config = [{ + "op_type": "bilinear_interp_v2", + "op_inputs": { + "X": ["input_data"], + "Scale": ["input_scale"] + }, + "op_outputs": { + "Out": ["bilinear_interp_v2_output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "input_scale": + TensorConfig( + data_gen=partial(generate_input2, dics)) + }, + inputs={ + "input_data": + TensorConfig( + data_gen=partial(generate_input1, dics)) + }, + outputs=["bilinear_interp_v2_output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 64, 64]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-2 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab