提交 3b5e3f9b 编写于 作者: T tangwei12

update checkpoint unittest

上级 951fa744
......@@ -14,6 +14,7 @@
import paddle.fluid as fluid
import unittest
import os
class TestCheckpoint(unittest.TestCase):
......@@ -35,8 +36,8 @@ class TestCheckpoint(unittest.TestCase):
trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, step_id)
self.assertEqual(self.epoch_id, epoch_id)
self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program()
with fluid.program_guard(program):
......@@ -44,6 +45,7 @@ class TestCheckpoint(unittest.TestCase):
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册