未验证 提交 a0f1dba3 编写于 作者: L LielinJiang 提交者: GitHub

Add visualdl callback function (#27565)

* add visualdl callback
上级 9b3ef597
......@@ -13,12 +13,14 @@
# limitations under the License.
import os
import numbers
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.utils import try_import
from .progressbar import ProgressBar
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint']
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL']
def config_callbacks(callbacks=None,
......@@ -471,3 +473,111 @@ class ModelCheckpoint(Callback):
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
class VisualDL(Callback):
"""VisualDL callback function
Args:
log_dir (str): The directory to save visualdl log file.
Examples:
.. code-block:: python
import paddle
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
## uncomment following lines to fit model with visualdl callback function
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)
"""
def __init__(self, log_dir):
self.log_dir = log_dir
self.epochs = None
self.steps = None
self.epoch = 0
def _is_write(self):
return ParallelEnv().local_rank == 0
def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
assert self.epochs
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._is_fit = True
self.train_step = 0
def on_epoch_begin(self, epoch=None, logs=None):
self.steps = self.params['steps']
self.epoch = epoch
def _updates(self, logs, mode):
if not self._is_write():
return
if not hasattr(self, 'writer'):
visualdl = try_import('visualdl')
self.writer = visualdl.LogWriter(self.log_dir)
metrics = getattr(self, '%s_metrics' % (mode))
current_step = getattr(self, '%s_step' % (mode))
if mode == 'train':
total_step = current_step
else:
total_step = self.epoch
for k in metrics:
if k in logs:
temp_tag = mode + '/' + k
if isinstance(logs[k], (list, tuple)):
temp_value = logs[k][0]
elif isinstance(logs[k], numbers.Number):
temp_value = logs[k]
else:
continue
self.writer.add_scalar(
tag=temp_tag, step=total_step, value=temp_value)
def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step += 1
if self._is_write():
self._updates(logs, 'train')
def on_eval_begin(self, logs=None):
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0
def on_train_end(self, logs=None):
if hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
def on_eval_end(self, logs=None):
if self._is_write():
self._updates(logs, 'eval')
if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import time
import random
import tempfile
import shutil
import paddle
from paddle import Model
from paddle.static import InputSpec
......@@ -102,6 +104,32 @@ class TestCallbacks(unittest.TestCase):
self.verbose = 2
self.run_callback()
def test_visualdl_callback(self):
# visualdl not support python3
if sys.version_info < (3, ):
return
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(
optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
model.fit(train_dataset,
eval_dataset,
batch_size=64,
callbacks=callback)
if __name__ == '__main__':
unittest.main()
......@@ -2,3 +2,4 @@ PyGithub
coverage
pycrypto ; platform_system != "Windows"
mock
visualdl ; python_version>="3.5"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册