提交 97691cb0 编写于 作者: R Ross Girshick

add timers; deduplicate aliased boxes

上级 4673c923
......@@ -5,7 +5,7 @@ caffe_path = '../caffe/python'
sys.path.insert(0, caffe_path)
import argparse
import time
from utils.timer import Timer
import numpy as np
import matplotlib.pyplot as plt
import cv2
......@@ -75,7 +75,7 @@ def _get_blobs(im, rois):
blobs = {'data' : None, 'rois' : None}
blobs['data'], im_scale_factors = _get_image_blob(im)
blobs['rois'] = _get_rois_blob(rois, im_scale_factors)
return blobs
return blobs, im_scale_factors
def _bbox_pred(boxes, box_deltas):
if boxes.shape[0] == 0:
......@@ -120,7 +120,18 @@ def _clip_boxes(boxes, im_shape):
return boxes
def im_detect(net, im, boxes):
blobs = _get_blobs(im, boxes)
# TODO: remove duplicates
blobs, im_scale_factors = _get_blobs(im, boxes)
# v = np.array([1, 1e3, 1e6, 1e9, 1e12])
# hashes = blobs['rois'][:, :, 0, 0].dot(v.T)
hashes = (blobs['rois'][:, :, 0, 0] *
np.array([[1, 1e3, 1e6, 1e9, 1e12]])).sum(axis=1)
_, index, inv_index = np.unique(hashes, return_index=True,
return_inverse=True)
blobs['rois'] = blobs['rois'][index, :, :, :]
boxes = boxes[index, :]
# reshape network inputs
base_shape = blobs['data'].shape
num_rois = blobs['rois'].shape[0]
......@@ -136,7 +147,25 @@ def im_detect(net, im, boxes):
pred_boxes = _bbox_pred(boxes, box_deltas)
pred_boxes = _clip_boxes(pred_boxes, im.shape)
scores = scores[inv_index, :]
pred_boxes = pred_boxes[inv_index, :]
# TODO(rbg): try variant where we predict boxes and then score those
# Need to compute all cls_rois and then deduplicate
# for i in xrange(1, scores.shape[1]):
# cls_rois_blob = _get_rois_blob(pred_boxes[:, i*4:(i+1)*4],
# im_scale_factors)
# t = Timer()
# t.tic()
# blobs_out = net.forward(data=blobs['data'].astype(np.float32,
# copy=False),
# rois=cls_rois_blob.astype(np.float32,
# copy=False),
# start='roi_pool5')
# print t.toc()
# cls_scores = blobs_out['fc8_pascal'][:, :, 0, 0]
# scores[:, i] = cls_scores[:, i] - cls_scores[:, 0]
return scores, pred_boxes
def _vis_detections(im, class_name, dets):
......@@ -198,17 +227,17 @@ def fast_rcnn_test(net, imdb):
all_boxes = [[[] for _ in xrange(num_images)]
for _ in xrange(imdb.num_classes)]
time_det = 0.
# timers
_t = {'im_detect' : Timer(), 'misc' : Timer()}
roidb = imdb.roidb
for i in xrange(num_images):
im = cv2.imread(imdb.image_path_at(i))
time_t = time.time()
_t['im_detect'].tic()
scores, boxes = im_detect(net, im, roidb[i]['boxes'])
time_det += time.time() - time_t
print 'im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, num_images,
time_det / (i + 1))
_t['im_detect'].toc()
time_t = time.time()
_t['misc'].tic()
for j in xrange(1, imdb.num_classes):
inds = np.where((scores[:, j] > thresh[j]) &
(roidb[i]['gt_classes'] == 0))[0]
......@@ -234,7 +263,11 @@ def fast_rcnn_test(net, imdb):
if 0:
keep = utils.cython_nms.nms(all_boxes[j][i], 0.3)
_vis_detections(im, imdb.classes[j], all_boxes[j][i][keep, :])
print 'other: {:.3f}s'.format(time.time() - time_t)
_t['misc'].toc()
print 'im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
.format(i + 1, num_images, _t['im_detect'].average_time,
_t['misc'].average_time)
for j in xrange(1, imdb.num_classes):
for i in xrange(num_images):
......@@ -246,8 +279,6 @@ def fast_rcnn_test(net, imdb):
_write_voc_results_file(imdb, all_boxes)
# Prune boxes with score < final threshold
# Save boxes
# Write results file and call matlab to evaluate
if __name__ == '__main__':
......@@ -256,11 +287,13 @@ if __name__ == '__main__':
caffe.set_phase_test()
caffe.set_mode_gpu()
GPU_ID = 3
GPU_ID = 2
if GPU_ID is not None:
caffe.set_device(GPU_ID)
net = caffe.Net(prototxt, caffemodel)
print '!!!!!!!!!!!!!!!!!! REMOVE mean/std??? !!!!!!!!!!!!!!!!!!'
# TODO(rbg): save net with these changes during training snapshots
import scipy.io
stats = scipy.io.loadmat('../rcnn/data/voc_2007_means_stds.mat')
stds = stats['stds'].ravel()[np.newaxis, np.newaxis, :, np.newaxis]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册