diff --git a/research/object_detection/inputs.py b/research/object_detection/inputs.py index 699ce7e2d2b11a9c5b929845b78af01afa182a83..d95b67fd5efad00891c3582920a3d89fff8de060 100644 --- a/research/object_detection/inputs.py +++ b/research/object_detection/inputs.py @@ -88,7 +88,8 @@ def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes): def _remove_unrecognized_classes(class_ids, unrecognized_label): """Returns class ids with unrecognized classes filtered out.""" - recognized_indices = tf.where(tf.greater(class_ids, unrecognized_label)) + recognized_indices = tf.squeeze( + tf.where(tf.greater(class_ids, unrecognized_label)), -1) return tf.gather(class_ids, recognized_indices) @@ -213,6 +214,8 @@ def transform_input_data(tensor_dict, out_tensor_dict[labeled_classes_field], num_classes) if image_classes_field in out_tensor_dict: + out_tensor_dict[image_classes_field] = _remove_unrecognized_classes( + out_tensor_dict[image_classes_field], unrecognized_label=-1) out_tensor_dict[labeled_classes_field] = _convert_labeled_classes_to_k_hot( out_tensor_dict[image_classes_field], num_classes)