From 214c5553090dbb25e1ccbfbee76f882afa08b051 Mon Sep 17 00:00:00 2001 From: chenyuntc Date: Thu, 21 Dec 2017 11:01:56 +0800 Subject: [PATCH] add support for warm start --- .gitignore | 1 + config.py | 1 + model/faster_rcnn_vgg16.py | 11 ++++++++--- train.py | 7 +++++-- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 38fe8e6..94f052b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ model/utils/build/ imgs/ *.png *.jpg +misc/ diff --git a/config.py b/config.py index 8c3c2e6..2b5da99 100644 --- a/config.py +++ b/config.py @@ -42,6 +42,7 @@ class Config: use_adam = False use_chainer = False + use_drop = False # debug debug_file = '/tmp/debugf' diff --git a/model/faster_rcnn_vgg16.py b/model/faster_rcnn_vgg16.py index 73e5e81..8695c70 100644 --- a/model/faster_rcnn_vgg16.py +++ b/model/faster_rcnn_vgg16.py @@ -15,9 +15,14 @@ def decom_vgg16(pretrained=True): model = vgg16(pretrained) features = list(model.features)[:30] classifier = model.classifier - # classifier = list(classifier) - # del the last layer - del classifier._modules['6'] + + classifier = list(classifier) + # delete dropout + del classifier[6] + if not opt.use_drop: + del classifier[5] + del classifier[2] + classifier = nn.Sequential(*classifier) # free top3 conv for layer in features[:10]: diff --git a/train.py b/train.py index d7bf82b..961923c 100644 --- a/train.py +++ b/train.py @@ -62,6 +62,7 @@ def train(**kwargs): trainer.load(opt.load_path) print('load pretrained model from %s' % opt.load_path) + trainer.optimizer = trainer.faster_rcnn.get_great_optimizer() trainer.vis.text(dataset.db.label_names, win='labels') best_map = 0 for epoch in range(opt.epoch): @@ -98,7 +99,7 @@ def train(**kwargs): trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm') # roi confusion matrix trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float()) - if best_map>0.6: + if best_map>0.6 and opt.test_num<5000: opt.test_num=10000 best_map = 0 eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num) @@ -106,9 +107,11 @@ def train(**kwargs): if eval_result['map'] > best_map: best_map = eval_result['map'] best_path = trainer.save(best_map=best_map) - else: + if epoch==8: trainer.load(best_path) trainer.faster_rcnn.scale_lr(opt.lr_decay) + if epoch ==0: + trainer.optimizer = trainer.faster_rcnn.get_optimizer() trainer.vis.plot('test_map', eval_result['map']) lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr'] -- GitLab