提交 58d63fef 编写于 作者: L LielinJiang

fix step count and writer create

上级 e6c28438
......@@ -507,6 +507,7 @@ class VisualDL(Callback):
self.log_dir = log_dir
self.epochs = None
self.steps = None
self.epoch = 0
def _is_write(self):
return ParallelEnv().local_rank == 0
......@@ -517,20 +518,24 @@ class VisualDL(Callback):
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._is_fit = True
self.train_step = 0
def on_epoch_begin(self, epoch=None, logs=None):
visualdl = try_import('visualdl')
self.steps = self.params['steps']
self.epoch = epoch
self.train_step = 0
self.train_writer = visualdl.LogWriter(self.log_dir)
def _updates(self, logs, mode):
if not self._is_write():
return
if not hasattr(self, 'writer'):
visualdl = try_import('visualdl')
self.writer = visualdl.LogWriter(self.log_dir)
metrics = getattr(self, '%s_metrics' % (mode))
writer = getattr(self, '%s_writer' % (mode))
current_step = getattr(self, '%s_step' % (mode))
if mode == 'train':
total_step = self.epoch * self.steps + current_step
total_step = current_step
else:
total_step = self.epoch
......@@ -544,7 +549,8 @@ class VisualDL(Callback):
temp_value = logs[k]
else:
continue
writer.add_scalar(
self.writer.add_scalar(
tag=temp_tag, step=total_step, value=temp_value)
def on_train_batch_end(self, step, logs=None):
......@@ -552,30 +558,23 @@ class VisualDL(Callback):
self.train_step += 1
if self._is_write():
if self.steps is None or self.train_step < self.steps:
self._updates(logs, 'train')
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if self._is_write() and (self.steps is not None):
self._updates(logs, 'train')
def on_eval_begin(self, logs=None):
visualdl = try_import('visualdl')
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0
self.eval_writer = visualdl.LogWriter(self.log_dir)
def on_train_end(self, logs=None):
if hasattr(self, 'train_writer'):
self.train_writer.close()
if hasattr(self, 'eval_writer'):
self.eval_writer.close()
if hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
def on_eval_end(self, logs=None):
self._updates(logs, 'eval')
if self._is_write():
self._updates(logs, 'eval')
if (not hasattr(self, '_is_fit')) and hasattr(self, 'eval_writer'):
self.eval_writer.close()
if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册