提交 58d63fef 编写于 作者: L LielinJiang

fix step count and writer create

上级 e6c28438
...@@ -507,6 +507,7 @@ class VisualDL(Callback): ...@@ -507,6 +507,7 @@ class VisualDL(Callback):
self.log_dir = log_dir self.log_dir = log_dir
self.epochs = None self.epochs = None
self.steps = None self.steps = None
self.epoch = 0
def _is_write(self): def _is_write(self):
return ParallelEnv().local_rank == 0 return ParallelEnv().local_rank == 0
...@@ -517,20 +518,24 @@ class VisualDL(Callback): ...@@ -517,20 +518,24 @@ class VisualDL(Callback):
self.train_metrics = self.params['metrics'] self.train_metrics = self.params['metrics']
assert self.train_metrics assert self.train_metrics
self._is_fit = True self._is_fit = True
self.train_step = 0
def on_epoch_begin(self, epoch=None, logs=None): def on_epoch_begin(self, epoch=None, logs=None):
visualdl = try_import('visualdl')
self.steps = self.params['steps'] self.steps = self.params['steps']
self.epoch = epoch self.epoch = epoch
self.train_step = 0
self.train_writer = visualdl.LogWriter(self.log_dir)
def _updates(self, logs, mode): def _updates(self, logs, mode):
if not self._is_write():
return
if not hasattr(self, 'writer'):
visualdl = try_import('visualdl')
self.writer = visualdl.LogWriter(self.log_dir)
metrics = getattr(self, '%s_metrics' % (mode)) metrics = getattr(self, '%s_metrics' % (mode))
writer = getattr(self, '%s_writer' % (mode))
current_step = getattr(self, '%s_step' % (mode)) current_step = getattr(self, '%s_step' % (mode))
if mode == 'train': if mode == 'train':
total_step = self.epoch * self.steps + current_step total_step = current_step
else: else:
total_step = self.epoch total_step = self.epoch
...@@ -544,7 +549,8 @@ class VisualDL(Callback): ...@@ -544,7 +549,8 @@ class VisualDL(Callback):
temp_value = logs[k] temp_value = logs[k]
else: else:
continue continue
writer.add_scalar(
self.writer.add_scalar(
tag=temp_tag, step=total_step, value=temp_value) tag=temp_tag, step=total_step, value=temp_value)
def on_train_batch_end(self, step, logs=None): def on_train_batch_end(self, step, logs=None):
...@@ -552,30 +558,23 @@ class VisualDL(Callback): ...@@ -552,30 +558,23 @@ class VisualDL(Callback):
self.train_step += 1 self.train_step += 1
if self._is_write(): if self._is_write():
if self.steps is None or self.train_step < self.steps:
self._updates(logs, 'train')
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if self._is_write() and (self.steps is not None):
self._updates(logs, 'train') self._updates(logs, 'train')
def on_eval_begin(self, logs=None): def on_eval_begin(self, logs=None):
visualdl = try_import('visualdl')
self.eval_steps = logs.get('steps', None) self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', []) self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0 self.eval_step = 0
self.evaled_samples = 0 self.evaled_samples = 0
self.eval_writer = visualdl.LogWriter(self.log_dir)
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
if hasattr(self, 'train_writer'): if hasattr(self, 'writer'):
self.train_writer.close() self.writer.close()
if hasattr(self, 'eval_writer'): delattr(self, 'writer')
self.eval_writer.close()
def on_eval_end(self, logs=None): def on_eval_end(self, logs=None):
self._updates(logs, 'eval') if self._is_write():
self._updates(logs, 'eval')
if (not hasattr(self, '_is_fit')) and hasattr(self, 'eval_writer'): if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
self.eval_writer.close() self.writer.close()
delattr(self, 'writer')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册