提交 571880ce 编写于 作者: V Vivek Rathod 提交者: TF Object Detection Team

Filter our unrecognized `image_classes_field` entries (i.e -1s)

PiperOrigin-RevId: 336321500
上级 17f2f812
......@@ -88,7 +88,8 @@ def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes):
def _remove_unrecognized_classes(class_ids, unrecognized_label):
"""Returns class ids with unrecognized classes filtered out."""
recognized_indices = tf.where(tf.greater(class_ids, unrecognized_label))
recognized_indices = tf.squeeze(
tf.where(tf.greater(class_ids, unrecognized_label)), -1)
return tf.gather(class_ids, recognized_indices)
......@@ -213,6 +214,8 @@ def transform_input_data(tensor_dict,
out_tensor_dict[labeled_classes_field], num_classes)
if image_classes_field in out_tensor_dict:
out_tensor_dict[image_classes_field] = _remove_unrecognized_classes(
out_tensor_dict[image_classes_field], unrecognized_label=-1)
out_tensor_dict[labeled_classes_field] = _convert_labeled_classes_to_k_hot(
out_tensor_dict[image_classes_field], num_classes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册