Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
623dce83
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
623dce83
编写于
11月 02, 2022
作者:
L
Leo Chen
提交者:
GitHub
11月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix TRT UT failures (#47488)
上级
20db5221
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
39 addition
and
5 deletion
+39
-5
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py
...ts/unittests/ir/inference/test_trt_convert_elementwise.py
+3
-3
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py
...sts/unittests/ir/inference/test_trt_convert_group_norm.py
+2
-2
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py
...d/tests/unittests/ir/inference/test_trt_convert_pool2d.py
+34
-0
未找到文件。
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py
浏览文件 @
623dce83
...
...
@@ -382,7 +382,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
def
generate_input
(
shape
):
return
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
for
shape
in
[[
4
],
[
4
,
32
],
[
2
,
64
,
32
],
[
1
,
8
,
16
,
32
]]:
for
shape
in
[[
4
],
[
4
,
32
],
[
2
,
32
,
16
],
[
1
,
8
,
16
,
32
]]:
for
op_type
in
[
"elementwise_add"
,
"elementwise_mul"
,
...
...
@@ -464,8 +464,8 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
"input_data2"
:
[
128
,
128
,
256
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"input_data1"
:
[
2
,
64
,
64
],
"input_data2"
:
[
2
,
64
,
64
],
"input_data1"
:
[
2
,
32
,
16
],
"input_data2"
:
[
2
,
32
,
16
],
}
elif
self
.
dims
==
4
:
self
.
dynamic_shape
.
min_input_shape
=
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py
浏览文件 @
623dce83
...
...
@@ -129,7 +129,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
False
),
1e-3
),
(
1e-3
,
1e-3
)
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
...
...
@@ -140,7 +140,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest):
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-3
),
(
1e-3
,
1e-3
)
def
add_skip_trt_case
(
self
):
pass
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py
浏览文件 @
623dce83
...
...
@@ -20,6 +20,7 @@ from functools import partial
from
typing
import
Any
,
Dict
,
List
import
unittest
import
itertools
import
copy
class
TrtConvertPool2dTest
(
TrtLayerAutoScanTest
):
...
...
@@ -188,6 +189,39 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
"The results of some cases are Nan, but the results of TensorRT and GPU are the same."
,
)
def
assert_tensors_near
(
self
,
atol
:
float
,
rtol
:
float
,
tensor
:
Dict
[
str
,
np
.
array
],
baseline
:
Dict
[
str
,
np
.
array
],
):
for
key
,
arr
in
tensor
.
items
():
self
.
assertEqual
(
baseline
[
key
].
shape
,
arr
.
shape
,
'The output shapes are not equal, the baseline shape is '
+
str
(
baseline
[
key
].
shape
)
+
', but got '
+
str
(
arr
.
shape
),
)
# The result of Pool2d may have some elements that is the least value (-65504 for FP16),
# but for FP32 and FP16 precision, their least value are different.
# We set a threshold that is the least value of FP16,
# and make the values less than the threshold to be the threshold.
def
align_less_threshold
(
arr
,
threshold
):
return
np
.
clip
(
arr
,
threshold
,
None
)
fp16_min
=
np
.
finfo
(
np
.
float16
).
min
baseline_threshold
=
align_less_threshold
(
copy
.
deepcopy
(
baseline
[
key
]),
fp16_min
)
arr_threshold
=
align_less_threshold
(
copy
.
deepcopy
(
arr
),
fp16_min
)
np
.
testing
.
assert_allclose
(
baseline_threshold
,
arr_threshold
,
rtol
=
rtol
,
atol
=
atol
)
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录