diff --git a/python/paddle/fluid/tests/unittests/test_checkpoint.py b/python/paddle/fluid/tests/unittests/test_checkpoint.py index b8d82c59b4e2c2a0f9b135d318f3d470976a3db9..150e8822d577be7380d826a473c92402317c0ad2 100644 --- a/python/paddle/fluid/tests/unittests/test_checkpoint.py +++ b/python/paddle/fluid/tests/unittests/test_checkpoint.py @@ -14,6 +14,7 @@ import paddle.fluid as fluid import unittest +import os class TestCheckpoint(unittest.TestCase): @@ -35,8 +36,8 @@ class TestCheckpoint(unittest.TestCase): trainer_args = ["epoch_id", "step_id"] epoch_id, step_id = fluid.io.load_trainer_args( self.dirname, serial, self.trainer_id, trainer_args) - self.assertEqual(self.step_id, step_id) - self.assertEqual(self.epoch_id, epoch_id) + self.assertEqual(self.step_id, int(step_id)) + self.assertEqual(self.epoch_id, int(epoch_id)) program = fluid.Program() with fluid.program_guard(program): @@ -44,6 +45,7 @@ class TestCheckpoint(unittest.TestCase): fluid.io.load_checkpoint(exe, self.dirname, serial, program) fluid.io.clean_checkpoint(self.dirname, delete_dir=True) + self.assertFalse(os.path.isdir(self.dirname)) def save_checkpoint(self): config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,