未验证 提交 0354fba7 编写于 作者: D dyning 提交者: GitHub

Merge pull request #86 from WuHaobo/master

polish mixup and eval
......@@ -74,7 +74,7 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog)
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'valid')
'eval')
if __name__ == '__main__':
......
......@@ -385,19 +385,19 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value])
if epoch != -1:
if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
else:
logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format(
epoch, mode, idx, fetchs_str))
else:
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
end_str = ''.join([str(m.mean) + ' '
for m in metric_list] + [batch_time.total])
if epoch != -1:
logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str))
else:
if mode == 'eval':
logger.info("END {:s} {:s}s".format(mode, end_str))
else:
logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str))
# save the best model
top1_acc = fetchs["top1"][1].avg
return top1_acc
# return top1_acc in order to save the best model
if mode == 'valid':
return fetchs["top1"][1].avg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册