fast_rcnn_train.py 7.2 KB
Newer Older
R
Ross Girshick 已提交
1 2 3 4 5
#!/usr/bin/env python

import sys
caffe_path = '../caffe/python'
sys.path.insert(0, caffe_path)
R
Ross Girshick 已提交
6 7

import argparse
R
Ross Girshick 已提交
8
import time
R
Ross Girshick 已提交
9
import numpy as np
R
Ross Girshick 已提交
10
import matplotlib.pyplot as plt
R
Ross Girshick 已提交
11 12
import cv2
import caffe
R
Ross Girshick 已提交
13
import finetuning
14
import fast_rcnn_config as conf
R
Ross Girshick 已提交
15 16
import datasets.pascal_voc
import bbox_regression_targets
R
Ross Girshick 已提交
17

R
Ross Girshick 已提交
18 19 20
from caffe.proto import caffe_pb2
import google.protobuf as pb2

R
Ross Girshick 已提交
21 22 23 24 25 26 27 28 29
def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Train a fast R-CNN')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU id to use',
                        default=0, type=int)
    parser.add_argument('--solver', dest='solver', help='solver prototxt',
                        default=None, type=str)
R
Ross Girshick 已提交
30 31 32
    parser.add_argument('--epochs', dest='epochs',
                        help='number of epoch to train',
                        default=16, type=int)
R
Ross Girshick 已提交
33 34 35

    args = parser.parse_args()
    return args
R
Ross Girshick 已提交
36

R
Ross Girshick 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
class SolverWrapper(object):
    def __init__(self, solver_prototxt, pretrained_model=None):
        self.bbox_means = None
        self.bbox_stds = None

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
            print 'Loading pretrained model weights from {:s}' \
                .format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)

    def snapshot(self):
        assert self.bbox_stds is not None
        assert self.bbox_means is not None

        stds = self.bbox_stds.ravel()[np.newaxis, np.newaxis, :, np.newaxis]
        means = self.bbox_means.ravel()[np.newaxis, np.newaxis, np.newaxis, :]

        # save original values
        orig_0 = self.solver.net.params['fc8_pascal_bbox'][0].data.copy()
        orig_1 = self.solver.net.params['fc8_pascal_bbox'][1].data.copy()

        # scale and shift with bbox reg unnormalization; then save snapshot
        self.solver.net.params['fc8_pascal_bbox'][0].data[...] = \
                self.solver.net.params['fc8_pascal_bbox'][0].data * stds
        self.solver.net.params['fc8_pascal_bbox'][1].data[...] = \
                self.solver.net.params['fc8_pascal_bbox'][1].data + means

        filename = self.solver_param.snapshot_prefix + \
              '_bbox06_iter_{:d}'.format(self.solver.iter) + '.caffemodel'
        self.solver.net.save(str(filename))
        print 'Wrote snapshot to: {:s}'.format(filename)

        # restore net to original state
        self.solver.net.params['fc8_pascal_bbox'][0].data[...] = orig_0
        self.solver.net.params['fc8_pascal_bbox'][1].data[...] = orig_1

# TODO(rbg): move into SolverWrapper
def train_model(sw, roidb, max_epochs=100):
R
Ross Girshick 已提交
80
    for epoch in xrange(max_epochs):
R
Ross Girshick 已提交
81
        shuffled_inds = np.random.permutation(np.arange(len(roidb)))
82
        lim = (len(shuffled_inds) / conf.IMS_PER_BATCH) * conf.IMS_PER_BATCH
R
Ross Girshick 已提交
83
        shuffled_inds = shuffled_inds[0:lim]
84
        for shuffled_i in xrange(0, len(shuffled_inds), conf.IMS_PER_BATCH):
R
Ross Girshick 已提交
85
            # start_t = time.time()
86
            db_inds = shuffled_inds[shuffled_i:shuffled_i + conf.IMS_PER_BATCH]
R
Ross Girshick 已提交
87
            minibatch_db = [roidb[i] for i in db_inds]
88 89 90
            im_blob, rois_blob, labels_blob, \
                bbox_targets_blob, bbox_loss_weights_blob = \
                    finetuning.get_minibatch(minibatch_db)
R
Ross Girshick 已提交
91 92 93 94

            # Reshape net's input blobs
            base_shape = im_blob.shape
            num_rois = rois_blob.shape[0]
R
Ross Girshick 已提交
95 96 97 98 99
            sw.solver.net.blobs['data'].reshape(base_shape[0], base_shape[1],
                                                base_shape[2], base_shape[3])
            sw.solver.net.blobs['rois'].reshape(num_rois, 5, 1, 1)
            sw.solver.net.blobs['labels'].reshape(num_rois, 1, 1, 1)
            sw.solver.net.blobs['bbox_targets'] \
100
                .reshape(num_rois, 4 * conf.NUM_CLASSES, 1, 1)
R
Ross Girshick 已提交
101
            sw.solver.net.blobs['bbox_loss_weights'] \
102
                .reshape(num_rois, 4 * conf.NUM_CLASSES, 1, 1)
R
Ross Girshick 已提交
103
            # Copy data into net's input blobs
R
Ross Girshick 已提交
104
            sw.solver.net.blobs['data'].data[...] = \
R
Ross Girshick 已提交
105
                im_blob.astype(np.float32, copy=False)
R
Ross Girshick 已提交
106
            sw.solver.net.blobs['rois'].data[...] = \
R
Ross Girshick 已提交
107 108
                rois_blob[:, :, np.newaxis, np.newaxis] \
                .astype(np.float32, copy=False)
R
Ross Girshick 已提交
109
            sw.solver.net.blobs['labels'].data[...] = \
R
Ross Girshick 已提交
110
                labels_blob[:, np.newaxis, np.newaxis, np.newaxis] \
111
                .astype(np.float32, copy=False)
R
Ross Girshick 已提交
112
            sw.solver.net.blobs['bbox_targets'].data[...] = \
113 114
                bbox_targets_blob[:, :, np.newaxis, np.newaxis] \
                .astype(np.float32, copy=False)
R
Ross Girshick 已提交
115
            sw.solver.net.blobs['bbox_loss_weights'].data[...] = \
116
                bbox_loss_weights_blob[:, :, np.newaxis, np.newaxis] \
R
Ross Girshick 已提交
117
                .astype(np.float32, copy=False)
R
Ross Girshick 已提交
118

R
Ross Girshick 已提交
119 120 121
            sw.solver.step(1)
            if sw.solver.iter % conf.SNAPSHOT_ITERS == 0:
                sw.snapshot()
R
Ross Girshick 已提交
122
    return solver
R
Ross Girshick 已提交
123

R
Ross Girshick 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
def training_roidb(imdb):
    """
    Enriched the imdb's roidb by adding some derived quantities that
    are useful for training. This function precomputes the maximum
    overlap, taken over ground-truth boxes, between each ROI and
    each ground-truth box. The class with maximum overlap is also
    recorded.
    """
    roidb = imdb.roidb
    for i in xrange(len(imdb.image_index)):
        roidb[i]['image'] = imdb.image_path_at(i)
        # need gt_overlaps as a dense array for argmax
        gt_overlaps = roidb[i]['gt_overlaps'].toarray()
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        roidb[i]['max_classes'] = max_classes
        roidb[i]['max_overlaps'] = max_overlaps
        # sanity checks
        # max overlap of 0 => class should be zero (background)
        zero_inds = np.where(max_overlaps == 0)[0]
        assert all(max_classes[zero_inds] == 0)
        # max overlap > 0 => class should not be zero (must be a fg class)
        nonzero_inds = np.where(max_overlaps > 0)[0]
        assert all(max_classes[nonzero_inds] != 0)

    return roidb
R
Ross Girshick 已提交
152 153

if __name__ == '__main__':
R
Ross Girshick 已提交
154 155
    args = parse_args()

R
Ross Girshick 已提交
156 157 158 159 160
    # set up caffe
    caffe.set_phase_train()
    caffe.set_mode_gpu()
    if args.gpu_id is not None:
        caffe.set_device(args.gpu_id)
R
Ross Girshick 已提交
161 162 163 164 165 166

    imdb_train = datasets.pascal_voc('trainval', '2007')

    # enhance roidb to contain some useful derived quanties
    roidb_train = training_roidb(imdb_train)

R
Ross Girshick 已提交
167
    # enhance roidb to contain bounding-box regression targets
R
Ross Girshick 已提交
168 169
    means, stds = \
        bbox_regression_targets.append_bbox_regression_targets(roidb_train)
R
Ross Girshick 已提交
170

R
Ross Girshick 已提交
171 172
    # CAFFE_MODEL = '/data/reference_caffe_nets/ilsvrc_2012_train_iter_310k'
    # SOLVER_DEF = './model-defs/pyramid_solver.prototxt'
R
Ross Girshick 已提交
173 174 175 176
    CAFFE_MODEL = '/data/reference_caffe_nets/VGG_ILSVRC_16_layers.caffemodel'
    if args.solver is None:
        args.solver = './model-defs/vgg16_solver.prototxt'

R
Ross Girshick 已提交
177 178 179 180 181 182
    sw = SolverWrapper(args.solver, pretrained_model=CAFFE_MODEL)
    sw.bbox_means = means
    sw.bbox_stds = stds

    train_model(sw, roidb_train, max_epochs=args.epochs)
    sw.snapshot()