diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index cd463ea0405f56f29699e8222a9ec69ac64f6445..023ceeaa73acc29c74881640284662c079bca4c5 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -23,6 +23,7 @@ import unittest import os import sys import math +import tempfile class TestPassBuilder(unittest.TestCase): @@ -98,17 +99,17 @@ class TestPassBuilder(unittest.TestCase): pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) - current_path = os.path.abspath(os.path.dirname(__file__)) - graph_viz_path = current_path + os.sep + 'tmp' + os.sep + 'test_viz_pass' - viz_pass.set("graph_viz_path", graph_viz_path) - - self.check_network_convergence( - use_cuda=core.is_compiled_with_cuda(), - build_strategy=build_strategy) - try: - os.stat(graph_viz_path) - except os.error: - self.assertFalse(True) + with tempfile.TemporaryDirectory(prefix="dot_path_") as tmpdir: + graph_viz_path = os.path.join(tmpdir, 'test_viz_pass.dot') + viz_pass.set("graph_viz_path", graph_viz_path) + + self.check_network_convergence( + use_cuda=core.is_compiled_with_cuda(), + build_strategy=build_strategy) + try: + os.stat(graph_viz_path) + except os.error: + self.assertFalse(True) if __name__ == '__main__':