diff --git a/tools/train.py b/tools/train.py index 3cd9d459d9b979efa93e3738699e88413aa891c6..45ac37f9e71a873d0e5bbe1da5f51fd03d8c76e9 100644 --- a/tools/train.py +++ b/tools/train.py @@ -98,13 +98,12 @@ def main(args): if top1_acc > best_top1_acc: best_top1_acc = top1_acc best_top1_epoch = epoch_id - if epoch_id % config.save_interval == 0: - model_path = os.path.join(config.model_save_dir, - config.ARCHITECTURE["name"]) - save_model(net, optimizer, model_path, "best_model") + model_path = os.path.join(config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(net, optimizer, model_path, "best_model") message = "The best top1 acc {:.5f}, in epoch: {:d}".format( best_top1_acc, best_top1_epoch) - logger.info("{:s}".format(logger.coloring(message, "RED"))) + logger.info(message) # 3. save the persistable model if epoch_id % config.save_interval == 0: