未验证 提交 24843fcb 编写于 作者: Q Qi Li 提交者: GitHub

[Cherry-pick] Fix ut tempfile v23 (#43387)

* fix unit test temp file, test=develop (#43155)

* add cleanup code, test=develop (#43305)
上级 689e0999
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
import unittest import unittest
import os
import numpy as np import numpy as np
import tempfile
import paddle import paddle
import paddle.static import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
...@@ -29,6 +31,11 @@ class TestBase(IPUOpTest): ...@@ -29,6 +31,11 @@ class TestBase(IPUOpTest):
self.set_data_feed() self.set_data_feed()
self.set_feed_attr() self.set_feed_attr()
self.set_attrs() self.set_attrs()
self.temp_dir = tempfile.TemporaryDirectory()
self.model_path = os.path.join(self.temp_dir.name, "weight_decay")
def tearDown(self):
self.temp_dir.cleanup()
def set_atol(self): def set_atol(self):
self.atol = 1e-6 self.atol = 1e-6
...@@ -83,7 +90,7 @@ class TestBase(IPUOpTest): ...@@ -83,7 +90,7 @@ class TestBase(IPUOpTest):
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
paddle.static.save(main_prog, "weight_decay") paddle.static.save(main_prog, self.model_path)
if run_ipu: if run_ipu:
feed_list = [image.name] feed_list = [image.name]
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import os
import tempfile
from op_test import OpTest, randomize_probability from op_test import OpTest, randomize_probability
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -29,6 +31,8 @@ class TestLoadOpXpu(unittest.TestCase): ...@@ -29,6 +31,8 @@ class TestLoadOpXpu(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.model_path = os.path.join(self.temp_dir.name, "model")
self.ones = np.ones((4, 4)).astype('float32') self.ones = np.ones((4, 4)).astype('float32')
main_prog = fluid.Program() main_prog = fluid.Program()
start_prog = fluid.Program() start_prog = fluid.Program()
...@@ -44,14 +48,17 @@ class TestLoadOpXpu(unittest.TestCase): ...@@ -44,14 +48,17 @@ class TestLoadOpXpu(unittest.TestCase):
exe = fluid.Executor(fluid.XPUPlace(0)) exe = fluid.Executor(fluid.XPUPlace(0))
exe.run(start_prog) exe.run(start_prog)
fluid.io.save_persistables( fluid.io.save_persistables(
exe, dirname="./model", main_program=main_prog) exe, dirname=self.model_path, main_program=main_prog)
def tearDown(self):
self.temp_dir.cleanup()
def test_load_xpu(self): def test_load_xpu(self):
main_prog = fluid.Program() main_prog = fluid.Program()
start_prog = fluid.Program() start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog): with fluid.program_guard(main_prog, start_prog):
var = layers.create_tensor(dtype='float32') var = layers.create_tensor(dtype='float32')
layers.load(var, file_path='./model/w') layers.load(var, file_path=self.model_path + '/w')
exe = fluid.Executor(fluid.XPUPlace(0)) exe = fluid.Executor(fluid.XPUPlace(0))
exe.run(start_prog) exe.run(start_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册