From 623dce83f072b6ac86f1c95cb38f574ca8df0703 Mon Sep 17 00:00:00 2001 From: Leo Chen <39020268+leo0519@users.noreply.github.com> Date: Wed, 2 Nov 2022 13:56:42 +0800 Subject: [PATCH] Fix TRT UT failures (#47488) --- .../inference/test_trt_convert_elementwise.py | 6 ++-- .../inference/test_trt_convert_group_norm.py | 4 +-- .../ir/inference/test_trt_convert_pool2d.py | 34 +++++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 3699055475..8420c9cdaa 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -382,7 +382,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( def generate_input(shape): return np.random.random(shape).astype(np.float32) - for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]: + for shape in [[4], [4, 32], [2, 32, 16], [1, 8, 16, 32]]: for op_type in [ "elementwise_add", "elementwise_mul", @@ -464,8 +464,8 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( "input_data2": [128, 128, 256], } self.dynamic_shape.opt_input_shape = { - "input_data1": [2, 64, 64], - "input_data2": [2, 64, 64], + "input_data1": [2, 32, 16], + "input_data2": [2, 32, 16], } elif self.dims == 4: self.dynamic_shape.min_input_shape = { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py index 18f4093417..cc4e719585 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py @@ -129,7 +129,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( attrs, False - ), 1e-3 + ), (1e-3, 1e-3) # for dynamic_shape generate_dynamic_shape(attrs) @@ -140,7 +140,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( attrs, True - ), 1e-3 + ), (1e-3, 1e-3) def add_skip_trt_case(self): pass diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py index 7bdaab0ee8..cf57fd3beb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py @@ -20,6 +20,7 @@ from functools import partial from typing import Any, Dict, List import unittest import itertools +import copy class TrtConvertPool2dTest(TrtLayerAutoScanTest): @@ -188,6 +189,39 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): "The results of some cases are Nan, but the results of TensorRT and GPU are the same.", ) + def assert_tensors_near( + self, + atol: float, + rtol: float, + tensor: Dict[str, np.array], + baseline: Dict[str, np.array], + ): + for key, arr in tensor.items(): + self.assertEqual( + baseline[key].shape, + arr.shape, + 'The output shapes are not equal, the baseline shape is ' + + str(baseline[key].shape) + + ', but got ' + + str(arr.shape), + ) + + # The result of Pool2d may have some elements that is the least value (-65504 for FP16), + # but for FP32 and FP16 precision, their least value are different. + # We set a threshold that is the least value of FP16, + # and make the values less than the threshold to be the threshold. + def align_less_threshold(arr, threshold): + return np.clip(arr, threshold, None) + + fp16_min = np.finfo(np.float16).min + baseline_threshold = align_less_threshold( + copy.deepcopy(baseline[key]), fp16_min + ) + arr_threshold = align_less_threshold(copy.deepcopy(arr), fp16_min) + np.testing.assert_allclose( + baseline_threshold, arr_threshold, rtol=rtol, atol=atol + ) + def test(self): self.add_skip_trt_case() self.run_test() -- GitLab