diff --git a/tools/train.py b/tools/train.py index 707a7d7a10bf94fc6b50be1396b9b75a8d563132..65a0c8f90306bacc52ee13350f3da645b1420ad6 100644 --- a/tools/train.py +++ b/tools/train.py @@ -62,9 +62,7 @@ def main(args): startup_prog = fluid.Program() train_prog = fluid.Program() - # best_top1_acc_list[0]: top1 acc - # best_top1_acc_list[1]: epoch id - best_top1_acc_list = [0.0, 0] + best_top1_acc_list = (0.0, -1) # (top1_acc, epoch_id) train_dataloader, train_fetchs = program.build( config, train_prog, startup_prog, is_train=True) @@ -97,24 +95,25 @@ def main(args): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, epoch_id, 'train') - # 2. validate with validate dataset - if config.validate and epoch_id % config.valid_interval == 0: - top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, - valid_fetchs, epoch_id, 'valid') - if top1_acc > best_top1_acc_list[0]: - best_top1_acc_list[0] = top1_acc - best_top1_acc_list[1] = epoch_id - logger.info("Best top1 acc: {}, in epoch: {}".format( - best_top1_acc_list[0], best_top1_acc_list[1])) + if int(os.environ.get("PADDLE_TRAINERS_ID", 0)) == 0: + # 2. validate with validate dataset + if config.validate and epoch_id % config.valid_interval == 0: + top1_acc = program.run(valid_dataloader, exe, + compiled_valid_prog, valid_fetchs, + epoch_id, 'valid') + if top1_acc > best_top1_acc_list[0]: + best_top1_acc_list = (top1_acc, epoch_id) + logger.info("Best top1 acc: {}, in epoch: {}".format( + *best_top1_acc_list)) + model_path = os.path.join(config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(train_prog, model_path, "best_model") + + # 3. save the persistable model + if epoch_id % config.save_interval == 0: model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, "best_model") - - # 3. save the persistable model - if epoch_id % config.save_interval == 0: - model_path = os.path.join(config.model_save_dir, - config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, epoch_id) + save_model(train_prog, model_path, epoch_id) if __name__ == '__main__':