未验证 提交 623dce83 编写于 作者: L Leo Chen 提交者: GitHub

Fix TRT UT failures (#47488)

上级 20db5221
......@@ -382,7 +382,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]:
for shape in [[4], [4, 32], [2, 32, 16], [1, 8, 16, 32]]:
for op_type in [
"elementwise_add",
"elementwise_mul",
......@@ -464,8 +464,8 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
"input_data2": [128, 128, 256],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 64, 64],
"input_data2": [2, 64, 64],
"input_data1": [2, 32, 16],
"input_data2": [2, 32, 16],
}
elif self.dims == 4:
self.dynamic_shape.min_input_shape = {
......
......@@ -129,7 +129,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-3
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
......@@ -140,7 +140,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-3
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
......
......@@ -20,6 +20,7 @@ from functools import partial
from typing import Any, Dict, List
import unittest
import itertools
import copy
class TrtConvertPool2dTest(TrtLayerAutoScanTest):
......@@ -188,6 +189,39 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
"The results of some cases are Nan, but the results of TensorRT and GPU are the same.",
)
def assert_tensors_near(
self,
atol: float,
rtol: float,
tensor: Dict[str, np.array],
baseline: Dict[str, np.array],
):
for key, arr in tensor.items():
self.assertEqual(
baseline[key].shape,
arr.shape,
'The output shapes are not equal, the baseline shape is '
+ str(baseline[key].shape)
+ ', but got '
+ str(arr.shape),
)
# The result of Pool2d may have some elements that is the least value (-65504 for FP16),
# but for FP32 and FP16 precision, their least value are different.
# We set a threshold that is the least value of FP16,
# and make the values less than the threshold to be the threshold.
def align_less_threshold(arr, threshold):
return np.clip(arr, threshold, None)
fp16_min = np.finfo(np.float16).min
baseline_threshold = align_less_threshold(
copy.deepcopy(baseline[key]), fp16_min
)
arr_threshold = align_less_threshold(copy.deepcopy(arr), fp16_min)
np.testing.assert_allclose(
baseline_threshold, arr_threshold, rtol=rtol, atol=atol
)
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册