From 12c78eee5c53d785a39efa2564fef16207d451ab Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 23 Sep 2020 13:32:03 +0000 Subject: [PATCH] Calculate the average time for benchmark. --- ppgan/engine/trainer.py | 34 +++++++++++++++++++++------------- ppgan/utils/timer.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 13 deletions(-) create mode 100644 ppgan/utils/timer.py diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 65d8798..fb3de0b 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -11,6 +11,7 @@ from ..datasets.builder import build_dataloader from ..models.builder import build_model from ..utils.visual import tensor2img, save_image from ..utils.filesystem import save, load, makedirs +from ..utils.timer import TimeAverager from ..metric.psnr_ssim import calculate_psnr, calculate_ssim @@ -61,30 +62,37 @@ class Trainer: paddle.DataParallel(net, strategy)) def train(self): + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() for epoch in range(self.start_epoch, self.epochs): self.current_epoch = epoch start_time = step_start_time = time.time() for i, data in enumerate(self.train_dataloader): - data_time = time.time() + reader_cost_averager.record(time.time() - step_start_time) + self.batch_id = i # unpack data from dataset and apply preprocessing # data input should be dict self.model.set_input(data) self.model.optimize_parameters() - self.data_time = data_time - step_start_time - self.step_time = time.time() - step_start_time + batch_cost_averager.record(time.time() - step_start_time) if i % self.log_interval == 0: + self.data_time = reader_cost_averager.get_average() + self.step_time = batch_cost_averager.get_average() self.print_log() + reader_cost_averager.reset() + batch_cost_averager.reset() + if i % self.visual_interval == 0: self.visual('visual_train') step_start_time = time.time() - self.logger.info('train one epoch time: {}'.format(time.time() - - start_time)) + self.logger.info( + 'train one epoch time: {}'.format(time.time() - start_time)) if self.validate_interval > -1 and epoch % self.validate_interval: self.validate() self.model.lr_scheduler.step() @@ -94,8 +102,8 @@ class Trainer: def validate(self): if not hasattr(self, 'val_dataloader'): - self.val_dataloader = build_dataloader(self.cfg.dataset.val, - is_train=False) + self.val_dataloader = build_dataloader( + self.cfg.dataset.val, is_train=False) metric_result = {} @@ -141,8 +149,8 @@ class Trainer: self.visual('visual_val', visual_results=visual_results) if i % self.log_interval == 0: - self.logger.info('val iter: [%d/%d]' % - (i, len(self.val_dataloader))) + self.logger.info( + 'val iter: [%d/%d]' % (i, len(self.val_dataloader))) for metric_name in metric_result.keys(): metric_result[metric_name] /= len(self.val_dataloader.dataset) @@ -152,8 +160,8 @@ class Trainer: def test(self): if not hasattr(self, 'test_dataloader'): - self.test_dataloader = build_dataloader(self.cfg.dataset.test, - is_train=False) + self.test_dataloader = build_dataloader( + self.cfg.dataset.test, is_train=False) # data[0]: img, data[1]: img path index # test batch size must be 1 @@ -177,8 +185,8 @@ class Trainer: self.visual('visual_test', visual_results=visual_results) if i % self.log_interval == 0: - self.logger.info('Test iter: [%d/%d]' % - (i, len(self.test_dataloader))) + self.logger.info( + 'Test iter: [%d/%d]' % (i, len(self.test_dataloader))) def print_log(self): losses = self.model.get_current_losses() diff --git a/ppgan/utils/timer.py b/ppgan/utils/timer.py new file mode 100644 index 0000000..6b277e5 --- /dev/null +++ b/ppgan/utils/timer.py @@ -0,0 +1,33 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + + +class TimeAverager(object): + def __init__(self): + self.reset() + + def reset(self): + self._cnt = 0 + self._total_time = 0 + + def record(self, usetime): + self._cnt += 1 + self._total_time += usetime + + def get_average(self): + if self._cnt == 0: + return 0 + return self._total_time / self._cnt -- GitLab