提交 4ec44c0f 编写于 作者: S Shenoy 提交者: Cory Pruce

minor changes & more cleanup

上级 f98c6a74
...@@ -744,7 +744,7 @@ def refine_detections_graph(rois, probs, deltas, window, config): ...@@ -744,7 +744,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
iou_threshold=config.DETECTION_NMS_THRESHOLD) iou_threshold=config.DETECTION_NMS_THRESHOLD)
# Map indicies # Map indicies
cur_keep_indexes = tf.gather(keep, tf.gather(ixs, class_keep)) cur_keep_indexes = tf.gather(tf.cast(keep,tf.int32), tf.gather(ixs, class_keep))
return i+1, tf.concat([ret,cur_keep_indexes], axis=0) return i+1, tf.concat([ret,cur_keep_indexes], axis=0)
nums_iters = tf.shape(uniq_pre_nms_class_ids)[0] # unique class ids nums_iters = tf.shape(uniq_pre_nms_class_ids)[0] # unique class ids
...@@ -759,9 +759,9 @@ def refine_detections_graph(rois, probs, deltas, window, config): ...@@ -759,9 +759,9 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# remove initial_value background # remove initial_value background
nms_keep = tf.gather(nms_keep, tf.where(nms_keep >= 0)[:,0]) nms_keep = tf.gather(nms_keep, tf.where(nms_keep >= 0)[:,0])
keep = tf.sparse_tensor_to_dense(tf.sets.set_intersection(tf.expand_dims(keep,0), tf.expand_dims(nms_keep,0)))[0]
keep = tf.cast(keep, tf.int32) keep = tf.cast(keep, tf.int32)
keep = tf.sparse_tensor_to_dense(tf.sets.set_intersection(tf.expand_dims(keep,0), tf.expand_dims(nms_keep,0)))[0]
# Keep top detections # Keep top detections
roi_count = tf.convert_to_tensor(config.DETECTION_MAX_INSTANCES) roi_count = tf.convert_to_tensor(config.DETECTION_MAX_INSTANCES)
...@@ -775,7 +775,6 @@ def refine_detections_graph(rois, probs, deltas, window, config): ...@@ -775,7 +775,6 @@ def refine_detections_graph(rois, probs, deltas, window, config):
refined_rois_keep = tf.gather(tf.to_float(refined_rois), keep) refined_rois_keep = tf.gather(tf.to_float(refined_rois), keep)
class_ids_keep = tf.gather(tf.to_float(class_ids), keep)[..., tf.newaxis] class_ids_keep = tf.gather(tf.to_float(class_ids), keep)[..., tf.newaxis]
class_scores_keep = tf.gather(class_scores, keep)[..., tf.newaxis] class_scores_keep = tf.gather(class_scores, keep)[..., tf.newaxis]
# Arrange output as [N, (y1, x1, y2, x2, class_id, score)] # Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are in image domain. # Coordinates are in image domain.
detections = tf.concat((refined_rois_keep, class_ids_keep, detections = tf.concat((refined_rois_keep, class_ids_keep,
...@@ -784,14 +783,12 @@ def refine_detections_graph(rois, probs, deltas, window, config): ...@@ -784,14 +783,12 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# Pad with zeros if detections < DETECTION_MAX_INSTANCES # Pad with zeros if detections < DETECTION_MAX_INSTANCES
num_detections = tf.shape(detections)[0] num_detections = tf.shape(detections)[0]
gap = roi_count - num_detections gap = roi_count - num_detections
print(gap, roi_count, num_detections)
pred = tf.less(tf.constant(0), gap) pred = tf.less(tf.constant(0), gap)
#assert gap >= 0 #assert gap >= 0
#if gap > 0: #if gap > 0:
# paddings = tf.constant([[0, gap], [0, 0]]) # paddings = tf.constant([[0, gap], [0, 0]])
# detections = tf.pad(detections, paddings, "CONSTANT") # detections = tf.pad(detections, paddings, "CONSTANT")
def pad_detections(): def pad_detections():
print(detections.shape)
return tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT") return tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")
detections = tf.cond(pred, pad_detections, lambda: detections) detections = tf.cond(pred, pad_detections, lambda: detections)
...@@ -816,11 +813,9 @@ class DetectionLayer(KE.Layer): ...@@ -816,11 +813,9 @@ class DetectionLayer(KE.Layer):
mrcnn_class = inputs[1] mrcnn_class = inputs[1]
mrcnn_bbox = inputs[2] mrcnn_bbox = inputs[2]
image_meta = inputs[3] image_meta = inputs[3]
print(rois.shape, mrcnn_class.shape, mrcnn_bbox.shape, image_meta.shape)
#parse_image_meta can be reused as slicing works same way in TF & numpy #parse_image_meta can be reused as slicing works same way in TF & numpy
_, _, window, _ = parse_image_meta(image_meta) _, _, window, _ = parse_image_meta_graph(image_meta)
print('window after: ', window.shape)
detections_batch = utils.batch_slice( detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window], [rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config), lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
...@@ -832,7 +827,6 @@ class DetectionLayer(KE.Layer): ...@@ -832,7 +827,6 @@ class DetectionLayer(KE.Layer):
#detections_batch = np.array(detections_batch).astype(np.float32) #detections_batch = np.array(detections_batch).astype(np.float32)
# Reshape output # Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels # [batch, num_detections, (y1, x1, y2, x2, class_score)] in pixels
return tf.reshape( return tf.reshape(
detections_batch, detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6]) [self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册