diff --git a/examples/A2C/atari_model.py b/examples/A2C/atari_model.py index 54f74bf55956f8b645968a098dd6a2214139624d..5cc7bb4f3578e8f6c7c560be010a3b2d223a9e70 100644 --- a/examples/A2C/atari_model.py +++ b/examples/A2C/atari_model.py @@ -14,7 +14,6 @@ import parl import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr from parl import layers diff --git a/parl/utils/logger.py b/parl/utils/logger.py index b53dba5a045859a0986dedafe20887cae73c3cea..61353a4f4ab02e99fc9e1983a95020c9594e9427 100644 --- a/parl/utils/logger.py +++ b/parl/utils/logger.py @@ -18,6 +18,7 @@ import os import os.path import sys from termcolor import colored +import shutil __all__ = ['set_dir', 'get_dir', 'set_level'] @@ -140,5 +141,6 @@ mod = sys.modules['__main__'] if hasattr(mod, '__file__'): basename = os.path.basename(mod.__file__) auto_dirname = os.path.join('log_dir', basename[:basename.rfind('.')]) + shutil.rmtree(auto_dirname, ignore_errors=True) set_dir(auto_dirname) _logger.info("Argv: " + ' '.join(sys.argv)) diff --git a/parl/utils/tensorboard.py b/parl/utils/tensorboard.py index 1fe2c5bd56e1e93aef8a675b9ab059096bb5d58e..56d1c8687908f1d80d660a8ca1946d64668b31cc 100644 --- a/parl/utils/tensorboard.py +++ b/parl/utils/tensorboard.py @@ -17,10 +17,23 @@ from parl.utils import logger __all__ = [] -_writer = SummaryWriter(logdir=logger.get_dir()) +_writer = None _WRITTER_METHOD = ['add_scalar', 'add_histogram', 'close', 'flush'] + +def create_file_after_first_call(func_name): + def call(*args, **kwargs): + global _writer + if _writer is None: + _writer = SummaryWriter(logdir=logger.get_dir()) + func = getattr(_writer, func_name) + func(*args, **kwargs) + _writer.flush() + + return call + + # export writter functions -for func in _WRITTER_METHOD: - locals()[func] = getattr(_writer, func) - __all__.append(func) +for func_name in _WRITTER_METHOD: + locals()[func_name] = create_file_after_first_call(func_name) + __all__.append(func_name)