diff --git a/tools/program.py b/tools/program.py index 33e087d7085a9b7cdd5b5b385d7f5ea05a38e06a..5cee1f80f58349cc6d9ca040c4e6fee1f4d509ef 100644 --- a/tools/program.py +++ b/tools/program.py @@ -385,19 +385,20 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list] + [batch_time.value]) - if epoch != -1: + if mode == 'valid': + logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) + else: logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format( epoch, mode, idx, fetchs_str)) - else: - logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) end_str = ''.join([str(m.mean) + ' ' for m in metric_list] + [batch_time.total]) - if epoch != -1: - logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) - else: + if mode == 'valid': logger.info("END {:s} {:s}s".format(mode, end_str)) + else: + logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) # save the best model - top1_acc = fetchs["top1"][1].avg - return top1_acc + if mode == 'valid': + top1_acc = fetchs["top1"][1].avg + return top1_acc