From a0b125d99a1876a1afda392453069fb981b0c704 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sat, 10 Oct 2020 05:33:44 +0000 Subject: [PATCH] imporve msg info --- tools/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/train.py b/tools/train.py index 71863033..bb2ee9fa 100644 --- a/tools/train.py +++ b/tools/train.py @@ -58,7 +58,6 @@ def main(args): if use_gpu: gpu_id = ParallelEnv().dev_id place = paddle.CUDAPlace(gpu_id) - print("[gry debug]gpu_id: ", gpu_id) else: place = paddle.CPUPlace() @@ -85,6 +84,7 @@ def main(args): valid_dataloader = Reader(config, 'valid', places=place)() best_top1_acc = 0.0 # best top1 acc record + best_top1_epoch = 0 for epoch_id in range(config.epochs): net.train() # 1. train with train dataset @@ -99,14 +99,14 @@ def main(args): None, epoch_id, 'valid') if top1_acc > best_top1_acc: best_top1_acc = top1_acc - message = "The best top1 acc {:.5f}, in epoch: {:d}".format( - best_top1_acc, epoch_id) - logger.info("{:s}".format(logger.coloring(message, "RED"))) + best_top1_epoch = epoch_id if epoch_id % config.save_interval == 0: - model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) save_model(net, optimizer, model_path, "best_model") + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, epoch_id) + logger.info("{:s}".format(logger.coloring(message, "RED"))) # 3. save the persistable model if epoch_id % config.save_interval == 0: -- GitLab