diff --git a/parl/utils/summary.py b/parl/utils/summary.py index 575fc6b9976906e43dddeb7da2ea1ef32d4644c1..bc3578ef384222a4e55b7b9af90f36d9a7fccb4c 100644 --- a/parl/utils/summary.py +++ b/parl/utils/summary.py @@ -12,34 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tensorboardX import SummaryWriter -from parl.utils import logger - -__all__ = [] - -_writer = None -_WRITTER_METHOD = ['add_scalar', 'add_histogram', 'close', 'flush'] - - -def create_file_after_first_call(func_name): - def call(*args, **kwargs): - global _writer - if _writer is None: - logdir = logger.get_dir() - if logdir is None: - logdir = logger.auto_set_dir(action='d') - logger.warning( - "[tensorboard] logdir is None, will save tensorboard files to {}" - .format(logdir)) - _writer = SummaryWriter(logdir=logger.get_dir()) - func = getattr(_writer, func_name) - func(*args, **kwargs) - _writer.flush() - - return call - - -# export writter functions -for func_name in _WRITTER_METHOD: - locals()[func_name] = create_file_after_first_call(func_name) - __all__.append(func_name) +try: + from parl.utils.visualdl import * +except: + from parl.utils.tensorboard import * diff --git a/parl/utils/tensorboard.py b/parl/utils/tensorboard.py index 575fc6b9976906e43dddeb7da2ea1ef32d4644c1..3fef518196216986f33f187c215b8aa4834003d5 100644 --- a/parl/utils/tensorboard.py +++ b/parl/utils/tensorboard.py @@ -14,6 +14,7 @@ from tensorboardX import SummaryWriter from parl.utils import logger +from parl.utils.machine_info import get_ip_address __all__ = [] @@ -29,8 +30,8 @@ def create_file_after_first_call(func_name): if logdir is None: logdir = logger.auto_set_dir(action='d') logger.warning( - "[tensorboard] logdir is None, will save tensorboard files to {}" - .format(logdir)) + "[tensorboard] logdir is None, will save tensorboard files to {}\nView the data using: tensorboard --logdir=./{} --host={}" + .format(logdir, logdir, get_ip_address())) _writer = SummaryWriter(logdir=logger.get_dir()) func = getattr(_writer, func_name) func(*args, **kwargs) diff --git a/parl/utils/tests/summary_test.py b/parl/utils/tests/summary_test.py index 670abcccc35e3039cb1e93ea3462855bb84a503a..401051c5debd3ed69d3cba54bcdda1c9ef75f12c 100644 --- a/parl/utils/tests/summary_test.py +++ b/parl/utils/tests/summary_test.py @@ -20,7 +20,8 @@ import os class TestUtils(unittest.TestCase): def tearDown(self): - summary.flush() + if hasattr(summary, 'flush'): + summary.flush() def test_add_scalar(self): x = range(100) @@ -29,6 +30,8 @@ class TestUtils(unittest.TestCase): self.assertTrue(os.path.exists('./train_log/summary_test')) def test_add_histogram(self): + if not hasattr(summary, 'add_histogram'): + return for i in range(10): x = np.random.random(1000) summary.add_histogram('distribution centers', x + i, i) diff --git a/parl/utils/visualdl.py b/parl/utils/visualdl.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf1aa08e313440a47a73936d1be61ca9701f166 --- /dev/null +++ b/parl/utils/visualdl.py @@ -0,0 +1,46 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from visualdl import LogWriter +from parl.utils import logger +from parl.utils.machine_info import get_ip_address + +__all__ = [] + +_writer = None +_WRITTER_METHOD = ['add_scalar'] + + +def create_file_after_first_call(func_name): + def call(*args, **kwargs): + global _writer + if _writer is None: + logdir = logger.get_dir() + if logdir is None: + logdir = logger.auto_set_dir(action='d') + logger.warning( + "[VisualDL] logdir is None, will save VisualDL files to {}\nView the data using: visualdl --logdir=./{} --host={}" + .format(logdir, logdir, get_ip_address())) + _writer = LogWriter(logdir=logger.get_dir()) + func = getattr(_writer, func_name) + func(*args, **kwargs) + _writer.flush() + + return call + + +# export writter functions +for func_name in _WRITTER_METHOD: + locals()[func_name] = create_file_after_first_call(func_name) + __all__.append(func_name) diff --git a/setup.py b/setup.py index 1e5e74ef9a26d7b21c736a107aee6809f3a89b33..90c6465024884e2dcc9de09a2bf737533cf417e0 100644 --- a/setup.py +++ b/setup.py @@ -77,9 +77,10 @@ setup( "cloudpickle==1.2.1", "tensorboardX==1.8", "tb-nightly==1.15.0a20190801", - "flask==1.0.4", + "flask>=1.0.4", "click", "psutil>=5.6.2", + "visualdl>=2.0.0b;python_version>='3' and platform_system=='Linux'", ], classifiers=[ 'Intended Audience :: Developers',