未验证 提交 9a560122 编写于 作者: W Waleed 提交者: GitHub

Merge pull request #167 from Cpruce/cpruce/TfDetectionLayer

Convert Detection Layer from Python to TensorFlow ops.
......@@ -220,6 +220,7 @@ def clip_boxes_graph(boxes, window):
y2 = tf.maximum(tf.minimum(y2, wy2), wy1)
x2 = tf.maximum(tf.minimum(x2, wx2), wx1)
clipped = tf.concat([y1, x1, y2, x2], axis=1, name="clipped_boxes")
clipped.set_shape((clipped.shape[0], 4))
return clipped
......@@ -666,7 +667,7 @@ def clip_to_window(window, boxes):
return boxes
def refine_detections(rois, probs, deltas, window, config):
def refine_detections_graph(rois, probs, deltas, window, config):
"""Refine classified proposals and filter overlaps and return final
detections.
......@@ -678,64 +679,96 @@ def refine_detections(rois, probs, deltas, window, config):
window: (y1, x1, y2, x2) in image coordinates. The part of the image
that contains the image excluding the padding.
Returns detections shaped: [N, (y1, x1, y2, x2, class_id, score)]
Returns detections shaped: [N, (y1, x1, y2, x2, class_id, score)] where
coordinates are in image domain.
"""
# Class IDs per ROI
class_ids = np.argmax(probs, axis=1)
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
class_scores = probs[np.arange(class_ids.shape[0]), class_ids]
indices = tf.stack([tf.range(probs.shape[0]), class_ids], axis=1)
class_scores = tf.gather_nd(probs, indices)
# Class-specific bounding box deltas
deltas_specific = deltas[np.arange(deltas.shape[0]), class_ids]
deltas_specific = tf.gather_nd(deltas, indices)
# Apply bounding box deltas
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
refined_rois = utils.apply_box_deltas(
refined_rois = apply_box_deltas_graph(
rois, deltas_specific * config.BBOX_STD_DEV)
# Convert coordiates to image domain
# TODO: better to keep them normalized until later
height, width = config.IMAGE_SHAPE[:2]
refined_rois *= np.array([height, width, height, width])
refined_rois *= tf.constant([height, width, height, width], dtype=tf.float32)
# Clip boxes to image window
refined_rois = clip_to_window(window, refined_rois)
refined_rois = clip_boxes_graph(refined_rois, window)
# Round and cast to int since we're deadling with pixels now
refined_rois = np.rint(refined_rois).astype(np.int32)
refined_rois = tf.to_int32(tf.rint(refined_rois))
# TODO: Filter out boxes with zero area
# Filter out background boxes
keep = np.where(class_ids > 0)[0]
keep = tf.where(class_ids > 0)[:, 0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = np.intersect1d(
keep, np.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0])
conf_keep = tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:, 0]
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
tf.expand_dims(conf_keep, 0))
keep = tf.sparse_tensor_to_dense(keep)[0]
# Apply per-class NMS
pre_nms_class_ids = class_ids[keep]
pre_nms_scores = class_scores[keep]
pre_nms_rois = refined_rois[keep]
nms_keep = []
for class_id in np.unique(pre_nms_class_ids):
# Pick detections of this class
ixs = np.where(pre_nms_class_ids == class_id)[0]
# 1. Prepare variables
pre_nms_class_ids = tf.gather(class_ids, keep)
pre_nms_scores = tf.gather(class_scores, keep)
pre_nms_rois = tf.gather(refined_rois, keep)
unique_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]
def nms_keep_map(class_id):
"""Apply Non-Maximum Suppression on ROIs of the given class."""
# Indices of ROIs of the given class
ixs = tf.where(tf.equal(pre_nms_class_ids, class_id))[:, 0]
# Apply NMS
class_keep = utils.non_max_suppression(
pre_nms_rois[ixs], pre_nms_scores[ixs],
config.DETECTION_NMS_THRESHOLD)
class_keep = tf.image.non_max_suppression(
tf.to_float(tf.gather(pre_nms_rois, ixs)),
tf.gather(pre_nms_scores, ixs),
max_output_size=config.DETECTION_MAX_INSTANCES,
iou_threshold=config.DETECTION_NMS_THRESHOLD)
# Map indicies
class_keep = keep[ixs[class_keep]]
nms_keep = np.union1d(nms_keep, class_keep)
keep = np.intersect1d(keep, nms_keep).astype(np.int32)
class_keep = tf.gather(keep, tf.gather(ixs, class_keep))
# Pad with -1 so returned tensors have the same shape
gap = config.DETECTION_MAX_INSTANCES - tf.shape(class_keep)[0]
class_keep = tf.pad(class_keep, [(0, gap)],
mode='CONSTANT', constant_values=-1)
# Set shape so map_fn() can infer result shape
class_keep.set_shape([config.DETECTION_MAX_INSTANCES])
return class_keep
# 2. Map over class IDs
nms_keep = tf.map_fn(nms_keep_map, unique_pre_nms_class_ids,
dtype=tf.int64)
# 3. Merge results into one list, and remove -1 padding
nms_keep = tf.reshape(nms_keep, [-1])
nms_keep = tf.gather(nms_keep, tf.where(nms_keep > -1)[:, 0])
# 4. Compute intersection between keep and nms_keep
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
tf.expand_dims(nms_keep, 0))
keep = tf.sparse_tensor_to_dense(keep)[0]
# Keep top detections
roi_count = config.DETECTION_MAX_INSTANCES
top_ids = np.argsort(class_scores[keep])[::-1][:roi_count]
keep = keep[top_ids]
class_scores_keep = tf.gather(class_scores, keep)
num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]
keep = tf.gather(keep, top_ids)
# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are in image domain.
result = np.hstack((refined_rois[keep],
class_ids[keep][..., np.newaxis],
class_scores[keep][..., np.newaxis]))
return result
detections = tf.concat([
tf.to_float(tf.gather(refined_rois, keep)),
tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis],
tf.gather(class_scores, keep)[..., tf.newaxis]
], axis=1)
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
gap = config.DETECTION_MAX_INSTANCES - tf.shape(detections)[0]
detections = tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")
return detections
class DetectionLayer(KE.Layer):
......@@ -743,7 +776,8 @@ class DetectionLayer(KE.Layer):
returns the final detection boxes.
Returns:
[batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
[batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] where
coordinates are in image domain
"""
def __init__(self, config=None, **kwargs):
......@@ -751,29 +785,23 @@ class DetectionLayer(KE.Layer):
self.config = config
def call(self, inputs):
def wrapper(rois, mrcnn_class, mrcnn_bbox, image_meta):
detections_batch = []
_, _, window, _ = parse_image_meta(image_meta)
for b in range(self.config.BATCH_SIZE):
detections = refine_detections(
rois[b], mrcnn_class[b], mrcnn_bbox[b], window[b], self.config)
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
gap = self.config.DETECTION_MAX_INSTANCES - detections.shape[0]
assert gap >= 0
if gap > 0:
detections = np.pad(
detections, [(0, gap), (0, 0)], 'constant', constant_values=0)
detections_batch.append(detections)
# Stack detections and cast to float32
# TODO: track where float64 is introduced
detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
return np.reshape(detections_batch, [self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])
# Return wrapped function
return tf.py_func(wrapper, inputs, tf.float32)
rois = inputs[0]
mrcnn_class = inputs[1]
mrcnn_bbox = inputs[2]
image_meta = inputs[3]
# Run detection refinement graph on each item in the batch
_, _, window, _ = parse_image_meta_graph(image_meta)
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
self.config.IMAGES_PER_GPU)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
return tf.reshape(
detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])
def compute_output_shape(self, input_shape):
return (None, self.config.DETECTION_MAX_INSTANCES, 6)
......@@ -2456,8 +2484,7 @@ class MaskRCNN():
############################################################
def compose_image_meta(image_id, image_shape, window, active_class_ids):
"""Takes attributes of an image and puts them in one 1D array. Use
parse_image_meta() to parse the values back.
"""Takes attributes of an image and puts them in one 1D array.
image_id: An int ID of the image. Useful for debugging.
image_shape: [height, width, channels]
......@@ -2476,18 +2503,6 @@ def compose_image_meta(image_id, image_shape, window, active_class_ids):
return meta
# Two functions (for Numpy and TF) to parse image_meta tensors.
def parse_image_meta(meta):
"""Parses an image info Numpy array to its components.
See compose_image_meta() for more details.
"""
image_id = meta[:, 0]
image_shape = meta[:, 1:4]
window = meta[:, 4:8] # (y1, x1, y2, x2) window of image in in pixels
active_class_ids = meta[:, 8:]
return image_id, image_shape, window, active_class_ids
def parse_image_meta_graph(meta):
"""Parses a tensor that contains image attributes to its components.
See compose_image_meta() for more details.
......@@ -2496,7 +2511,7 @@ def parse_image_meta_graph(meta):
"""
image_id = meta[:, 0]
image_shape = meta[:, 1:4]
window = meta[:, 4:8]
window = meta[:, 4:8] # (y1, x1, y2, x2) window of image in in pixels
active_class_ids = meta[:, 8:]
return [image_id, image_shape, window, active_class_ids]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册