train.py 4.9 KB
Newer Older
C
chenyun 已提交
1
from __future__ import  absolute_import
C
backup  
chenyuntc 已提交
2 3
import os

C
chenyuntc 已提交
4 5
import ipdb
import matplotlib
C
backup  
chenyuntc 已提交
6 7
from tqdm import tqdm

C
chenyuntc 已提交
8
from utils.config import opt
C
chenyuntc 已提交
9
from data.dataset import Dataset, TestDataset, inverse_normalize
C
backup  
chenyuntc 已提交
10
from model import FasterRCNNVGG16
C
chenyuntc 已提交
11
from torch.utils import data as data_
C
backup  
chenyuntc 已提交
12
from trainer import FasterRCNNTrainer
C
chenyuntc 已提交
13 14 15
from utils import array_tool as at
from utils.vis_tool import visdom_bbox
from utils.eval_tool import eval_detection_voc
C
backup  
chenyuntc 已提交
16

C
chenyuntc 已提交
17 18 19
# fix for ulimit
# https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
import resource
C
chenyuntc 已提交
20

C
chenyuntc 已提交
21 22 23
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1]))

C
chenyuntc 已提交
24
matplotlib.use('agg')
C
backup  
chenyuntc 已提交
25

C
chenyuntc 已提交
26

C
chenyuntc 已提交
27 28 29 30
def eval(dataloader, faster_rcnn, test_num=10000):
    pred_bboxes, pred_labels, pred_scores = list(), list(), list()
    gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
    for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)):
C
chenyun 已提交
31
        sizes = [sizes[0][0].item(), sizes[1][0].item()]
C
chenyuntc 已提交
32
        pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes])
C
backup  
chenyuntc 已提交
33 34 35 36 37 38
        gt_bboxes += list(gt_bboxes_.numpy())
        gt_labels += list(gt_labels_.numpy())
        gt_difficults += list(gt_difficults_.numpy())
        pred_bboxes += pred_bboxes_
        pred_labels += pred_labels_
        pred_scores += pred_scores_
C
chenyuntc 已提交
39
        if ii == test_num: break
C
backup  
chenyuntc 已提交
40 41

    result = eval_detection_voc(
C
chenyuntc 已提交
42
        pred_bboxes, pred_labels, pred_scores,
C
backup  
chenyuntc 已提交
43 44 45 46
        gt_bboxes, gt_labels, gt_difficults,
        use_07_metric=True)
    return result

C
chenyuntc 已提交
47

C
backup  
chenyuntc 已提交
48 49 50 51 52
def train(**kwargs):
    opt._parse(kwargs)

    dataset = Dataset(opt)
    print('load data')
C
chenyuntc 已提交
53 54 55 56 57
    dataloader = data_.DataLoader(dataset, \
                                  batch_size=1, \
                                  shuffle=True, \
                                  # pin_memory=True,
                                  num_workers=opt.num_workers)
C
backup  
chenyuntc 已提交
58 59
    testset = TestDataset(opt)
    test_dataloader = data_.DataLoader(testset,
C
chenyuntc 已提交
60
                                       batch_size=1,
C
chenyuntc 已提交
61
                                       num_workers=opt.test_num_workers,
C
chenyuntc 已提交
62
                                       shuffle=False, \
C
chenyuntc 已提交
63
                                       pin_memory=True
C
chenyuntc 已提交
64
                                       )
C
backup  
chenyuntc 已提交
65
    faster_rcnn = FasterRCNNVGG16()
C
chenyuntc 已提交
66
    print('model construct completed')
C
backup  
chenyuntc 已提交
67 68
    trainer = FasterRCNNTrainer(faster_rcnn).cuda()
    if opt.load_path:
C
chenyuntc 已提交
69 70 71
        trainer.load(opt.load_path)
        print('load pretrained model from %s' % opt.load_path)
    trainer.vis.text(dataset.db.label_names, win='labels')
C
chenyuntc 已提交
72
    best_map = 0
C
backup  
chenyuntc 已提交
73
    for epoch in range(opt.epoch):
C
chenyuntc 已提交
74
        trainer.reset_meters()
C
chenyuntc 已提交
75
        for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
C
chenyuntc 已提交
76
            scale = at.scalar(scale)
C
chenyuntc 已提交
77
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
C
chenyuntc 已提交
78
            trainer.train_step(img, bbox, label, scale)
C
chenyuntc 已提交
79 80

            if (ii + 1) % opt.plot_every == 0:
C
backup  
chenyuntc 已提交
81
                if os.path.exists(opt.debug_file):
C
chenyuntc 已提交
82
                    ipdb.set_trace()
C
chenyuntc 已提交
83

C
backup  
chenyuntc 已提交
84
                # plot loss
C
chenyuntc 已提交
85
                trainer.vis.plot_many(trainer.get_meter_data())
C
chenyuntc 已提交
86

C
chenyuntc 已提交
87 88
                # plot groud truth bboxes
                ori_img_ = inverse_normalize(at.tonumpy(img[0]))
C
chenyuntc 已提交
89 90 91
                gt_img = visdom_bbox(ori_img_,
                                     at.tonumpy(bbox_[0]),
                                     at.tonumpy(label_[0]))
C
chenyuntc 已提交
92
                trainer.vis.img('gt_img', gt_img)
C
chenyuntc 已提交
93

C
chenyuntc 已提交
94
                # plot predicti bboxes
C
chenyuntc 已提交
95 96 97 98 99
                _bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True)
                pred_img = visdom_bbox(ori_img_,
                                       at.tonumpy(_bboxes[0]),
                                       at.tonumpy(_labels[0]).reshape(-1),
                                       at.tonumpy(_scores[0]))
C
chenyuntc 已提交
100
                trainer.vis.img('pred_img', pred_img)
C
backup  
chenyuntc 已提交
101 102

                # rpn confusion matrix(meter)
C
chenyuntc 已提交
103
                trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm')
C
backup  
chenyuntc 已提交
104
                # roi confusion matrix
C
chenyuntc 已提交
105 106
                trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float())
        eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
C
chenyun 已提交
107 108 109 110 111 112
        trainer.vis.plot('test_map', eval_result['map'])
        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_),
                                                  str(eval_result['map']),
                                                  str(trainer.get_meter_data()))
        trainer.vis.log(log_info)
C
chenyuntc 已提交
113

C
chenyuntc 已提交
114
        if eval_result['map'] > best_map:
C
chenyuntc 已提交
115
            best_map = eval_result['map']
C
chenyuntc 已提交
116
            best_path = trainer.save(best_map=best_map)
C
chenyuntc 已提交
117
        if epoch == 9:
C
chenyuntc 已提交
118
            trainer.load(best_path)
C
chenyuntc 已提交
119
            trainer.faster_rcnn.scale_lr(opt.lr_decay)
C
chenyuntc 已提交
120

C
chenyuntc 已提交
121 122
        if epoch == 13: 
            break
C
chenyuntc 已提交
123

C
backup  
chenyuntc 已提交
124

C
chenyuntc 已提交
125
if __name__ == '__main__':
C
backup  
chenyuntc 已提交
126
    import fire
C
chenyuntc 已提交
127

C
chenyuntc 已提交
128
    fire.Fire()