未验证 提交 623dce83 编写于 作者: L Leo Chen 提交者: GitHub

Fix TRT UT failures (#47488)

上级 20db5221
...@@ -382,7 +382,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( ...@@ -382,7 +382,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
def generate_input(shape): def generate_input(shape):
return np.random.random(shape).astype(np.float32) return np.random.random(shape).astype(np.float32)
for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]: for shape in [[4], [4, 32], [2, 32, 16], [1, 8, 16, 32]]:
for op_type in [ for op_type in [
"elementwise_add", "elementwise_add",
"elementwise_mul", "elementwise_mul",
...@@ -464,8 +464,8 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( ...@@ -464,8 +464,8 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
"input_data2": [128, 128, 256], "input_data2": [128, 128, 256],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 64, 64], "input_data1": [2, 32, 16],
"input_data2": [2, 64, 64], "input_data2": [2, 32, 16],
} }
elif self.dims == 4: elif self.dims == 4:
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
......
...@@ -129,7 +129,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): ...@@ -129,7 +129,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False attrs, False
), 1e-3 ), (1e-3, 1e-3)
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
...@@ -140,7 +140,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): ...@@ -140,7 +140,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num( yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True attrs, True
), 1e-3 ), (1e-3, 1e-3)
def add_skip_trt_case(self): def add_skip_trt_case(self):
pass pass
......
...@@ -20,6 +20,7 @@ from functools import partial ...@@ -20,6 +20,7 @@ from functools import partial
from typing import Any, Dict, List from typing import Any, Dict, List
import unittest import unittest
import itertools import itertools
import copy
class TrtConvertPool2dTest(TrtLayerAutoScanTest): class TrtConvertPool2dTest(TrtLayerAutoScanTest):
...@@ -188,6 +189,39 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -188,6 +189,39 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
"The results of some cases are Nan, but the results of TensorRT and GPU are the same.", "The results of some cases are Nan, but the results of TensorRT and GPU are the same.",
) )
def assert_tensors_near(
self,
atol: float,
rtol: float,
tensor: Dict[str, np.array],
baseline: Dict[str, np.array],
):
for key, arr in tensor.items():
self.assertEqual(
baseline[key].shape,
arr.shape,
'The output shapes are not equal, the baseline shape is '
+ str(baseline[key].shape)
+ ', but got '
+ str(arr.shape),
)
# The result of Pool2d may have some elements that is the least value (-65504 for FP16),
# but for FP32 and FP16 precision, their least value are different.
# We set a threshold that is the least value of FP16,
# and make the values less than the threshold to be the threshold.
def align_less_threshold(arr, threshold):
return np.clip(arr, threshold, None)
fp16_min = np.finfo(np.float16).min
baseline_threshold = align_less_threshold(
copy.deepcopy(baseline[key]), fp16_min
)
arr_threshold = align_less_threshold(copy.deepcopy(arr), fp16_min)
np.testing.assert_allclose(
baseline_threshold, arr_threshold, rtol=rtol, atol=atol
)
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册