未验证 提交 8a5bbae6 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the save path problem of UT test_pass_builder. (#33717)

上级 dd4297cd
...@@ -23,6 +23,7 @@ import unittest ...@@ -23,6 +23,7 @@ import unittest
import os import os
import sys import sys
import math import math
import tempfile
class TestPassBuilder(unittest.TestCase): class TestPassBuilder(unittest.TestCase):
...@@ -98,17 +99,17 @@ class TestPassBuilder(unittest.TestCase): ...@@ -98,17 +99,17 @@ class TestPassBuilder(unittest.TestCase):
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
current_path = os.path.abspath(os.path.dirname(__file__)) with tempfile.TemporaryDirectory(prefix="dot_path_") as tmpdir:
graph_viz_path = current_path + os.sep + 'tmp' + os.sep + 'test_viz_pass' graph_viz_path = os.path.join(tmpdir, 'test_viz_pass.dot')
viz_pass.set("graph_viz_path", graph_viz_path) viz_pass.set("graph_viz_path", graph_viz_path)
self.check_network_convergence( self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(), use_cuda=core.is_compiled_with_cuda(),
build_strategy=build_strategy) build_strategy=build_strategy)
try: try:
os.stat(graph_viz_path) os.stat(graph_viz_path)
except os.error: except os.error:
self.assertFalse(True) self.assertFalse(True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册