From 24bb23fafbff6465f064e96bc5d1b636a33e9dfb Mon Sep 17 00:00:00 2001 From: chengxianbin Date: Wed, 20 May 2020 23:15:15 +0800 Subject: [PATCH] supportfunction of incremental training --- example/ssd_coco2017/train.py | 6 +++--- example/yolov3_coco2017/train.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/ssd_coco2017/train.py b/example/ssd_coco2017/train.py index a89d558c6..75f9a6d31 100644 --- a/example/ssd_coco2017/train.py +++ b/example/ssd_coco2017/train.py @@ -87,7 +87,7 @@ def main(): parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") - parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.") + parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") args_opt = parser.parse_args() @@ -157,8 +157,8 @@ def main(): opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) net = TrainingWrapper(net, opt, loss_scale) - if args_opt.checkpoint_path != "": - param_dict = load_checkpoint(args_opt.checkpoint_path) + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) load_param_into_net(net, param_dict) callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] diff --git a/example/yolov3_coco2017/train.py b/example/yolov3_coco2017/train.py index cfa3580b8..62329bf88 100644 --- a/example/yolov3_coco2017/train.py +++ b/example/yolov3_coco2017/train.py @@ -70,7 +70,7 @@ def main(): parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") - parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") + parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path") parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", @@ -138,8 +138,8 @@ def main(): opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) net = TrainingWrapper(net, opt, loss_scale) - if args_opt.checkpoint_path != "": - param_dict = load_checkpoint(args_opt.checkpoint_path) + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) load_param_into_net(net, param_dict) callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] -- GitLab