From fba54ddec18f9a57bed44d306435fe960f5e7356 Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Sun, 25 Sep 2022 16:34:37 +0800 Subject: [PATCH] Fix unit test in A10 GPU (#46450) * Disable TF32 to solve accuracy for test_trt_conv_pass and test_trt_deformable_conv in A10 GPU. --- .../fluid/tests/unittests/ir/inference/test_trt_conv_pass.py | 2 ++ .../tests/unittests/ir/inference/test_trt_deformable_conv.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py index a934c264e4..3a16f4f7d5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_conv_pass.py @@ -22,6 +22,8 @@ import paddle.fluid.core as core from paddle.fluid.core import PassVersionChecker from paddle.fluid.core import AnalysisConfig +os.environ['NVIDIA_TF32_OVERRIDE'] = '0' + class TensorRTSubgraphPassConvTest(InferencePassTest): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py index 3bed89e74f..6a0e98539a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_deformable_conv.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import numpy as np from inference_pass_test import InferencePassTest @@ -22,6 +23,8 @@ import paddle.fluid.core as core from paddle.fluid.core import PassVersionChecker from paddle.fluid.core import AnalysisConfig +os.environ['NVIDIA_TF32_OVERRIDE'] = '0' + class TRTDeformableConvTest(InferencePassTest): -- GitLab