From d90db9bd27bcef082a48f1c58a904d6df803741e Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Thu, 29 Sep 2022 11:27:38 +0800 Subject: [PATCH] Fix the half precision problem of general plugin (#46580) --- paddle/fluid/inference/tensorrt/op_teller.cc | 5 ++++- .../fluid/inference/tensorrt/plugin/generic_plugin.cu | 7 ++++++- .../ir/inference/test_trt_convert_instance_norm.py | 11 ----------- .../ir/inference/test_trt_convert_yolo_box.py | 11 ----------- 4 files changed, 10 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 3db57ea196..5756efefa5 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2326,7 +2326,10 @@ struct GenericPluginTeller : public Teller { if (!with_dynamic_shape) { return false; } - + if (op_type == "yolo_box") { + if (!desc.HasAttr("iou_aware") && !desc.HasAttr("iou_aware_factor")) + return false; + } if (use_no_calib_int8) { return false; } else { diff --git a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu index febabc6d8e..d9afa475bf 100644 --- a/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu @@ -287,7 +287,12 @@ bool GenericPlugin::supportsFormatCombination( const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, int nb_outputs) TRT_NOEXCEPT { - return true; + if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") { + if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT; + if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32; + } else { + return in_out[pos].type == nvinfer1::DataType::kFLOAT; + } } nvinfer1::DataType GenericPlugin::getOutputDataType( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py index a9eef0e296..99c34a587b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_instance_norm.py @@ -144,17 +144,6 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest): def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if len( - self.dynamic_shape.min_input_shape - ) != 0 and self.trt_param.precision == paddle_infer.PrecisionType.Half: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output has diff between gpu and trt in dynamic fp16 mode.") - def teller2(program_config, predictor_config): if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': return True diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py index 001c1a1ccb..ddab5b8c52 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_yolo_box.py @@ -165,17 +165,6 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest): def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if len( - self.dynamic_shape.min_input_shape - ) != 0 and self.trt_param.precision == paddle_infer.PrecisionType.Half: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output has diff between gpu and trt in dynamic fp16 mode.") - def teller2(program_config, predictor_config): if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': return True -- GitLab