From 942bb211e6c97af7d691646a8037790b3851a6eb Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 26 Oct 2020 09:44:12 +0800 Subject: [PATCH] Add ips to the benchmark information. (#47) --- ppgan/engine/trainer.py | 34 ++++++++++++++++++++-------------- ppgan/utils/timer.py | 12 ++++++++++-- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index e400e4b..0944087 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -77,10 +77,13 @@ class Trainer: self.model.set_input(data) self.model.optimize_parameters() - batch_cost_averager.record(time.time() - step_start_time) + batch_cost_averager.record( + time.time() - step_start_time, + num_samples=self.cfg.get('batch_size', 1)) if i % self.log_interval == 0: self.data_time = reader_cost_averager.get_average() self.step_time = batch_cost_averager.get_average() + self.ips = batch_cost_averager.get_ips_average() self.print_log() reader_cost_averager.reset() @@ -91,8 +94,8 @@ class Trainer: step_start_time = time.time() - self.logger.info('train one epoch time: {}'.format(time.time() - - start_time)) + self.logger.info( + 'train one epoch time: {}'.format(time.time() - start_time)) if self.validate_interval > -1 and epoch % self.validate_interval: self.validate() self.model.lr_scheduler.step() @@ -102,8 +105,8 @@ class Trainer: def validate(self): if not hasattr(self, 'val_dataloader'): - self.val_dataloader = build_dataloader(self.cfg.dataset.val, - is_train=False) + self.val_dataloader = build_dataloader( + self.cfg.dataset.val, is_train=False) metric_result = {} @@ -149,8 +152,8 @@ class Trainer: self.visual('visual_val', visual_results=visual_results) if i % self.log_interval == 0: - self.logger.info('val iter: [%d/%d]' % - (i, len(self.val_dataloader))) + self.logger.info( + 'val iter: [%d/%d]' % (i, len(self.val_dataloader))) for metric_name in metric_result.keys(): metric_result[metric_name] /= len(self.val_dataloader.dataset) @@ -160,8 +163,8 @@ class Trainer: def test(self): if not hasattr(self, 'test_dataloader'): - self.test_dataloader = build_dataloader(self.cfg.dataset.test, - is_train=False) + self.test_dataloader = build_dataloader( + self.cfg.dataset.test, is_train=False) # data[0]: img, data[1]: img path index # test batch size must be 1 @@ -185,8 +188,8 @@ class Trainer: self.visual('visual_test', visual_results=visual_results) if i % self.log_interval == 0: - self.logger.info('Test iter: [%d/%d]' % - (i, len(self.test_dataloader))) + self.logger.info( + 'Test iter: [%d/%d]' % (i, len(self.test_dataloader))) def print_log(self): losses = self.model.get_current_losses() @@ -197,11 +200,14 @@ class Trainer: for k, v in losses.items(): message += '%s: %.3f ' % (k, v) + if hasattr(self, 'step_time'): + message += 'batch_cost: %.5f sec ' % self.step_time + if hasattr(self, 'data_time'): - message += 'reader cost: %.5fs ' % self.data_time + message += 'reader_cost: %.5f sec ' % self.data_time - if hasattr(self, 'step_time'): - message += 'batch cost: %.5fs' % self.step_time + if hasattr(self, 'ips'): + message += 'ips: %.5f images/s' % self.ips # print the message self.logger.info(message) diff --git a/ppgan/utils/timer.py b/ppgan/utils/timer.py index 6b277e5..838dc75 100644 --- a/ppgan/utils/timer.py +++ b/ppgan/utils/timer.py @@ -22,12 +22,20 @@ class TimeAverager(object): def reset(self): self._cnt = 0 self._total_time = 0 + self._total_samples = 0 - def record(self, usetime): + def record(self, usetime, num_samples=None): self._cnt += 1 self._total_time += usetime + if num_samples: + self._total_samples += num_samples def get_average(self): if self._cnt == 0: return 0 - return self._total_time / self._cnt + return self._total_time / float(self._cnt) + + def get_ips_average(self): + if not self._total_samples or self._cnt == 0: + return 0 + return float(self._total_samples) / self._total_time -- GitLab