未验证 提交 942bb211 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add ips to the benchmark information. (#47)

上级 2b570d02
......@@ -77,10 +77,13 @@ class Trainer:
self.model.set_input(data)
self.model.optimize_parameters()
batch_cost_averager.record(time.time() - step_start_time)
batch_cost_averager.record(
time.time() - step_start_time,
num_samples=self.cfg.get('batch_size', 1))
if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average()
self.print_log()
reader_cost_averager.reset()
......@@ -91,8 +94,8 @@ class Trainer:
step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.logger.info(
'train one epoch time: {}'.format(time.time() - start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step()
......@@ -102,8 +105,8 @@ class Trainer:
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
is_train=False)
self.val_dataloader = build_dataloader(
self.cfg.dataset.val, is_train=False)
metric_result = {}
......@@ -149,8 +152,8 @@ class Trainer:
self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
self.logger.info(
'val iter: [%d/%d]' % (i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
......@@ -160,8 +163,8 @@ class Trainer:
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
self.test_dataloader = build_dataloader(
self.cfg.dataset.test, is_train=False)
# data[0]: img, data[1]: img path index
# test batch size must be 1
......@@ -185,8 +188,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
self.logger.info(
'Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
def print_log(self):
losses = self.model.get_current_losses()
......@@ -197,11 +200,14 @@ class Trainer:
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time
if hasattr(self, 'data_time'):
message += 'reader cost: %.5fs ' % self.data_time
message += 'reader_cost: %.5f sec ' % self.data_time
if hasattr(self, 'step_time'):
message += 'batch cost: %.5fs' % self.step_time
if hasattr(self, 'ips'):
message += 'ips: %.5f images/s' % self.ips
# print the message
self.logger.info(message)
......
......@@ -22,12 +22,20 @@ class TimeAverager(object):
def reset(self):
self._cnt = 0
self._total_time = 0
self._total_samples = 0
def record(self, usetime):
def record(self, usetime, num_samples=None):
self._cnt += 1
self._total_time += usetime
if num_samples:
self._total_samples += num_samples
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
return self._total_time / float(self._cnt)
def get_ips_average(self):
if not self._total_samples or self._cnt == 0:
return 0
return float(self._total_samples) / self._total_time
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册