diff --git a/parl/utils/logger.py b/parl/utils/logger.py index 402c8cf1bec0b1e8fb900805cd4d5a4ee978b748..46815a417cd3a0f16a2298ca09193356855217eb 100644 --- a/parl/utils/logger.py +++ b/parl/utils/logger.py @@ -143,8 +143,8 @@ def set_dir(dirname): _FILE_HANDLER.close() del _FILE_HANDLER - if not os.path.isdir(dirname): - _makedirs(dirname) + shutil.rmtree(dirname, ignore_errors=True) + _makedirs(dirname) LOG_DIR = dirname _set_file(os.path.join(dirname, 'log.log')) diff --git a/parl/utils/tests/logger_test.py b/parl/utils/tests/logger_test.py index 8c57521aeadd284514263adcea27e3cc54c83e31..3a847f45309af388240a3e1b46e39d3d2a5a2dc0 100644 --- a/parl/utils/tests/logger_test.py +++ b/parl/utils/tests/logger_test.py @@ -15,6 +15,7 @@ import unittest from parl.utils import logger import threading as th +import os.path class TestLogger(unittest.TestCase): @@ -45,6 +46,15 @@ class TestLogger(unittest.TestCase): logger.auto_set_dir(action='n') logger.auto_set_dir(action='k') + def test_set_dir(self): + logger.set_dir('./logger_dir') + temp_file = './logger_dir/tmp.file' + with open(temp_file, 'w') as t_file: + t_file.write("Are you OK? From Mr.Lei") + self.assertTrue(os.path.isfile(temp_file)) + logger.set_dir('./logger_dir') + self.assertFalse(os.path.isfile(temp_file)) + if __name__ == '__main__': unittest.main()