未验证 提交 942bb211 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add ips to the benchmark information. (#47)

上级 2b570d02
...@@ -77,10 +77,13 @@ class Trainer: ...@@ -77,10 +77,13 @@ class Trainer:
self.model.set_input(data) self.model.set_input(data)
self.model.optimize_parameters() self.model.optimize_parameters()
batch_cost_averager.record(time.time() - step_start_time) batch_cost_averager.record(
time.time() - step_start_time,
num_samples=self.cfg.get('batch_size', 1))
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average() self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average() self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average()
self.print_log() self.print_log()
reader_cost_averager.reset() reader_cost_averager.reset()
...@@ -91,8 +94,8 @@ class Trainer: ...@@ -91,8 +94,8 @@ class Trainer:
step_start_time = time.time() step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - self.logger.info(
start_time)) 'train one epoch time: {}'.format(time.time() - start_time))
if self.validate_interval > -1 and epoch % self.validate_interval: if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate() self.validate()
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
...@@ -102,8 +105,8 @@ class Trainer: ...@@ -102,8 +105,8 @@ class Trainer:
def validate(self): def validate(self):
if not hasattr(self, 'val_dataloader'): if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val, self.val_dataloader = build_dataloader(
is_train=False) self.cfg.dataset.val, is_train=False)
metric_result = {} metric_result = {}
...@@ -149,8 +152,8 @@ class Trainer: ...@@ -149,8 +152,8 @@ class Trainer:
self.visual('visual_val', visual_results=visual_results) self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' % self.logger.info(
(i, len(self.val_dataloader))) 'val iter: [%d/%d]' % (i, len(self.val_dataloader)))
for metric_name in metric_result.keys(): for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset) metric_result[metric_name] /= len(self.val_dataloader.dataset)
...@@ -160,8 +163,8 @@ class Trainer: ...@@ -160,8 +163,8 @@ class Trainer:
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, self.test_dataloader = build_dataloader(
is_train=False) self.cfg.dataset.test, is_train=False)
# data[0]: img, data[1]: img path index # data[0]: img, data[1]: img path index
# test batch size must be 1 # test batch size must be 1
...@@ -185,8 +188,8 @@ class Trainer: ...@@ -185,8 +188,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results) self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % self.logger.info(
(i, len(self.test_dataloader))) 'Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
def print_log(self): def print_log(self):
losses = self.model.get_current_losses() losses = self.model.get_current_losses()
...@@ -197,11 +200,14 @@ class Trainer: ...@@ -197,11 +200,14 @@ class Trainer:
for k, v in losses.items(): for k, v in losses.items():
message += '%s: %.3f ' % (k, v) message += '%s: %.3f ' % (k, v)
if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time
if hasattr(self, 'data_time'): if hasattr(self, 'data_time'):
message += 'reader cost: %.5fs ' % self.data_time message += 'reader_cost: %.5f sec ' % self.data_time
if hasattr(self, 'step_time'): if hasattr(self, 'ips'):
message += 'batch cost: %.5fs' % self.step_time message += 'ips: %.5f images/s' % self.ips
# print the message # print the message
self.logger.info(message) self.logger.info(message)
......
...@@ -22,12 +22,20 @@ class TimeAverager(object): ...@@ -22,12 +22,20 @@ class TimeAverager(object):
def reset(self): def reset(self):
self._cnt = 0 self._cnt = 0
self._total_time = 0 self._total_time = 0
self._total_samples = 0
def record(self, usetime): def record(self, usetime, num_samples=None):
self._cnt += 1 self._cnt += 1
self._total_time += usetime self._total_time += usetime
if num_samples:
self._total_samples += num_samples
def get_average(self): def get_average(self):
if self._cnt == 0: if self._cnt == 0:
return 0 return 0
return self._total_time / self._cnt return self._total_time / float(self._cnt)
def get_ips_average(self):
if not self._total_samples or self._cnt == 0:
return 0
return float(self._total_samples) / self._total_time
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册