未验证 提交 31a04814 编写于 作者: L littletomatodonkey 提交者: GitHub

fix logger (#423)

* fix logger

* fix typo
上级 82fceb5b
...@@ -479,18 +479,24 @@ def run(dataloader, ...@@ -479,18 +479,24 @@ def run(dataloader,
metric_list = [f[1] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()]
for m in metric_list: for m in metric_list:
m.reset() m.reset()
batch_time = AverageMeter('elapse', '.3f') batch_time = AverageMeter('elapse', '.5f')
tic = time.time() tic = time.time()
dataloader = dataloader if config.get('use_dali') else dataloader()() dataloader = dataloader if config.get('use_dali') else dataloader()()
for idx, batch in enumerate(dataloader): for idx, batch in enumerate(dataloader):
if idx == 10:
for m in metric_list:
m.reset()
batch_time.reset()
batch_size = batch[0]["feed_image"].shape()[0]
metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list) metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
batch_time.update(time.time() - tic) batch_time.update(time.time() - tic)
tic = time.time()
for i, m in enumerate(metrics): for i, m in enumerate(metrics):
metric_list[i].update(np.mean(m), len(batch[0])) metric_list[i].update(np.mean(m), batch_size)
fetchs_str = ''.join([str(m.value) + ' ' fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value]) + 's' for m in metric_list] + [batch_time.value]) + 's'
ips_info = " ips: {:.5f} images/sec.".format(batch_size /
batch_time.val)
fetchs_str += ips_info
if vdl_writer: if vdl_writer:
global total_step global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_writer) logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
...@@ -502,37 +508,26 @@ def run(dataloader, ...@@ -502,37 +508,26 @@ def run(dataloader,
else: else:
epoch_str = "epoch:{:<3d}".format(epoch) epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx) step_str = "{:s} step:{:<4d}".format(mode, idx)
# Keep the first 10 batches statistics, They are important for develop
if epoch == 0 and idx < 10:
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN')))
else:
if idx % config.get('print_interval', 10) == 0: if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} {:s} {:s}".format( logger.info("{:s} {:s} {:s}".format(epoch_str
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str, if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"), step_str, fetchs_str))
logger.coloring(fetchs_str, 'OKGREEN'))) tic = time.time()
if config.get('use_dali'): if config.get('use_dali'):
dataloader.reset() dataloader.reset()
end_str = ''.join([str(m.mean) + ' ' end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total]) + 's' for m in metric_list] + [batch_time.total]) + 's'
ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count /
batch_time.sum)
if mode == 'eval': if mode == 'eval':
logger.info("END {:s} {:s}s".format(mode, end_str)) logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info))
else: else:
end_epoch_str = "END epoch:{:<3d}".format(epoch) end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
logger.info("{:s} {:s} {:s}".format( ips_info))
logger.coloring(end_epoch_str, "RED"),
logger.coloring(mode, "PURPLE"),
logger.coloring(end_str, "OKGREEN")))
# return top1_acc in order to save the best model # return top1_acc in order to save the best model
if mode == 'valid': if mode == 'valid':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册