未验证 提交 6c9d59a3 编写于 作者: J jerrywgz 提交者: GitHub

Evaluation for Faster Rcnn model

return im_id in reader for evaluation 
add coco evaluation for Faster Rcnn model
visualize detection result when image_path is provided
上级 e92a1c9c
......@@ -68,3 +68,58 @@ def clip_xyxy_to_image(x1, y1, x2, y2, height, width):
x2 = np.minimum(width - 1., np.maximum(0., x2))
y2 = np.minimum(height - 1., np.maximum(0., y2))
return x1, y1, x2, y2
def nms(dets, thresh):
"""Apply classic DPM-style greedy NMS."""
if dets.shape[0] == 0:
return []
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int)
# nominal indices
# _i, _j
# sorted indices
# i, j
# temp variables for box i's (the box currently under consideration)
# ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
# xx1, yy1, xx2, yy2
# w, h
# inter, ovr
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
return np.where(suppressed == 0)[0]
import os
import time
import numpy as np
import argparse
import functools
from eval_helper import get_nmsed_box
import paddle
import paddle.fluid as fluid
import reader
from utility import add_arguments, print_arguments
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
# A special mAP metric for COCO dataset, which averages AP in different IoUs.
# To use this eval_cocoMAP.py, [cocoapi](https://github.com/cocodataset/cocoapi) is needed.
import models.model_builder as model_builder
import models.resnet as resnet
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset', str, 'coco2017', "coco2014, coco2017.")
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('data_dir', str, 'data/COCO17', "The data root path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('nms_threshold', float, 0.5, "NMS threshold.")
add_arg('score_threshold', float, 0.05, "score threshold for NMS.")
add_arg('confs_threshold', float, 9., "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The image used to inference and visualize.")
add_arg('anchor_sizes', int, [32,64,128,256,512], "The size of anchors.")
add_arg('aspect_ratios', float, [0.5,1.0,2.0], "The ratio of anchors.")
add_arg('ap_version', str, 'cocoMAP', "cocoMAP.")
add_arg('max_size', int, 1333, "The resized image height.")
add_arg('scales', int, [800], "The resized image height.")
add_arg('mean_value', float, [102.9801, 115.9465, 122.7717], "pixel mean")
add_arg('class_num', int, 81, "Class number.")
add_arg('variance', float, [1.,1.,1.,1.], "The variance of anchors.")
# yapf: enable
def eval(args):
if '2014' in args.dataset:
test_list = 'annotations/instances_val2014.json'
elif '2017' in args.dataset:
test_list = 'annotations/instances_val2017.json'
image_shape = [3, args.max_size, args.max_size]
class_nums = args.class_num
batch_size = args.batch_size
cocoGt = COCO(os.path.join(data_args.data_dir, test_list))
numId_to_catId_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
category_ids = cocoGt.getCatIds()
label_list = {
item['id']: item['name']
for item in cocoGt.loadCats(category_ids)
}
label_list[0] = ['background']
print(label_list)
model = model_builder.FasterRCNN(
cfg=args,
add_conv_body_func=resnet.add_ResNet50_conv4_body,
add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head,
use_pyreader=False,
is_train=False,
use_random=False)
model.build_model(image_shape)
rpn_rois, confs, locs = model.eval_out()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if args.model_dir:
def if_exist(var):
return os.path.exists(os.path.join(args.model_dir, var.name))
fluid.io.load_vars(exe, args.model_dir, predicate=if_exist)
# yapf: enable
test_reader = reader.test(args, batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
dts_res = []
fetch_list = [rpn_rois, confs, locs]
for batch_id, data in enumerate(test_reader()):
start = time.time()
#image, gt_box, gt_label, is_crowd, im_info, im_id = data[0]
im_info = []
for i in range(len(data)):
im_info.append(data[i][4])
rpn_rois_v, confs_v, locs_v = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
new_lod, nmsed_out = get_nmsed_box(args, rpn_rois_v, confs_v, locs_v,
class_nums, im_info,
numId_to_catId_map)
for i in range(len(data)):
if str(data[i][5]) in args.image_path:
draw_bounding_box_on_image(args.image_path, nmsed_out,
args.confs_threshold, label_list)
dts_res += get_dt_res(new_lod, nmsed_out, data)
end = time.time()
print('batch id: {}, time: {}'.format(batch_id, end - start))
with open("detection_result.json", 'w') as outfile:
json.dump(dts_res, outfile)
print("start evaluate using coco api")
cocoDt = cocoGt.loadRes("detection_result.json")
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
#cocoEval.params.imgIds = im_id
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
def get_dt_res(lod, nmsed_out, data):
dts_res = []
nmsed_out_v = np.array(nmsed_out)
assert (len(lod) == args.batch_size + 1), \
"Error Lod Tensor offset dimension. Lod({}) vs. batch_size({})"\
.format(len(lod), batch_size)
k = 0
for i in range(args.batch_size):
dt_num_this_img = lod[i + 1] - lod[i]
image_id = int(data[i][-1])
image_width = int(data[i][4][1])
image_height = int(data[i][4][2])
for j in range(dt_num_this_img):
dt = nmsed_out_v[k]
k = k + 1
xmin, ymin, xmax, ymax, score, category_id = dt.tolist()
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
dt_res = {
'image_id': image_id,
'category_id': category_id,
'bbox': bbox,
'score': score
}
dts_res.append(dt_res)
return dts_res
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold,
label_list):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
for dt in nms_out:
xmin, ymin, xmax, ymax, score, category_id = dt.tolist()
if score < confs_threshold:
continue
bbox = dt[:4]
xmin, ymin, xmax, ymax = bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=4,
fill='red')
if image.mode == 'RGB':
draw.text((xmin, ymin), label_list[int(category_id)], (255, 255, 0))
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_args = reader.Settings(args)
eval(data_args)
import os
import numpy as np
import paddle.fluid as fluid
import math
import box_utils
def box_decoder(target_box, prior_box, prior_box_var):
proposals = np.zeros_like(target_box, dtype=np.float32)
prior_box_loc = np.zeros_like(prior_box, dtype=np.float32)
prior_box_loc[:, 0] = prior_box[:, 2] - prior_box[:, 0] + 1.
prior_box_loc[:, 1] = prior_box[:, 3] - prior_box[:, 1] + 1.
prior_box_loc[:, 2] = (prior_box[:, 2] + prior_box[:, 0]) / 2
prior_box_loc[:, 3] = (prior_box[:, 3] + prior_box[:, 1]) / 2
pred_bbox = np.zeros_like(target_box, dtype=np.float32)
for i in range(prior_box.shape[0]):
dw = np.minimum(prior_box_var[2] * target_box[i, 2::4],
np.log(1000. / 16.))
dh = np.minimum(prior_box_var[3] * target_box[i, 3::4],
np.log(1000. / 16.))
pred_bbox[i, 0::4] = prior_box_var[0] * target_box[
i, 0::4] * prior_box_loc[i, 0] + prior_box_loc[i, 2]
pred_bbox[i, 1::4] = prior_box_var[1] * target_box[
i, 1::4] * prior_box_loc[i, 1] + prior_box_loc[i, 3]
pred_bbox[i, 2::4] = np.exp(dw) * prior_box_loc[i, 0]
pred_bbox[i, 3::4] = np.exp(dh) * prior_box_loc[i, 1]
proposals[:, 0::4] = pred_bbox[:, 0::4] - pred_bbox[:, 2::4] / 2
proposals[:, 1::4] = pred_bbox[:, 1::4] - pred_bbox[:, 3::4] / 2
proposals[:, 2::4] = pred_bbox[:, 0::4] + pred_bbox[:, 2::4] / 2 - 1
proposals[:, 3::4] = pred_bbox[:, 1::4] + pred_bbox[:, 3::4] / 2 - 1
return proposals
def clip_tiled_boxes(boxes, im_shape):
"""Clip boxes to image boundaries. im_shape is [height, width] and boxes
has shape (N, 4 * num_tiled_boxes)."""
assert boxes.shape[1] % 4 == 0, \
'boxes.shape[1] is {:d}, but must be divisible by 4.'.format(
boxes.shape[1]
)
# x1 >= 0
boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
# y1 >= 0
boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
# x2 < im_shape[1]
boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
# y2 < im_shape[0]
boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
return boxes
def get_nmsed_box(args, rpn_rois, confs, locs, class_nums, im_info,
numId_to_catId_map):
lod = rpn_rois.lod()[0]
rpn_rois_v = np.array(rpn_rois)
variance_v = np.array([0.1, 0.1, 0.2, 0.2])
confs_v = np.array(confs)
locs_v = np.array(locs)
rois = box_decoder(locs_v, rpn_rois_v, variance_v)
im_results = [[] for _ in range(len(lod) - 1)]
new_lod = [0]
for i in range(len(lod) - 1):
start = lod[i]
end = lod[i + 1]
if start == end:
continue
rois_n = rois[start:end, :]
rois_n = rois_n / im_info[i][2]
rois_n = clip_tiled_boxes(rois_n, im_info[i][:2])
cls_boxes = [[] for _ in range(class_nums)]
scores_n = confs_v[start:end, :]
for j in range(1, class_nums):
inds = np.where(scores_n[:, j] > args.score_threshold)[0]
scores_j = scores_n[inds, j]
rois_j = rois_n[inds, j * 4:(j + 1) * 4]
dets_j = np.hstack((rois_j, scores_j[:, np.newaxis])).astype(
np.float32, copy=False)
keep = box_utils.nms(dets_j, args.nms_threshold)
nms_dets = dets_j[keep, :]
#add labels
cat_id = numId_to_catId_map[j]
label = np.array([cat_id for _ in range(len(keep))])
nms_dets = np.hstack((nms_dets, label[:, np.newaxis])).astype(
np.float32, copy=False)
cls_boxes[j] = nms_dets
# Limit to max_per_image detections **over all classes**
image_scores = np.hstack(
[cls_boxes[j][:, -2] for j in range(1, class_nums)])
if len(image_scores) > 100:
image_thresh = np.sort(image_scores)[-100]
for j in range(1, class_nums):
keep = np.where(cls_boxes[j][:, -2] >= image_thresh)[0]
cls_boxes[j] = cls_boxes[j][keep, :]
im_results_n = np.vstack([cls_boxes[j] for j in range(1, class_nums)])
im_results[i] = im_results_n
new_lod.append(len(im_results_n) + new_lod[-1])
boxes = im_results_n[:, :-2]
scores = im_results_n[:, -2]
labels = im_results_n[:, -1]
im_results = np.vstack([im_results[k] for k in range(len(lod) - 1)])
return new_lod, im_results
......@@ -36,16 +36,22 @@ class FasterRCNN(object):
rpn_cls_loss, rpn_reg_loss = self.rpn_loss()
return loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss,
def eval_out(self):
return [self.rpn_rois, self.cls_score, self.bbox_pred]
def build_input(self, image_shape):
if self.use_pyreader:
self.py_reader = fluid.layers.py_reader(
capacity=64,
shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1], [-1, 3]],
lod_levels=[0, 1, 1, 1, 0],
dtypes=["float32", "float32", "int32", "int32", "float32"],
shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1], [-1, 3],
[-1, 1]],
lod_levels=[0, 1, 1, 1, 0, 0],
dtypes=[
"float32", "float32", "int32", "int32", "float32", "int32"
],
use_double_buffer=True)
self.image, self.gt_box, self.gt_label, self.is_crowd, \
self.im_info = fluid.layers.read_file(self.py_reader)
self.im_info, self.im_id = fluid.layers.read_file(self.py_reader)
else:
self.image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
......@@ -61,10 +67,13 @@ class FasterRCNN(object):
append_batch_size=False)
self.im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32')
self.im_id = fluid.layers.data(
name='im_id', shape=[1], dtype='int32')
def feeds(self):
return [
self.image, self.gt_box, self.gt_label, self.is_crowd, self.im_info
self.image, self.gt_box, self.gt_label, self.is_crowd, self.im_info,
self.im_id
]
def rpn_heads(self, rpn_input):
......@@ -126,18 +135,20 @@ class FasterRCNN(object):
rpn_cls_score_prob = fluid.layers.sigmoid(
self.rpn_cls_score, name='rpn_cls_score_prob')
pre_nms_top_n = 12000 if self.is_train else 6000
post_nms_top_n = 2000 if self.is_train else 1000
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
scores=rpn_cls_score_prob,
bbox_deltas=self.rpn_bbox_pred,
im_info=self.im_info,
anchors=self.anchor,
variances=self.var,
pre_nms_top_n=12000,
post_nms_top_n=2000,
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=0.7,
min_size=0.0,
eta=1.0)
self.rpn_rois = rpn_rois
if self.is_train:
outs = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
......@@ -161,9 +172,13 @@ class FasterRCNN(object):
self.bbox_outside_weights = outs[4]
def fast_rcnn_heads(self, roi_input):
if self.is_train:
pool_rois = self.rois
else:
pool_rois = self.rpn_rois
pool = fluid.layers.roi_pool(
input=roi_input,
rois=self.rois,
rois=pool_rois,
pooled_height=14,
pooled_width=14,
spatial_scale=0.0625)
......
......@@ -73,6 +73,8 @@ def coco(settings, mode, batch_size=None, shuffle=False):
batch_out = []
for roidb in roidbs:
im, im_scales = data_utils.get_image_blob(roidb, settings)
im_id = roidb['id']
im_height = np.round(roidb['height'] * im_scales)
im_width = np.round(roidb['width'] * im_scales)
im_info = np.array(
......@@ -80,10 +82,10 @@ def coco(settings, mode, batch_size=None, shuffle=False):
gt_boxes = roidb['gt_boxes'].astype('float32')
gt_classes = roidb['gt_classes'].astype('int32')
is_crowd = roidb['is_crowd'].astype('int32')
if gt_boxes.shape[0] == 0:
if mode == 'train' and gt_boxes.shape[0] == 0:
continue
batch_out.append((im, gt_boxes, gt_classes, is_crowd, im_info))
batch_out.append(
(im, gt_boxes, gt_classes, is_crowd, im_info, im_id))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
......
......@@ -161,7 +161,6 @@ class JsonDataset(object):
def _extend_with_flipped_entries(self, roidb):
"""Flip each entry in the given roidb and return a new roidb that is the
concatenation of the original roidb and the flipped entries.
"Flipping" an entry means that that image and associated metadata (e.g.,
ground truth boxes and object proposals) are horizontally flipped.
"""
......
......@@ -145,6 +145,7 @@ def train(cfg):
def train_step(epoc_id):
start_time = time.time()
prev_start_time = start_time
start = start_time
every_pass_loss = []
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册