diff --git a/tools/program.py b/tools/program.py index 5a5745953ec500df408ad6417e7f62ea49a44952..33e087d7085a9b7cdd5b5b385d7f5ea05a38e06a 100644 --- a/tools/program.py +++ b/tools/program.py @@ -383,17 +383,21 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): tic = time.time() for i, m in enumerate(metrics): metric_list[i].update(m[0], len(batch[0])) - fetchs_str = ''.join([str(m.value)+' ' - for m in metric_list]+ [batch_time.value]) + fetchs_str = ''.join([str(m.value) + ' ' + for m in metric_list] + [batch_time.value]) if epoch != -1: logger.info("epoch:{:<3d} {:s} step:{:<4d} {:s}s".format( - epoch, mode, idx, fetchs_str)) + epoch, mode, idx, fetchs_str)) else: - logger.info("{:s} step:{:<4d} {:s}s".format( - mode, idx, fetchs_str)) + logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) - end_str = ''.join([str(m.mean)+' ' for m in metric_list] + [batch_time.total]) - if epoch!= -1: + end_str = ''.join([str(m.mean) + ' ' + for m in metric_list] + [batch_time.total]) + if epoch != -1: logger.info("END epoch:{:<3d} {:s} {:s}s".format(epoch, mode, end_str)) else: logger.info("END {:s} {:s}s".format(mode, end_str)) + + # save the best model + top1_acc = fetchs["top1"][1].avg + return top1_acc diff --git a/tools/train.py b/tools/train.py index 9cbe3edd844ca192eec0d1fe305fba33a45530b2..707a7d7a10bf94fc6b50be1396b9b75a8d563132 100644 --- a/tools/train.py +++ b/tools/train.py @@ -26,6 +26,7 @@ from paddle.fluid.incubate.fleet.collective import fleet from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model, save_model +from ppcls.utils import logger import program @@ -61,6 +62,10 @@ def main(args): startup_prog = fluid.Program() train_prog = fluid.Program() + # best_top1_acc_list[0]: top1 acc + # best_top1_acc_list[1]: epoch id + best_top1_acc_list = [0.0, 0] + train_dataloader, train_fetchs = program.build( config, train_prog, startup_prog, is_train=True) @@ -94,8 +99,16 @@ def main(args): epoch_id, 'train') # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: - program.run(valid_dataloader, exe, compiled_valid_prog, - valid_fetchs, epoch_id, 'valid') + top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, + valid_fetchs, epoch_id, 'valid') + if top1_acc > best_top1_acc_list[0]: + best_top1_acc_list[0] = top1_acc + best_top1_acc_list[1] = epoch_id + logger.info("Best top1 acc: {}, in epoch: {}".format( + best_top1_acc_list[0], best_top1_acc_list[1])) + model_path = os.path.join(config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(train_prog, model_path, "best_model") # 3. save the persistable model if epoch_id % config.save_interval == 0: