未验证 提交 01486ec1 编写于 作者: D dyning 提交者: GitHub

Merge pull request #81 from shippingwang/refine_1

refine log format
...@@ -42,18 +42,18 @@ class AverageMeter(object): ...@@ -42,18 +42,18 @@ class AverageMeter(object):
@property @property
def total(self): def total(self):
return '[{self.name}_sum: {self.sum:{self.fmt}}]'.format(self=self) return '{self.name}_sum: {self.sum:{self.fmt}}'.format(self=self)
@property @property
def total_minute(self): def total_minute(self):
return '[{self.name}_sum: {s:{self.fmt}} min]'.format( return '{self.name}_sum: {s:{self.fmt}} min'.format(
s=self.sum / 60, self=self) s=self.sum / 60, self=self)
@property @property
def mean(self): def mean(self):
return '[{self.name}_avg: {self.avg:{self.fmt}}]'.format( return '{self.name}_avg: {self.avg:{self.fmt}}'.format(
self=self) if self.need_avg else '' self=self) if self.need_avg else ''
@property @property
def value(self): def value(self):
return '[{self.name}: {self.val:{self.fmt}}]'.format(self=self) return '{self.name}: {self.val:{self.fmt}}'.format(self=self)
...@@ -73,7 +73,7 @@ def main(args): ...@@ -73,7 +73,7 @@ def main(args):
valid_dataloader.set_sample_list_generator(valid_reader, place) valid_dataloader.set_sample_list_generator(valid_reader, place)
compiled_valid_prog = program.compile(config, valid_prog) compiled_valid_prog = program.compile(config, valid_prog)
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, 0, program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'valid') 'valid')
......
...@@ -145,10 +145,6 @@ def main(): ...@@ -145,10 +145,6 @@ def main():
output = output.flatten() output = output.flatten()
if i >= 10: if i >= 10:
test_time += time.time() - start_time test_time += time.time() - start_time
cls = np.argmax(output)
score = output[cls]
logger.info("class: {0}".format(cls))
logger.info("score: {0}".format(score))
fp_message = "FP16" if args.use_fp16 else "FP32" fp_message = "FP16" if args.use_fp16 else "FP32"
logger.info("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( logger.info("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
......
...@@ -375,7 +375,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -375,7 +375,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
metric_list = [f[1] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()]
for m in metric_list: for m in metric_list:
m.reset() m.reset()
batch_time = AverageMeter('cost', '.3f') batch_time = AverageMeter('elapse', '.3f')
tic = time.time() tic = time.time()
for idx, batch in enumerate(dataloader()): for idx, batch in enumerate(dataloader()):
metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list) metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
...@@ -383,9 +383,17 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -383,9 +383,17 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
tic = time.time() tic = time.time()
for i, m in enumerate(metrics): for i, m in enumerate(metrics):
metric_list[i].update(m[0], len(batch[0])) metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([m.value fetchs_str = ''.join([str(m.value)+' '
for m in metric_list] + [batch_time.value]) for m in metric_list]+ [batch_time.value])
logger.info("[epoch:{:3d}][{:s}][step:{:4d}]{:s}".format( if epoch != -1:
logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format(
epoch, mode, idx, fetchs_str)) epoch, mode, idx, fetchs_str))
end_str = ''.join([m.mean for m in metric_list] + [batch_time.total]) else:
logger.info("END [epoch:{:3d}][{:s}]{:s}".format(epoch, mode, end_str)) logger.info("{:s} step:{:<4d} {:s}s".format(
mode, idx, fetchs_str))
end_str = ''.join([str(m.mean)+' ' for m in metric_list] + [batch_time.total])
if epoch!= -1:
logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str))
else:
logger.info("END {:s} {:s}s".format(mode, end_str))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册