提交 214c5553 编写于 作者: C chenyuntc

add support for warm start

上级 50f8eaf9
...@@ -14,3 +14,4 @@ model/utils/build/ ...@@ -14,3 +14,4 @@ model/utils/build/
imgs/ imgs/
*.png *.png
*.jpg *.jpg
misc/
...@@ -42,6 +42,7 @@ class Config: ...@@ -42,6 +42,7 @@ class Config:
use_adam = False use_adam = False
use_chainer = False use_chainer = False
use_drop = False
# debug # debug
debug_file = '/tmp/debugf' debug_file = '/tmp/debugf'
......
...@@ -15,9 +15,14 @@ def decom_vgg16(pretrained=True): ...@@ -15,9 +15,14 @@ def decom_vgg16(pretrained=True):
model = vgg16(pretrained) model = vgg16(pretrained)
features = list(model.features)[:30] features = list(model.features)[:30]
classifier = model.classifier classifier = model.classifier
# classifier = list(classifier)
# del the last layer classifier = list(classifier)
del classifier._modules['6'] # delete dropout
del classifier[6]
if not opt.use_drop:
del classifier[5]
del classifier[2]
classifier = nn.Sequential(*classifier)
# free top3 conv # free top3 conv
for layer in features[:10]: for layer in features[:10]:
......
...@@ -62,6 +62,7 @@ def train(**kwargs): ...@@ -62,6 +62,7 @@ def train(**kwargs):
trainer.load(opt.load_path) trainer.load(opt.load_path)
print('load pretrained model from %s' % opt.load_path) print('load pretrained model from %s' % opt.load_path)
trainer.optimizer = trainer.faster_rcnn.get_great_optimizer()
trainer.vis.text(dataset.db.label_names, win='labels') trainer.vis.text(dataset.db.label_names, win='labels')
best_map = 0 best_map = 0
for epoch in range(opt.epoch): for epoch in range(opt.epoch):
...@@ -98,7 +99,7 @@ def train(**kwargs): ...@@ -98,7 +99,7 @@ def train(**kwargs):
trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm') trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm')
# roi confusion matrix # roi confusion matrix
trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float()) trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float())
if best_map>0.6: if best_map>0.6 and opt.test_num<5000:
opt.test_num=10000 opt.test_num=10000
best_map = 0 best_map = 0
eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num) eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
...@@ -106,9 +107,11 @@ def train(**kwargs): ...@@ -106,9 +107,11 @@ def train(**kwargs):
if eval_result['map'] > best_map: if eval_result['map'] > best_map:
best_map = eval_result['map'] best_map = eval_result['map']
best_path = trainer.save(best_map=best_map) best_path = trainer.save(best_map=best_map)
else: if epoch==8:
trainer.load(best_path) trainer.load(best_path)
trainer.faster_rcnn.scale_lr(opt.lr_decay) trainer.faster_rcnn.scale_lr(opt.lr_decay)
if epoch ==0:
trainer.optimizer = trainer.faster_rcnn.get_optimizer()
trainer.vis.plot('test_map', eval_result['map']) trainer.vis.plot('test_map', eval_result['map'])
lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr'] lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册