diff --git a/tools/train_multi_platform.py b/tools/train_multi_platform.py index 558339893b414884873088308cd8082eda10b0d8..ba024e89d287e66bd6228b9219349ed1728c10c4 100644 --- a/tools/train_multi_platform.py +++ b/tools/train_multi_platform.py @@ -120,37 +120,36 @@ def main(args): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, epoch_id, 'train', vdl_writer) - if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: - # 2. validate with validate dataset - if config.validate and epoch_id % config.valid_interval == 0: - if config.get('use_ema'): - logger.info(logger.coloring("EMA validate start...")) - with ema.apply(exe): - top1_acc = program.run(valid_dataloader, exe, - compiled_valid_prog, - valid_fetchs, epoch_id, 'valid') - logger.info(logger.coloring("EMA validate over!")) - - top1_acc = program.run(valid_dataloader, exe, - compiled_valid_prog, valid_fetchs, - epoch_id, 'valid') - if top1_acc > best_top1_acc: - best_top1_acc = top1_acc - message = "The best top1 acc {:.5f}, in epoch: {:d}".format( - best_top1_acc, epoch_id) - logger.info("{:s}".format(logger.coloring(message, "RED"))) - if epoch_id % config.save_interval == 0: - - model_path = os.path.join(config.model_save_dir, - config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, - "best_model_in_epoch_" + str(epoch_id)) - - # 3. save the persistable model - if epoch_id % config.save_interval == 0: - model_path = os.path.join(config.model_save_dir, - config.ARCHITECTURE["name"]) - save_model(train_prog, model_path, epoch_id) + + # 2. validate with validate dataset + if config.validate and epoch_id % config.valid_interval == 0: + if config.get('use_ema'): + logger.info(logger.coloring("EMA validate start...")) + with ema.apply(exe): + top1_acc = program.run(valid_dataloader, exe, + compiled_valid_prog, valid_fetchs, + epoch_id, 'valid') + logger.info(logger.coloring("EMA validate over!")) + + top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, + valid_fetchs, epoch_id, 'valid') + if top1_acc > best_top1_acc: + best_top1_acc = top1_acc + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, epoch_id) + logger.info("{:s}".format(logger.coloring(message, "RED"))) + if epoch_id % config.save_interval == 0: + + model_path = os.path.join(config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(train_prog, model_path, + "best_model_in_epoch_" + str(epoch_id)) + + # 3. save the persistable model + if epoch_id % config.save_interval == 0: + model_path = os.path.join(config.model_save_dir, + config.ARCHITECTURE["name"]) + save_model(train_prog, model_path, epoch_id) if __name__ == '__main__':