未验证 提交 a0f1dba3 编写于 作者: L LielinJiang 提交者: GitHub

Add visualdl callback function (#27565)

* add visualdl callback
上级 9b3ef597
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import os import os
import numbers
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.utils import try_import
from .progressbar import ProgressBar from .progressbar import ProgressBar
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint'] __all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL']
def config_callbacks(callbacks=None, def config_callbacks(callbacks=None,
...@@ -471,3 +473,111 @@ class ModelCheckpoint(Callback): ...@@ -471,3 +473,111 @@ class ModelCheckpoint(Callback):
path = '{}/final'.format(self.save_dir) path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path))) print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path) self.model.save(path)
class VisualDL(Callback):
"""VisualDL callback function
Args:
log_dir (str): The directory to save visualdl log file.
Examples:
.. code-block:: python
import paddle
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
## uncomment following lines to fit model with visualdl callback function
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)
"""
def __init__(self, log_dir):
self.log_dir = log_dir
self.epochs = None
self.steps = None
self.epoch = 0
def _is_write(self):
return ParallelEnv().local_rank == 0
def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
assert self.epochs
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._is_fit = True
self.train_step = 0
def on_epoch_begin(self, epoch=None, logs=None):
self.steps = self.params['steps']
self.epoch = epoch
def _updates(self, logs, mode):
if not self._is_write():
return
if not hasattr(self, 'writer'):
visualdl = try_import('visualdl')
self.writer = visualdl.LogWriter(self.log_dir)
metrics = getattr(self, '%s_metrics' % (mode))
current_step = getattr(self, '%s_step' % (mode))
if mode == 'train':
total_step = current_step
else:
total_step = self.epoch
for k in metrics:
if k in logs:
temp_tag = mode + '/' + k
if isinstance(logs[k], (list, tuple)):
temp_value = logs[k][0]
elif isinstance(logs[k], numbers.Number):
temp_value = logs[k]
else:
continue
self.writer.add_scalar(
tag=temp_tag, step=total_step, value=temp_value)
def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step += 1
if self._is_write():
self._updates(logs, 'train')
def on_eval_begin(self, logs=None):
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0
def on_train_end(self, logs=None):
if hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
def on_eval_end(self, logs=None):
if self._is_write():
self._updates(logs, 'eval')
if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import unittest import unittest
import time import time
import random import random
import tempfile import tempfile
import shutil import shutil
import paddle
from paddle import Model from paddle import Model
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -102,6 +104,32 @@ class TestCallbacks(unittest.TestCase): ...@@ -102,6 +104,32 @@ class TestCallbacks(unittest.TestCase):
self.verbose = 2 self.verbose = 2
self.run_callback() self.run_callback()
def test_visualdl_callback(self):
# visualdl not support python3
if sys.version_info < (3, ):
return
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(
optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
model.fit(train_dataset,
eval_dataset,
batch_size=64,
callbacks=callback)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -2,3 +2,4 @@ PyGithub ...@@ -2,3 +2,4 @@ PyGithub
coverage coverage
pycrypto ; platform_system != "Windows" pycrypto ; platform_system != "Windows"
mock mock
visualdl ; python_version>="3.5"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册