train.py 4.6 KB
Newer Older
C
backup  
chenyuntc 已提交
1 2
import os

C
chenyuntc 已提交
3 4
import ipdb
import matplotlib
C
backup  
chenyuntc 已提交
5 6
from tqdm import tqdm

C
chenyuntc 已提交
7 8
import torch as t
from config import opt
C
backup  
chenyuntc 已提交
9
from data.dataset import Dataset,TestDataset
C
backup  
chenyuntc 已提交
10
from model import FasterRCNNVGG16
C
chenyuntc 已提交
11 12
from torch.autograd import Variable
from torch.utils import data as data_
C
backup  
chenyuntc 已提交
13
from trainer import FasterRCNNTrainer
C
chenyuntc 已提交
14
from util import array_tool as at
C
backup  
chenyuntc 已提交
15
from util.vis_tool import visdom_bbox
C
backup  
chenyuntc 已提交
16
from util.eval_tool import eval_detection_voc
C
backup  
chenyuntc 已提交
17

C
chenyuntc 已提交
18
matplotlib.use('agg')
C
backup  
chenyuntc 已提交
19

C
backup  
chenyuntc 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

def eval(dataloader,faster_rcnn,test_num=1000):
    pred_bboxes,pred_labels, pred_scores = list(),list(),list()
    gt_bboxes,gt_labels, gt_difficults = list(),list(),list()
    for ii,(imgs,sizes,gt_bboxes_,gt_labels_,gt_difficults_) in tqdm(enumerate(dataloader)):
        sizes = [sizes[0][0],sizes[1][0]]
        pred_bboxes_,pred_labels_,pred_scores_ = faster_rcnn.predict2(imgs,[sizes])
        gt_bboxes += list(gt_bboxes_.numpy())
        gt_labels += list(gt_labels_.numpy())
        gt_difficults += list(gt_difficults_.numpy())
        pred_bboxes += pred_bboxes_
        pred_labels += pred_labels_
        pred_scores += pred_scores_
        if ii==test_num:break

    result = eval_detection_voc(
        pred_bboxes, pred_labels, pred_labels,
        gt_bboxes, gt_labels, gt_difficults,
        use_07_metric=True)
    return result

C
backup  
chenyuntc 已提交
41 42 43 44 45 46 47 48
def train(**kwargs):
    opt._parse(kwargs)

    dataset = Dataset(opt)
    print('load data')
    dataloader = data_.DataLoader(dataset,\
                            batch_size=1,\
                            shuffle=True,\
C
backup  
chenyuntc 已提交
49
                            pin_memory=True,
C
backup  
chenyuntc 已提交
50
                            num_workers=opt.num_workers)
C
backup  
chenyuntc 已提交
51 52 53 54
    testset = TestDataset(opt)
    test_dataloader = data_.DataLoader(testset,
                                batch_size=1,
                                num_workers=2,
C
backup  
chenyuntc 已提交
55
                                shuffle=True,\
C
backup  
chenyuntc 已提交
56
                                pin_memory=True)
C
backup  
chenyuntc 已提交
57 58 59 60 61 62 63 64 65 66 67

    faster_rcnn = FasterRCNNVGG16()
    print('model completed')
    trainer = FasterRCNNTrainer(faster_rcnn).cuda()
    if opt.load_path:
        trainer.load_state_dict(t.load(opt.load_path))
        print('load pretrained model from %s' %opt.load_path)
    
    trainer.vis.text(dataset.db.label_names,win='labels')

    for epoch in range(opt.epoch):
C
chenyuntc 已提交
68
        trainer.reset_meters()
C
backup  
chenyuntc 已提交
69
        for ii,(img, bbox_, label_, scale, ori_img) in tqdm(enumerate(dataloader)):
C
chenyuntc 已提交
70
            scale = at.scalar(scale)
C
backup  
chenyuntc 已提交
71 72 73
            img,bbox,label = img.cuda().float(),bbox_.cuda(),label_.cuda()
            img,bbox,label = Variable(img),Variable(bbox),Variable(label)
            losses,rois = trainer.train_step(img,bbox,label,scale)
C
chenyuntc 已提交
74 75
            
            if (ii+1)%opt.plot_every == 0:
C
backup  
chenyuntc 已提交
76
                if os.path.exists(opt.debug_file):
C
chenyuntc 已提交
77
                    ipdb.set_trace()
C
backup  
chenyuntc 已提交
78
                # plot loss
C
chenyuntc 已提交
79
                trainer.vis.plot_many(trainer.get_meter_data())
C
backup  
chenyuntc 已提交
80 81
                
                # plot groud truth bboxes
C
backup  
chenyuntc 已提交
82
                ori_img_ =  (img*0.225+0.45).clamp(min=0,max=1)*255
C
backup  
chenyuntc 已提交
83 84 85
                trainer.vis.img('gt_img',visdom_bbox(at.tonumpy(ori_img_)[0],at.tonumpy(bbox_)[0],label_[0].numpy()))
                
                # plot predicti bboxes
C
backup  
chenyuntc 已提交
86
                _bboxes, _labels, _scores = trainer.faster_rcnn.predict(ori_img)
C
backup  
chenyuntc 已提交
87 88 89
                trainer.vis.img('pred_img',visdom_bbox(at.tonumpy(ori_img[0]),at.tonumpy(_bboxes[0]),at.tonumpy(_labels[0]).reshape(-1)))

                # rpn confusion matrix(meter)
C
backup  
chenyuntc 已提交
90
                trainer.vis.text(str(trainer.rpn_cm.value().tolist()),win='rpn_cm')
C
backup  
chenyuntc 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
                # roi confusion matrix
                trainer.vis.img('roi_cm',at.totensor(trainer.roi_cm.value(),False).float())

                # ooo_ = (at.tonumpy(img[0])*0.25+0.45).clip(min=0,max=1)*255
                # trainer.vis.img('rpn_roi_top4',
                #                     visdom_bbox(ooo_,
                #                     at.tonumpy(rois[:4]))
                #             )
                # trainer.vis.img('sample_rois_img', 
                #         visdom_bbox(ooo_,
                #             at.tonumpy(trainer.sample_roi[0:12:2]),
                #             trainer.gt_roi_label[0:12:2]-1)
                #             )
                # break #TODO:delete it for debug
        if epoch==6: # lr decay
            trainer.faster_rcnn.update_optimizer(opt.lr_decay)

        eval_result  = eval(test_dataloader,faster_rcnn)
        trainer.vis.plot('test_map', eval_result['map'])
        trainer.vis.log('map:{},loss:{},roi_cm:{}'.format(str(eval_result),str(trainer.get_meter_data()),str(trainer.rpn_cm.conf.tolist())))
        trainer.save()
        # t.save(trainer.state_dict(),'checkpoints/fasterrcnn_%s.pth' %epoch)
        # t.vis.save([opt.env])
C
chenyuntc 已提交
114

C
backup  
chenyuntc 已提交
115 116 117

if __name__=='__main__':
    import fire
C
chenyuntc 已提交
118
    fire.Fire()