From 58d63fef926b6153d8432cac7c72bcdc4ad11114 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Sun, 27 Sep 2020 13:41:08 +0000 Subject: [PATCH] fix step count and writer create --- python/paddle/hapi/callbacks.py | 41 ++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/python/paddle/hapi/callbacks.py b/python/paddle/hapi/callbacks.py index e8496ef69e..1c516f3575 100644 --- a/python/paddle/hapi/callbacks.py +++ b/python/paddle/hapi/callbacks.py @@ -507,6 +507,7 @@ class VisualDL(Callback): self.log_dir = log_dir self.epochs = None self.steps = None + self.epoch = 0 def _is_write(self): return ParallelEnv().local_rank == 0 @@ -517,20 +518,24 @@ class VisualDL(Callback): self.train_metrics = self.params['metrics'] assert self.train_metrics self._is_fit = True + self.train_step = 0 def on_epoch_begin(self, epoch=None, logs=None): - visualdl = try_import('visualdl') self.steps = self.params['steps'] self.epoch = epoch - self.train_step = 0 - self.train_writer = visualdl.LogWriter(self.log_dir) def _updates(self, logs, mode): + if not self._is_write(): + return + if not hasattr(self, 'writer'): + visualdl = try_import('visualdl') + self.writer = visualdl.LogWriter(self.log_dir) + metrics = getattr(self, '%s_metrics' % (mode)) - writer = getattr(self, '%s_writer' % (mode)) current_step = getattr(self, '%s_step' % (mode)) + if mode == 'train': - total_step = self.epoch * self.steps + current_step + total_step = current_step else: total_step = self.epoch @@ -544,7 +549,8 @@ class VisualDL(Callback): temp_value = logs[k] else: continue - writer.add_scalar( + + self.writer.add_scalar( tag=temp_tag, step=total_step, value=temp_value) def on_train_batch_end(self, step, logs=None): @@ -552,30 +558,23 @@ class VisualDL(Callback): self.train_step += 1 if self._is_write(): - if self.steps is None or self.train_step < self.steps: - self._updates(logs, 'train') - - def on_epoch_end(self, epoch, logs=None): - logs = logs or {} - if self._is_write() and (self.steps is not None): self._updates(logs, 'train') def on_eval_begin(self, logs=None): - visualdl = try_import('visualdl') self.eval_steps = logs.get('steps', None) self.eval_metrics = logs.get('metrics', []) self.eval_step = 0 self.evaled_samples = 0 - self.eval_writer = visualdl.LogWriter(self.log_dir) def on_train_end(self, logs=None): - if hasattr(self, 'train_writer'): - self.train_writer.close() - if hasattr(self, 'eval_writer'): - self.eval_writer.close() + if hasattr(self, 'writer'): + self.writer.close() + delattr(self, 'writer') def on_eval_end(self, logs=None): - self._updates(logs, 'eval') + if self._is_write(): + self._updates(logs, 'eval') - if (not hasattr(self, '_is_fit')) and hasattr(self, 'eval_writer'): - self.eval_writer.close() + if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'): + self.writer.close() + delattr(self, 'writer') -- GitLab