From c061c0827ab0dfb74b90f1fa945665f569edc123 Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Wed, 2 Nov 2022 20:38:06 +0800 Subject: [PATCH] [inference][trt] bilinear support OutSize input (#47495) * add bilinear OutSize --- .../tensorrt/convert/bilinear_interp_v2_op.cc | 35 ++++++++++++++++--- paddle/fluid/inference/tensorrt/op_teller.cc | 4 +-- .../test_trt_convert_bilinear_interp_v2.py | 4 +-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc b/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc index fca5424875..bcd8462c99 100644 --- a/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/bilinear_interp_v2_op.cc @@ -52,6 +52,7 @@ class BilinearInterpolateV2OpConverter : public OpConverter { auto resize_inputs = op_desc.Inputs(); auto input_names = op_desc.Input("X"); + auto out_h = PADDLE_GET_CONST(int, op_desc.GetAttr("out_h")); auto out_w = PADDLE_GET_CONST(int, op_desc.GetAttr("out_w")); @@ -94,6 +95,15 @@ class BilinearInterpolateV2OpConverter : public OpConverter { out_w = static_cast(in_dim.d[w_axis] * scale_w); } + // Priority: Input(OutSize) > attr(out_h/out_w) > attr(scale) + nvinfer1::ITensor* outsize_tensor = nullptr; + if (engine_->with_dynamic_shape() && + resize_inputs.find("OutSize") != resize_inputs.end()) { + if (op_desc.Input("OutSize").size() >= 1) { + outsize_tensor = engine_->GetITensor(op_desc.Input("OutSize")[0]); + } + } + if (out_h > 0 && out_w > 0) { scale_h = static_cast(out_h) / static_cast(in_dim.d[h_axis]); @@ -102,11 +112,9 @@ class BilinearInterpolateV2OpConverter : public OpConverter { } std::vector scales; - if (engine_->with_dynamic_shape()) { scales.push_back(1.f); } - if (data_layout == phi::DataLayout::kNCHW) { scales.push_back(1.f); scales.push_back(scale_h); @@ -115,12 +123,29 @@ class BilinearInterpolateV2OpConverter : public OpConverter { scales.push_back(scale_h); scales.push_back(scale_w); scales.push_back(1.f); + } + + if (engine_->with_dynamic_shape()) { + if (outsize_tensor != nullptr) { + std::vector outsize_itensors; + auto* input_shape = Shape(input); + outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 0)); + + if (data_layout == phi::DataLayout::kNCHW) { + outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 1)); + outsize_itensors.push_back(outsize_tensor); + } else if (data_layout == phi::DataLayout::kNHWC) { + outsize_itensors.push_back(outsize_tensor); + outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 3)); + } + layer->setInput(1, *Concat(outsize_itensors)); + } else { + layer->setScales(scales.data(), scales.size()); + } } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Data layout must be NCHW or NHWC.")); + layer->setScales(scales.data(), scales.size()); } - layer->setScales(scales.data(), scales.size()); RreplenishLayerAndOutput( layer, "bilinear_interp_v2", {output_name}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6680b16a35..85aa5e82d7 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -803,8 +803,8 @@ struct SimpleOpTypeSetTeller : public Teller { } if (resize_inputs.find("OutSize") != resize_inputs.end()) { - if (desc.Input("OutSize").size() >= 1) { - VLOG(3) << "The Paddle-TRT doesn't support the OutSize for op_type " + if (!with_dynamic_shape) { + VLOG(3) << "Static shape don't support the OutSize for op_type " << op_type; return false; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py index 5015e7e36b..3854d0e86b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py @@ -41,8 +41,8 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest): ) for data_layout in ["NCHW", "NHWC"]: - for scale_y in [2.0, -1.0, 0.0]: - for scale_x in [2.0, -1.0, 0.0]: + for scale_y in [2.0, 1.0]: + for scale_x in [2.0, 1.0]: scale = [scale_y, scale_x] for out_h in [32, 64, 128, 192]: for out_w in [32, 64]: -- GitLab