提交 859b73c1 编写于 作者: W Waleed Abdulla

Remove unnecessary coordinate conversions.

Make DetectionLayer return normalized coordinates
to avoid unnecessary conversion to pixels and back
to normalized.
上级 e6290877
......@@ -223,10 +223,10 @@ def apply_box_deltas_graph(boxes, deltas):
def clip_boxes_graph(boxes, window):
"""
boxes: [N, 4] each row is y1, x1, y2, x2
boxes: [N, (y1, x1, y2, x2)]
window: [4] in the form y1, x1, y2, x2
"""
# Split corners
# Split
wy1, wx1, wy2, wx2 = tf.split(window, 4)
y1, x1, y2, x2 = tf.split(boxes, 4, axis=1)
# Clip
......@@ -670,18 +670,6 @@ class DetectionTargetLayer(KE.Layer):
# Detection Layer
############################################################
def clip_to_window(window, boxes):
"""
window: (y1, x1, y2, x2). The window in the image we want to clip to.
boxes: [N, (y1, x1, y2, x2)]
"""
boxes[:, 0] = np.maximum(np.minimum(boxes[:, 0], window[2]), window[0])
boxes[:, 1] = np.maximum(np.minimum(boxes[:, 1], window[3]), window[1])
boxes[:, 2] = np.maximum(np.minimum(boxes[:, 2], window[2]), window[0])
boxes[:, 3] = np.maximum(np.minimum(boxes[:, 3], window[3]), window[1])
return boxes
def refine_detections_graph(rois, probs, deltas, window, config):
"""Refine classified proposals and filter overlaps and return final
detections.
......@@ -695,7 +683,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
that contains the image excluding the padding.
Returns detections shaped: [N, (y1, x1, y2, x2, class_id, score)] where
coordinates are in image domain.
coordinates are normalized.
"""
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
......@@ -708,14 +696,8 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
refined_rois = apply_box_deltas_graph(
rois, deltas_specific * config.BBOX_STD_DEV)
# Convert coordiates to image domain
# TODO: better to keep them normalized until later
height, width = config.IMAGE_SHAPE[:2]
refined_rois *= tf.constant([height, width, height, width], dtype=tf.float32)
# Clip boxes to image window
refined_rois = clip_boxes_graph(refined_rois, window)
# Round and cast to int since we're deadling with pixels now
refined_rois = tf.to_int32(tf.rint(refined_rois))
# TODO: Filter out boxes with zero area
......@@ -741,7 +723,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
ixs = tf.where(tf.equal(pre_nms_class_ids, class_id))[:, 0]
# Apply NMS
class_keep = tf.image.non_max_suppression(
tf.to_float(tf.gather(pre_nms_rois, ixs)),
tf.gather(pre_nms_rois, ixs),
tf.gather(pre_nms_scores, ixs),
max_output_size=config.DETECTION_MAX_INSTANCES,
iou_threshold=config.DETECTION_NMS_THRESHOLD)
......@@ -773,9 +755,9 @@ def refine_detections_graph(rois, probs, deltas, window, config):
keep = tf.gather(keep, top_ids)
# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are in image domain.
# Coordinates are normalized.
detections = tf.concat([
tf.to_float(tf.gather(refined_rois, keep)),
tf.gather(refined_rois, keep),
tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis],
tf.gather(class_scores, keep)[..., tf.newaxis]
], axis=1)
......@@ -792,7 +774,7 @@ class DetectionLayer(KE.Layer):
Returns:
[batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] where
coordinates are in image domain
coordinates are normalized.
"""
def __init__(self, config=None, **kwargs):
......@@ -807,13 +789,15 @@ class DetectionLayer(KE.Layer):
# Run detection refinement graph on each item in the batch
_, _, window, _ = parse_image_meta_graph(image_meta)
window = norm_boxes_graph(window, self.config.IMAGE_SHAPE[:2])
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
self.config.IMAGES_PER_GPU)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in
# normalized coordinates
return tf.reshape(
detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])
......@@ -2009,18 +1993,13 @@ class MaskRCNN():
train_bn=config.TRAIN_BN)
# Detections
# output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in image coordinates
# output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in
# normalized coordinates
detections = DetectionLayer(config, name="mrcnn_detection")(
[rpn_rois, mrcnn_class, mrcnn_bbox, input_image_meta])
# Convert boxes to normalized coordinates
# TODO: let DetectionLayer return normalized coordinates to avoid
# unnecessary conversions
h, w = config.IMAGE_SHAPE[:2]
detection_boxes = KL.Lambda(
lambda x: x[..., :4] / np.array([h, w, h, w]))(detections)
# Create masks for detections
detection_boxes = KL.Lambda(lambda x: x[..., :4])(detections)
mrcnn_mask = build_fpn_mask_graph(detection_boxes, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
......@@ -2353,16 +2332,18 @@ class MaskRCNN():
windows = np.stack(windows)
return molded_images, image_metas, windows
def unmold_detections(self, detections, mrcnn_mask, image_shape, window):
def unmold_detections(self, detections, mrcnn_mask, original_image_shape,
image_shape, window):
"""Reformats the detections of one image from the format of the neural
network output to a format suitable for use in the rest of the
application.
detections: [N, (y1, x1, y2, x2, class_id, score)]
detections: [N, (y1, x1, y2, x2, class_id, score)] in normalized coordinates
mrcnn_mask: [N, height, width, num_classes]
image_shape: [height, width, depth] Original size of the image before resizing
window: [y1, x1, y2, x2] Box in the image where the real image is
excluding the padding.
original_image_shape: [H, W, C] Original image shape before resizing
image_shape: [H, W, C] Shape of the image after resizing and padding
window: [y1, x1, y2, x2] Pixel coordinates of box in the image where the real
image is excluding the padding.
Returns:
boxes: [N, (y1, x1, y2, x2)] Bounding boxes in pixels
......@@ -2381,19 +2362,21 @@ class MaskRCNN():
scores = detections[:N, 5]
masks = mrcnn_mask[np.arange(N), :, :, class_ids]
# Compute scale and shift to translate coordinates to image domain.
h_scale = image_shape[0] / (window[2] - window[0])
w_scale = image_shape[1] / (window[3] - window[1])
scale = min(h_scale, w_scale)
shift = window[:2] # y, x
scales = np.array([scale, scale, scale, scale])
shifts = np.array([shift[0], shift[1], shift[0], shift[1]])
# Translate bounding boxes to image domain
boxes = np.multiply(boxes - shifts, scales).astype(np.int32)
# Filter out detections with zero area. Often only happens in early
# stages of training when the network weights are still a bit random.
# Translate normalized coordinates in the resized image to pixel
# coordinates in the original image before resizing
window = utils.norm_boxes(window, image_shape[:2])
wy1, wx1, wy2, wx2 = window
shift = np.array([wy1, wx1, wy1, wx1])
wh = wy2 - wy1 # window height
ww = wx2 - wx1 # window width
scale = np.array([wh, ww, wh, ww]) # todo:normalize
# Convert boxes to normalized coordinates on the window
boxes = np.divide(boxes - shift, scale)
# Convert boxes to pixel coordinates on the original image
boxes = utils.denorm_boxes(boxes, original_image_shape[:2])
# Filter out detections with zero area. Happens in early training when
# network weights are still random
exclude_ix = np.where(
(boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
if exclude_ix.shape[0] > 0:
......@@ -2407,7 +2390,7 @@ class MaskRCNN():
full_masks = []
for i in range(N):
# Convert neural network mask to full size mask
full_mask = utils.unmold_mask(masks[i], boxes[i], image_shape)
full_mask = utils.unmold_mask(masks[i], boxes[i], original_image_shape)
full_masks.append(full_mask)
full_masks = np.stack(full_masks, axis=-1)\
if full_masks else np.empty((0,) + masks.shape[1:3])
......@@ -2447,7 +2430,8 @@ class MaskRCNN():
for i, image in enumerate(images):
final_rois, final_class_ids, final_scores, final_masks =\
self.unmold_detections(detections[i], mrcnn_mask[i],
image.shape, windows[i])
image.shape, molded_images[i].shape,
windows[i])
results.append({
"rois": final_rois,
"class_ids": final_class_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册