未验证 提交 9623fc43 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Add VisualDL to replace tb. (#509)

上级 9b6c6837
...@@ -144,7 +144,7 @@ class TqdmNotebookProgressBarHook(RunHook): ...@@ -144,7 +144,7 @@ class TqdmNotebookProgressBarHook(RunHook):
class LoggingHook(RunHook): class LoggingHook(RunHook):
"""log tensor in to screan and tensorboard""" """log tensor in to screan and VisualDL"""
def __init__(self, def __init__(self,
loss, loss,
...@@ -205,7 +205,7 @@ class LoggingHook(RunHook): ...@@ -205,7 +205,7 @@ class LoggingHook(RunHook):
speed = -1. speed = -1.
self.last_state = state self.last_state = state
# log to tensorboard # log to VisualDL
if self.writer is not None: if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep) self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np): for name, t in zip(self.s_name, s_np):
......
...@@ -48,11 +48,11 @@ __all__ = ['train_and_eval', 'Learner'] ...@@ -48,11 +48,11 @@ __all__ = ['train_and_eval', 'Learner']
def _get_summary_writer(path): def _get_summary_writer(path):
summary_writer = None summary_writer = None
try: try:
from tensorboardX import SummaryWriter from visualdl import LogWriter
if distribution.status.is_master: if distribution.status.is_master:
summary_writer = SummaryWriter(os.path.join(path)) summary_writer = LogWriter(os.path.join(path))
except ImportError: except ImportError:
log.warning('tensorboardX not installed, will not log to tensorboard') log.warning('VisualDL not installed, will not log to VisualDL')
return summary_writer return summary_writer
...@@ -69,7 +69,7 @@ def _log_eval_result(name, eval_result, swriter, state): ...@@ -69,7 +69,7 @@ def _log_eval_result(name, eval_result, swriter, state):
printable.append('{}\t{}'.format(n, val)) printable.append('{}\t{}'.format(n, val))
if swriter is not None: if swriter is not None:
swriter.add_scalar(n, val, state.gstep) swriter.add_scalar(n, val, state.gstep)
log.debug('write to tensorboard %s' % swriter.logdir) log.debug('write to VisualDL %s' % swriter.logdir)
if len(printable): if len(printable):
log.info('*** eval res: %10s ***' % name) log.info('*** eval res: %10s ***' % name)
......
...@@ -4,3 +4,4 @@ six==1.11.0 ...@@ -4,3 +4,4 @@ six==1.11.0
sklearn==0.0 sklearn==0.0
sentencepiece==0.1.8 sentencepiece==0.1.8
jieba==0.39 jieba==0.39
visualdl>=2.0.0b7
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册