提交 12c78eee 编写于 作者: L Liu Yiqun

Calculate the average time for benchmark.

上级 e41decb6
......@@ -11,6 +11,7 @@ from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs
from ..utils.timer import TimeAverager
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
......@@ -61,30 +62,37 @@ class Trainer:
paddle.DataParallel(net, strategy))
def train(self):
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
for epoch in range(self.start_epoch, self.epochs):
self.current_epoch = epoch
start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader):
data_time = time.time()
reader_cost_averager.record(time.time() - step_start_time)
self.batch_id = i
# unpack data from dataset and apply preprocessing
# data input should be dict
self.model.set_input(data)
self.model.optimize_parameters()
self.data_time = data_time - step_start_time
self.step_time = time.time() - step_start_time
batch_cost_averager.record(time.time() - step_start_time)
if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average()
self.print_log()
reader_cost_averager.reset()
batch_cost_averager.reset()
if i % self.visual_interval == 0:
self.visual('visual_train')
step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.logger.info(
'train one epoch time: {}'.format(time.time() - start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step()
......@@ -94,8 +102,8 @@ class Trainer:
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
is_train=False)
self.val_dataloader = build_dataloader(
self.cfg.dataset.val, is_train=False)
metric_result = {}
......@@ -141,8 +149,8 @@ class Trainer:
self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
self.logger.info(
'val iter: [%d/%d]' % (i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
......@@ -152,8 +160,8 @@ class Trainer:
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
self.test_dataloader = build_dataloader(
self.cfg.dataset.test, is_train=False)
# data[0]: img, data[1]: img path index
# test batch size must be 1
......@@ -177,8 +185,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
self.logger.info(
'Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
def print_log(self):
losses = self.model.get_current_losses()
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册