train.py 5.3 KB
Newer Older
C
chenyun 已提交
1
from __future__ import  absolute_import
C
chenyun 已提交
2 3
# though cupy is not used but without this line, it raise errors...
import cupy as cp
C
backup  
chenyuntc 已提交
4 5
import os

C
chenyuntc 已提交
6 7
import ipdb
import matplotlib
C
backup  
chenyuntc 已提交
8 9
from tqdm import tqdm

C
chenyuntc 已提交
10
from utils.config import opt
C
chenyuntc 已提交
11
from data.dataset import Dataset, TestDataset, inverse_normalize
C
backup  
chenyuntc 已提交
12
from model import FasterRCNNVGG16
C
chenyuntc 已提交
13
from torch.utils import data as data_
C
backup  
chenyuntc 已提交
14
from trainer import FasterRCNNTrainer
C
chenyuntc 已提交
15 16 17
from utils import array_tool as at
from utils.vis_tool import visdom_bbox
from utils.eval_tool import eval_detection_voc
C
backup  
chenyuntc 已提交
18

C
chenyuntc 已提交
19 20 21
# fix for ulimit
# https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
import resource
C
chenyuntc 已提交
22

C
chenyuntc 已提交
23 24 25
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1]))

C
chenyuntc 已提交
26
matplotlib.use('agg')
C
backup  
chenyuntc 已提交
27

C
chenyuntc 已提交
28

C
chenyuntc 已提交
29 30 31 32
def eval(dataloader, faster_rcnn, test_num=10000):
    pred_bboxes, pred_labels, pred_scores = list(), list(), list()
    gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
    for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)):
C
chenyun 已提交
33
        sizes = [sizes[0][0].item(), sizes[1][0].item()]
C
chenyuntc 已提交
34
        pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes])
C
backup  
chenyuntc 已提交
35 36 37 38 39 40
        gt_bboxes += list(gt_bboxes_.numpy())
        gt_labels += list(gt_labels_.numpy())
        gt_difficults += list(gt_difficults_.numpy())
        pred_bboxes += pred_bboxes_
        pred_labels += pred_labels_
        pred_scores += pred_scores_
C
chenyuntc 已提交
41
        if ii == test_num: break
C
backup  
chenyuntc 已提交
42 43

    result = eval_detection_voc(
C
chenyuntc 已提交
44
        pred_bboxes, pred_labels, pred_scores,
C
backup  
chenyuntc 已提交
45 46 47 48
        gt_bboxes, gt_labels, gt_difficults,
        use_07_metric=True)
    return result

C
chenyuntc 已提交
49

C
backup  
chenyuntc 已提交
50 51 52 53 54
def train(**kwargs):
    opt._parse(kwargs)

    dataset = Dataset(opt)
    print('load data')
C
chenyuntc 已提交
55 56 57 58 59
    dataloader = data_.DataLoader(dataset, \
                                  batch_size=1, \
                                  shuffle=True, \
                                  # pin_memory=True,
                                  num_workers=opt.num_workers)
C
backup  
chenyuntc 已提交
60 61
    testset = TestDataset(opt)
    test_dataloader = data_.DataLoader(testset,
C
chenyuntc 已提交
62
                                       batch_size=1,
C
chenyuntc 已提交
63
                                       num_workers=opt.test_num_workers,
C
chenyuntc 已提交
64
                                       shuffle=False, \
C
chenyuntc 已提交
65
                                       pin_memory=True
C
chenyuntc 已提交
66
                                       )
C
backup  
chenyuntc 已提交
67
    faster_rcnn = FasterRCNNVGG16()
C
chenyuntc 已提交
68
    print('model construct completed')
C
backup  
chenyuntc 已提交
69 70
    trainer = FasterRCNNTrainer(faster_rcnn).cuda()
    if opt.load_path:
C
chenyuntc 已提交
71 72 73
        trainer.load(opt.load_path)
        print('load pretrained model from %s' % opt.load_path)
    trainer.vis.text(dataset.db.label_names, win='labels')
C
chenyuntc 已提交
74
    best_map = 0
75
    lr_ = opt.lr
C
backup  
chenyuntc 已提交
76
    for epoch in range(opt.epoch):
C
chenyuntc 已提交
77
        trainer.reset_meters()
C
chenyuntc 已提交
78
        for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)):
C
chenyuntc 已提交
79
            scale = at.scalar(scale)
C
chenyuntc 已提交
80
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
C
chenyuntc 已提交
81
            trainer.train_step(img, bbox, label, scale)
C
chenyuntc 已提交
82 83

            if (ii + 1) % opt.plot_every == 0:
C
backup  
chenyuntc 已提交
84
                if os.path.exists(opt.debug_file):
C
chenyuntc 已提交
85
                    ipdb.set_trace()
C
chenyuntc 已提交
86

C
backup  
chenyuntc 已提交
87
                # plot loss
C
chenyuntc 已提交
88
                trainer.vis.plot_many(trainer.get_meter_data())
C
chenyuntc 已提交
89

C
chenyuntc 已提交
90 91
                # plot groud truth bboxes
                ori_img_ = inverse_normalize(at.tonumpy(img[0]))
C
chenyuntc 已提交
92 93 94
                gt_img = visdom_bbox(ori_img_,
                                     at.tonumpy(bbox_[0]),
                                     at.tonumpy(label_[0]))
C
chenyuntc 已提交
95
                trainer.vis.img('gt_img', gt_img)
C
chenyuntc 已提交
96

C
chenyuntc 已提交
97
                # plot predicti bboxes
C
chenyuntc 已提交
98 99 100 101 102
                _bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True)
                pred_img = visdom_bbox(ori_img_,
                                       at.tonumpy(_bboxes[0]),
                                       at.tonumpy(_labels[0]).reshape(-1),
                                       at.tonumpy(_scores[0]))
C
chenyuntc 已提交
103
                trainer.vis.img('pred_img', pred_img)
C
backup  
chenyuntc 已提交
104 105

                # rpn confusion matrix(meter)
C
chenyuntc 已提交
106
                trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm')
C
backup  
chenyuntc 已提交
107
                # roi confusion matrix
C
chenyuntc 已提交
108 109
                trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float())
        eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
C
chenyun 已提交
110 111 112 113 114 115
        trainer.vis.plot('test_map', eval_result['map'])
        lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr']
        log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_),
                                                  str(eval_result['map']),
                                                  str(trainer.get_meter_data()))
        trainer.vis.log(log_info)
C
chenyuntc 已提交
116

C
chenyuntc 已提交
117
        if eval_result['map'] > best_map:
C
chenyuntc 已提交
118
            best_map = eval_result['map']
C
chenyuntc 已提交
119
            best_path = trainer.save(best_map=best_map)
C
chenyuntc 已提交
120
        if epoch == 9:
C
chenyuntc 已提交
121
            trainer.load(best_path)
C
chenyuntc 已提交
122
            trainer.faster_rcnn.scale_lr(opt.lr_decay)
123
            lr_ = lr_ * opt.lr_decay
C
chenyuntc 已提交
124

C
backup  
chenyuntc 已提交
125
        trainer.vis.plot('test_map', eval_result['map'])
C
chenyuntc 已提交
126
        log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_),
C
chenyuntc 已提交
127 128
                                                  str(eval_result['map']),
                                                  str(trainer.get_meter_data()))
C
chenyuntc 已提交
129
        trainer.vis.log(log_info)
C
chenyuntc 已提交
130 131
        if epoch == 13: 
            break
C
chenyuntc 已提交
132

C
backup  
chenyuntc 已提交
133

C
chenyuntc 已提交
134
if __name__ == '__main__':
C
backup  
chenyuntc 已提交
135
    import fire
C
chenyuntc 已提交
136

C
chenyuntc 已提交
137
    fire.Fire()