From 3b5e3f9be4b97f15aac809b851cb328bbf424437 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 5 Jun 2018 18:05:06 +0800 Subject: [PATCH] update checkpoint unittest --- python/paddle/fluid/tests/unittests/test_checkpoint.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_checkpoint.py b/python/paddle/fluid/tests/unittests/test_checkpoint.py index b8d82c59b4e..150e8822d57 100644 --- a/python/paddle/fluid/tests/unittests/test_checkpoint.py +++ b/python/paddle/fluid/tests/unittests/test_checkpoint.py @@ -14,6 +14,7 @@ import paddle.fluid as fluid import unittest +import os class TestCheckpoint(unittest.TestCase): @@ -35,8 +36,8 @@ class TestCheckpoint(unittest.TestCase): trainer_args = ["epoch_id", "step_id"] epoch_id, step_id = fluid.io.load_trainer_args( self.dirname, serial, self.trainer_id, trainer_args) - self.assertEqual(self.step_id, step_id) - self.assertEqual(self.epoch_id, epoch_id) + self.assertEqual(self.step_id, int(step_id)) + self.assertEqual(self.epoch_id, int(epoch_id)) program = fluid.Program() with fluid.program_guard(program): @@ -44,6 +45,7 @@ class TestCheckpoint(unittest.TestCase): fluid.io.load_checkpoint(exe, self.dirname, serial, program) fluid.io.clean_checkpoint(self.dirname, delete_dir=True) + self.assertFalse(os.path.isdir(self.dirname)) def save_checkpoint(self): config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints, -- GitLab