From a0f1dba37fc93283c9af893d1045968bec782e90 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 30 Sep 2020 17:14:28 +0800 Subject: [PATCH] Add visualdl callback function (#27565) * add visualdl callback --- python/paddle/hapi/callbacks.py | 112 +++++++++++++++++++++++++- python/paddle/tests/test_callbacks.py | 28 +++++++ python/unittest_py/requirements.txt | 1 + 3 files changed, 140 insertions(+), 1 deletion(-) diff --git a/python/paddle/hapi/callbacks.py b/python/paddle/hapi/callbacks.py index 69b7fedd72..4a1751b331 100644 --- a/python/paddle/hapi/callbacks.py +++ b/python/paddle/hapi/callbacks.py @@ -13,12 +13,14 @@ # limitations under the License. import os +import numbers from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.utils import try_import from .progressbar import ProgressBar -__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint'] +__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL'] def config_callbacks(callbacks=None, @@ -471,3 +473,111 @@ class ModelCheckpoint(Callback): path = '{}/final'.format(self.save_dir) print('save checkpoint at {}'.format(os.path.abspath(path))) self.model.save(path) + + +class VisualDL(Callback): + """VisualDL callback function + Args: + log_dir (str): The directory to save visualdl log file. + + Examples: + .. code-block:: python + + import paddle + from paddle.static import InputSpec + + inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] + labels = [InputSpec([None, 1], 'int64', 'label')] + + train_dataset = paddle.vision.datasets.MNIST(mode='train') + eval_dataset = paddle.vision.datasets.MNIST(mode='test') + + net = paddle.vision.LeNet() + model = paddle.Model(net, inputs, labels) + + optim = paddle.optimizer.Adam(0.001, parameters=net.parameters()) + model.prepare(optimizer=optim, + loss=paddle.nn.CrossEntropyLoss(), + metrics=paddle.metric.Accuracy()) + + ## uncomment following lines to fit model with visualdl callback function + # callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir') + # model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback) + + """ + + def __init__(self, log_dir): + self.log_dir = log_dir + self.epochs = None + self.steps = None + self.epoch = 0 + + def _is_write(self): + return ParallelEnv().local_rank == 0 + + def on_train_begin(self, logs=None): + self.epochs = self.params['epochs'] + assert self.epochs + self.train_metrics = self.params['metrics'] + assert self.train_metrics + self._is_fit = True + self.train_step = 0 + + def on_epoch_begin(self, epoch=None, logs=None): + self.steps = self.params['steps'] + self.epoch = epoch + + def _updates(self, logs, mode): + if not self._is_write(): + return + if not hasattr(self, 'writer'): + visualdl = try_import('visualdl') + self.writer = visualdl.LogWriter(self.log_dir) + + metrics = getattr(self, '%s_metrics' % (mode)) + current_step = getattr(self, '%s_step' % (mode)) + + if mode == 'train': + total_step = current_step + else: + total_step = self.epoch + + for k in metrics: + if k in logs: + temp_tag = mode + '/' + k + + if isinstance(logs[k], (list, tuple)): + temp_value = logs[k][0] + elif isinstance(logs[k], numbers.Number): + temp_value = logs[k] + else: + continue + + self.writer.add_scalar( + tag=temp_tag, step=total_step, value=temp_value) + + def on_train_batch_end(self, step, logs=None): + logs = logs or {} + self.train_step += 1 + + if self._is_write(): + self._updates(logs, 'train') + + def on_eval_begin(self, logs=None): + self.eval_steps = logs.get('steps', None) + self.eval_metrics = logs.get('metrics', []) + self.eval_step = 0 + self.evaled_samples = 0 + + def on_train_end(self, logs=None): + if hasattr(self, 'writer'): + self.writer.close() + delattr(self, 'writer') + + def on_eval_end(self, logs=None): + if self._is_write(): + self._updates(logs, 'eval') + + if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'): + self.writer.close() + delattr(self, 'writer') diff --git a/python/paddle/tests/test_callbacks.py b/python/paddle/tests/test_callbacks.py index f0d9a132b9..b9442c46b8 100644 --- a/python/paddle/tests/test_callbacks.py +++ b/python/paddle/tests/test_callbacks.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import time import random import tempfile import shutil +import paddle from paddle import Model from paddle.static import InputSpec @@ -102,6 +104,32 @@ class TestCallbacks(unittest.TestCase): self.verbose = 2 self.run_callback() + def test_visualdl_callback(self): + # visualdl not support python3 + if sys.version_info < (3, ): + return + + inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] + labels = [InputSpec([None, 1], 'int64', 'label')] + + train_dataset = paddle.vision.datasets.MNIST(mode='train') + eval_dataset = paddle.vision.datasets.MNIST(mode='test') + + net = paddle.vision.LeNet() + model = paddle.Model(net, inputs, labels) + + optim = paddle.optimizer.Adam(0.001, parameters=net.parameters()) + model.prepare( + optimizer=optim, + loss=paddle.nn.CrossEntropyLoss(), + metrics=paddle.metric.Accuracy()) + + callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir') + model.fit(train_dataset, + eval_dataset, + batch_size=64, + callbacks=callback) + if __name__ == '__main__': unittest.main() diff --git a/python/unittest_py/requirements.txt b/python/unittest_py/requirements.txt index 56c8be862f..389d45fc6b 100644 --- a/python/unittest_py/requirements.txt +++ b/python/unittest_py/requirements.txt @@ -2,3 +2,4 @@ PyGithub coverage pycrypto ; platform_system != "Windows" mock +visualdl ; python_version>="3.5" -- GitLab