From bc150edc388290e9b7c36e6703212370154a5d82 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 23 Nov 2021 13:27:57 +0800 Subject: [PATCH] [Paddle Inference] Fix_nearest: align_corners != true (#37368) * fix_nearest * fix_nearest * fix_nearest * fix_nearest --- .../tensorrt/convert/nearest_interp_op.cc | 6 +- paddle/fluid/inference/tensorrt/op_teller.cc | 55 ++++++++++--------- .../test_trt_convert_nearest_interp.py | 2 + .../inference/test_trt_nearest_interp_op.py | 6 +- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc b/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc index fa21442e2db..169df33c7c3 100644 --- a/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc @@ -38,8 +38,10 @@ class NearestInterpolateOpConverter : public OpConverter { auto input = engine_->GetITensor(input_name); - auto data_layout = framework::StringToDataLayout( - BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout"))); + auto data_layout = !op_desc.HasAttr("data_layout") + ? framework::DataLayout::kNCHW + : framework::StringToDataLayout(BOOST_GET_CONST( + std::string, op_desc.GetAttr("data_layout"))); auto interp_method = BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method")); bool align_corners = diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 5e9f7e28a20..6f0dec45644 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -667,43 +667,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "nearest_interp") { - std::vector attrs{"data_layout", "interp_method", - "align_corners", "scale", - "out_h", "out_w"}; + std::vector attrs{"interp_method", "align_corners", "scale", + "out_h", "out_w"}; for (auto const attr : attrs) { if (!desc.HasAttr(attr)) return false; } - auto data_layout = framework::StringToDataLayout( - BOOST_GET_CONST(std::string, desc.GetAttr("data_layout"))); - if (data_layout != framework::DataLayout::kNCHW && - data_layout != framework::DataLayout::kNHWC) - return false; + if (desc.HasAttr("data_layout")) { + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, desc.GetAttr("data_layout"))); + if (data_layout != framework::DataLayout::kNCHW && + data_layout != framework::DataLayout::kNHWC) + return false; + } auto interp_method = BOOST_GET_CONST(std::string, desc.GetAttr("interp_method")); if (interp_method != "nearest") return false; - - if (!desc.HasAttr("scale") || !desc.HasAttr("out_h") || - !desc.HasAttr("out_w")) { - return false; - } else { - auto scale = BOOST_GET_CONST(float, desc.GetAttr("scale")); - auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h")); - auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w")); - if (!(scale > 0.f && (out_h <= 0 && out_w <= 0))) { - if (out_h <= 0) { - VLOG(3) << "out_h must be greater than 0 if scale is not set."; - return false; - } - if (out_w <= 0) { - VLOG(3) << "out_w must be greater than 0 if scale is not set."; - return false; - } + auto scale = BOOST_GET_CONST(float, desc.GetAttr("scale")); + auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w")); + auto align_corners = BOOST_GET_CONST(bool, desc.GetAttr("align_corners")); + if (!(scale > 0.f && (out_h <= 0 && out_w <= 0))) { + if (out_h <= 0) { + VLOG(3) << "out_h must be greater than 0 if scale is not set."; + return false; } - if ((scale <= 0.f) && with_dynamic_shape) { - VLOG(3) << "dynamic shape not support scale not set."; + if (out_w <= 0) { + VLOG(3) << "out_w must be greater than 0 if scale is not set."; return false; } } + if ((scale <= 0.f) && with_dynamic_shape) { + VLOG(3) << "dynamic shape not support scale not set."; + return false; + } + // When align_corners = true, the paddle's and trt_layer's results has + // diff + if (align_corners && scale != 1) { + return false; + } } if (op_type == "nearest_interp_v2") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp.py index 134446ffa57..56c0b041da2 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp.py @@ -124,6 +124,8 @@ class TrtConvertNearestInterpTest(TrtLayerAutoScanTest): if program_config.ops[0].attrs[ 'scale'] <= 0 and self.dynamic_shape.min_input_shape: return True + if program_config.ops[0].attrs['align_corners'] == True: + return True return False self.add_skip_case( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py index 1a58a6c9dda..16eba0a043a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py @@ -61,9 +61,9 @@ class TRTNearestInterpTest(InferencePassTest): def set_params(self): self.bs = 4 - self.scale = 1 + self.scale = 0 self.channels = 3 - self.origin_shape = (32, 32) # HW + self.origin_shape = (4, 4) # HW self.resize_shape = (64, 64) # HW self.align_corners = True self.data_layout = 'NCHW' @@ -114,7 +114,7 @@ class TRTNearestInterpTest2(TRTNearestInterpTest): class TRTNearestInterpTest3(TRTNearestInterpTest): def set_params(self): self.bs = 4 - self.scale = -1 + self.scale = 0 self.channels = 3 self.origin_shape = (32, 32) # HW self.resize_shape = (64, 64) # HW -- GitLab