提交 3b5e3f9b 编写于 作者: T tangwei12

update checkpoint unittest

上级 951fa744
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
import os
class TestCheckpoint(unittest.TestCase): class TestCheckpoint(unittest.TestCase):
...@@ -35,8 +36,8 @@ class TestCheckpoint(unittest.TestCase): ...@@ -35,8 +36,8 @@ class TestCheckpoint(unittest.TestCase):
trainer_args = ["epoch_id", "step_id"] trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args( epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args) self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, step_id) self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, epoch_id) self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program() program = fluid.Program()
with fluid.program_guard(program): with fluid.program_guard(program):
...@@ -44,6 +45,7 @@ class TestCheckpoint(unittest.TestCase): ...@@ -44,6 +45,7 @@ class TestCheckpoint(unittest.TestCase):
fluid.io.load_checkpoint(exe, self.dirname, serial, program) fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True) fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self): def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints, config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册