提交 01e69849 编写于 作者: C Cory Pruce

adding waleeds tips

上级 d6db1229
......@@ -781,11 +781,11 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# TODO: Filter out boxes with zero area
# Filter out background boxes
keep = tf.where(class_ids > 0)[0]
keep = tf.where(class_ids > 0)[:,0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = tf.sets.set_intersection(
keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0])
keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:,0])
# Apply per-class NMS
pre_nms_class_ids = tf.gather(class_ids, keep) #class_ids[keep]
......@@ -799,7 +799,9 @@ def refine_detections_graph(rois, probs, deltas, window, config):
nms_keep = []
def nms_keep_map(class_id):
ixs = tf.where(pre_nms_class_ids == class_id)[0]
print('pre_nms_class_ids.shape', pre_nms_class_ids.shape)
print('class_id', class_id.shape)
ixs = tf.where(pre_nms_class_ids == tf.expand_dims(class_id, -1))[0]
# Apply NMS
class_keep = tf.image.non_max_suppression(
......@@ -810,15 +812,19 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# Map indicies
return tf.gather(keep, tf.gather(ixs, class_keep))
print('uniq_pre_nms_class_ids: {}'.format(uniq_pre_nms_class_ids.shape))
nms_keep = tf.to_int64(tf.unique(tf.concat(
tf.map_fn(nms_keep_map, uniq_pre_nms_class_ids), axis=0))[0])
print(keep.shape, nms_keep.shape)
print(keep.dtype, nms_keep.dtype)
#keep = tf.sets.set_intersection(keep, nms_keep)
#tf.to_int32(
"""keep = tf.sets.set_intersection(
tf.expand_dims(keep, 0),
tf.expand_dims(tf.sparse_to_dense(nms_keep), 0))[1]
"""#tf.to_int32(
#np.intersect1d(keep, nms_keep).astype(np.int32)
result_keep = tf.concat([keep,nms_keep], axis = 0)
print('result_keep: {}'.format(result_keep.shape))
output_keep, idx_keep, count_keep = tf.unique_with_counts(result_keep)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册