From 8a5bbae6ac7a386716ff46df810073667f793959 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 22 Jun 2021 17:20:46 +0800 Subject: [PATCH] Fix the save path problem of UT test_pass_builder. (#33717) --- .../tests/unittests/test_pass_builder.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index cd463ea040..023ceeaa73 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -23,6 +23,7 @@ import unittest import os import sys import math +import tempfile class TestPassBuilder(unittest.TestCase): @@ -98,17 +99,17 @@ class TestPassBuilder(unittest.TestCase): pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) - current_path = os.path.abspath(os.path.dirname(__file__)) - graph_viz_path = current_path + os.sep + 'tmp' + os.sep + 'test_viz_pass' - viz_pass.set("graph_viz_path", graph_viz_path) - - self.check_network_convergence( - use_cuda=core.is_compiled_with_cuda(), - build_strategy=build_strategy) - try: - os.stat(graph_viz_path) - except os.error: - self.assertFalse(True) + with tempfile.TemporaryDirectory(prefix="dot_path_") as tmpdir: + graph_viz_path = os.path.join(tmpdir, 'test_viz_pass.dot') + viz_pass.set("graph_viz_path", graph_viz_path) + + self.check_network_convergence( + use_cuda=core.is_compiled_with_cuda(), + build_strategy=build_strategy) + try: + os.stat(graph_viz_path) + except os.error: + self.assertFalse(True) if __name__ == '__main__': -- GitLab