diff --git a/model.py b/model.py index a8d993e4b943cdd07572d3cb520978ada39ec0f8..48560646acdac9eacdbb62e2977e9fb7896e6662 100644 --- a/model.py +++ b/model.py @@ -781,11 +781,11 @@ def refine_detections_graph(rois, probs, deltas, window, config): # TODO: Filter out boxes with zero area # Filter out background boxes - keep = tf.where(class_ids > 0)[0] + keep = tf.where(class_ids > 0)[:,0] # Filter out low confidence boxes if config.DETECTION_MIN_CONFIDENCE: keep = tf.sets.set_intersection( - keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0]) + keep, tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:,0]) # Apply per-class NMS pre_nms_class_ids = tf.gather(class_ids, keep) #class_ids[keep] @@ -799,7 +799,9 @@ def refine_detections_graph(rois, probs, deltas, window, config): nms_keep = [] def nms_keep_map(class_id): - ixs = tf.where(pre_nms_class_ids == class_id)[0] + print('pre_nms_class_ids.shape', pre_nms_class_ids.shape) + print('class_id', class_id.shape) + ixs = tf.where(pre_nms_class_ids == tf.expand_dims(class_id, -1))[0] # Apply NMS class_keep = tf.image.non_max_suppression( @@ -810,15 +812,19 @@ def refine_detections_graph(rois, probs, deltas, window, config): # Map indicies return tf.gather(keep, tf.gather(ixs, class_keep)) - + + print('uniq_pre_nms_class_ids: {}'.format(uniq_pre_nms_class_ids.shape)) nms_keep = tf.to_int64(tf.unique(tf.concat( tf.map_fn(nms_keep_map, uniq_pre_nms_class_ids), axis=0))[0]) print(keep.shape, nms_keep.shape) print(keep.dtype, nms_keep.dtype) - #keep = tf.sets.set_intersection(keep, nms_keep) - #tf.to_int32( + """keep = tf.sets.set_intersection( + tf.expand_dims(keep, 0), + tf.expand_dims(tf.sparse_to_dense(nms_keep), 0))[1] + """#tf.to_int32( #np.intersect1d(keep, nms_keep).astype(np.int32) + result_keep = tf.concat([keep,nms_keep], axis = 0) print('result_keep: {}'.format(result_keep.shape)) output_keep, idx_keep, count_keep = tf.unique_with_counts(result_keep)