diff --git a/tools/program.py b/tools/program.py index d7ccce7f07f1550e421d039c7bb87c8e3c9b40cf..9afe718ad733eb5303e7a23d870b18aff4b26a44 100644 --- a/tools/program.py +++ b/tools/program.py @@ -398,7 +398,6 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): else: logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) - # save the best model + # return top1_acc in order to save the best model if mode == 'valid': - top1_acc = fetchs["top1"][1].avg - return top1_acc + return fetchs["top1"][1].avg