提交 24bb23fa 编写于 作者: C chengxianbin

supportfunction of incremental training

上级 113e9c7b
...@@ -87,7 +87,7 @@ def main(): ...@@ -87,7 +87,7 @@ def main():
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
args_opt = parser.parse_args() args_opt = parser.parse_args()
...@@ -157,8 +157,8 @@ def main(): ...@@ -157,8 +157,8 @@ def main():
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
net = TrainingWrapper(net, opt, loss_scale) net = TrainingWrapper(net, opt, loss_scale)
if args_opt.checkpoint_path != "": if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
......
...@@ -70,7 +70,7 @@ def main(): ...@@ -70,7 +70,7 @@ def main():
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink")
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
...@@ -138,8 +138,8 @@ def main(): ...@@ -138,8 +138,8 @@ def main():
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale) net = TrainingWrapper(net, opt, loss_scale)
if args_opt.checkpoint_path != "": if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册