diff --git a/pdseg/train.py b/pdseg/train.py index 4ae5ad121c99f17a74c512fb15b56e57fcaed153..4f6a90e003c0b2997daceab684b7199f52c9aafc 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -438,7 +438,8 @@ def train(cfg): except Exception as e: print(e) - if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 and cfg.TRAINER_ID == 0: + if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 + or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0: ckpt_dir = save_checkpoint(exe, train_prog, epoch) if args.do_eval: