提交 761c3f4f 编写于 作者: W Waleed Abdulla

Handle COCO crowds.

上级 de3ad95c
......@@ -171,6 +171,14 @@ class CocoDataset(utils.Dataset):
# and end up rounded out. Skip those objects.
if m.max() < 1:
continue
# Is it a crowd? If so, use a negative class ID.
if annotation['iscrowd']:
# Use negative class ID for crowds
class_id *= -1
# For crowd masks, annToMask() sometimes returns a mask
# smaller than the given dimensions. If so, resize it.
if m.shape[0] != image_info["height"] or m.shape[1] != image_info["width"]:
m = np.ones([image_info["height"], image_info["width"]], dtype=bool)
instance_masks.append(m)
class_ids.append(class_id)
......
......@@ -494,16 +494,32 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=2,
name="trim_gt_masks")
# Handle COCO crowds
# A crowd box in COCO is a bounding box around several instances. Exclude
# them from training. A crowd box is given a negative class ID.
crowd_ix = tf.where(gt_class_ids < 0)[:, 0]
non_crowd_ix = tf.where(gt_class_ids > 0)[:, 0]
crowd_boxes = tf.gather(gt_boxes, crowd_ix)
crowd_masks = tf.gather(gt_masks, crowd_ix, axis=2)
gt_class_ids = tf.gather(gt_class_ids, non_crowd_ix)
gt_boxes = tf.gather(gt_boxes, non_crowd_ix)
gt_masks = tf.gather(gt_masks, non_crowd_ix, axis=2)
# Compute overlaps matrix [proposals, gt_boxes]
overlaps = overlaps_graph(proposals, gt_boxes)
# Compute overlaps with crowd boxes [anchors, crowds]
crowd_overlaps = overlaps_graph(proposals, crowd_boxes)
crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1)
no_crowd_bool = (crowd_iou_max < 0.001)
# Determine postive and negative ROIs
roi_iou_max = tf.reduce_max(overlaps, axis=1)
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box
positive_roi_bool = (roi_iou_max >= 0.5)
positive_indices = tf.where(positive_roi_bool)[:, 0]
# 2. Negative ROIs are those with < 0.5 with every GT box
negative_indices = tf.where(roi_iou_max < 0.5)[:, 0]
# 2. Negative ROIs are those with < 0.5 with every GT box. Skip crowds.
negative_indices = tf.where(tf.logical_and(roi_iou_max < 0.5, no_crowd_bool))[:, 0]
# Subsample ROIs. Aim for 33% positive
# Positive ROIs
......@@ -1357,6 +1373,23 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
# RPN bounding boxes: [max anchors per image, (dy, dx, log(dh), log(dw))]
rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4))
# Handle COCO crowds
# A crowd box in COCO is a bounding box around several instances. Exclude
# them from training. A crowd box is given a negative class ID.
crowd_ix = np.where(gt_class_ids < 0)[0]
if crowd_ix.shape[0] > 0:
# Filter out crowds from ground truth class IDs and boxes
non_crowd_ix = np.where(gt_class_ids > 0)[0]
crowd_boxes = gt_boxes[crowd_ix]
gt_class_ids = gt_class_ids[non_crowd_ix]
gt_boxes = gt_boxes[non_crowd_ix]
# Compute overlaps with crowd boxes [anchors, crowds]
crowd_overlaps = utils.compute_overlaps(anchors, crowd_boxes)
crowd_iou_max = np.amax(crowd_overlaps, axis=1)
no_crowd_bool = (crowd_iou_max < 0.001)
else:
# All anchors don't intersect a crowd
no_crowd_bool = np.ones([anchors.shape[0]], dtype=bool)
# Compute overlaps [num_anchors, num_gt_boxes]
overlaps = utils.compute_overlaps(anchors, gt_boxes)
......@@ -1369,10 +1402,11 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
# However, don't keep any GT box unmatched (rare, but happens). Instead,
# match it to the closest anchor (even if its max IoU is < 0.3).
#
# 1. Set negative anchors first. It gets overwritten if a gt box is matched to them.
# 1. Set negative anchors first. They get overwritten below if a GT box is
# matched to them. Skip boxes in crowd areas.
anchor_iou_argmax = np.argmax(overlaps, axis=1)
anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
rpn_match[anchor_iou_max < 0.3] = -1
rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1
# 2. Set an anchor for each GT box (regardless of IoU value).
# TODO: If multiple anchors have the same IoU match all of them
gt_iou_argmax = np.argmax(overlaps, axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册