diff --git a/parl/__init__.py b/parl/__init__.py index 4567facaac043d28eb781df153b73b034093451d..81e9f16a7b3a1e2281cd5a4dd2090051c3588e4f 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -17,6 +17,9 @@ __version__ = "1.1" generates new PARL python API """ +# trick to solve importing error +from tensorboardX import SummaryWriter + from parl.utils.utils import _HAS_FLUID if _HAS_FLUID: diff --git a/parl/utils/tensorboard.py b/parl/utils/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe2c5bd56e1e93aef8a675b9ab059096bb5d58e --- /dev/null +++ b/parl/utils/tensorboard.py @@ -0,0 +1,26 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorboardX import SummaryWriter +from parl.utils import logger + +__all__ = [] + +_writer = SummaryWriter(logdir=logger.get_dir()) +_WRITTER_METHOD = ['add_scalar', 'add_histogram', 'close', 'flush'] + +# export writter functions +for func in _WRITTER_METHOD: + locals()[func] = getattr(_writer, func) + __all__.append(func) diff --git a/parl/utils/tests/tensorboard_test.py b/parl/utils/tests/tensorboard_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1649c95e784ba09f8120a80634174423499390 --- /dev/null +++ b/parl/utils/tests/tensorboard_test.py @@ -0,0 +1,35 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from parl.utils import tensorboard +import numpy as np + + +class TestUtils(unittest.TestCase): + def tearDown(self): + tensorboard.flush() + + def test_add_scalar(self): + x = range(100) + for i in x: + tensorboard.add_scalar('y=2x', i * 2, i) + + def test_add_histogram(self): + for i in range(10): + x = np.random.random(1000) + tensorboard.add_histogram('distribution centers', x + i, i) + + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py index 7def3f12393e3c534841a06745cc8857b87d5728..979da29af08219bbc8f0d9c59f83182e6f631290 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,8 @@ setup( "pyarrow==0.13.0", "scipy>=1.0.0", "cloudpickle==1.0.0", + "tensorboardX", + "tensorboard", ], classifiers=[ 'Intended Audience :: Developers',