diff --git a/research/object_detection/core/standard_fields.py b/research/object_detection/core/standard_fields.py index 9bcca662cd366ff1cb1e1a1673d68bf20e034be3..789296c38bb3186f008803d282d73e3bded1c965 100644 --- a/research/object_detection/core/standard_fields.py +++ b/research/object_detection/core/standard_fields.py @@ -70,6 +70,9 @@ class InputDataFields(object): groundtruth_keypoint_visibilities: ground truth keypoint visibilities. groundtruth_keypoint_weights: groundtruth weight factor for keypoints. groundtruth_label_weights: groundtruth label weights. + groundtruth_verified_negative_classes: groundtruth verified negative classes + groundtruth_not_exhaustive_classes: groundtruth not-exhaustively labeled + classes. groundtruth_weights: groundtruth weight factor for bounding boxes. groundtruth_dp_num_points: The number of DensePose sampled points for each instance. @@ -120,6 +123,8 @@ class InputDataFields(object): groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' groundtruth_keypoint_weights = 'groundtruth_keypoint_weights' groundtruth_label_weights = 'groundtruth_label_weights' + groundtruth_verified_neg_classes = 'groundtruth_verified_neg_classes' + groundtruth_not_exhaustive_classes = 'groundtruth_not_exhaustive_classes' groundtruth_weights = 'groundtruth_weights' groundtruth_dp_num_points = 'groundtruth_dp_num_points' groundtruth_dp_part_ids = 'groundtruth_dp_part_ids' diff --git a/research/object_detection/eval_util_test.py b/research/object_detection/eval_util_test.py index 0c6fdb753c5609ebfff121385c5270a144990e7e..a39a5ff16749fdfbb091448c444c02de5d524b36 100644 --- a/research/object_detection/eval_util_test.py +++ b/research/object_detection/eval_util_test.py @@ -85,6 +85,8 @@ class EvalUtilTest(test_case.TestCase, parameterized.TestCase): groundtruth_boxes = tf.constant([[0., 0., 1., 1.]]) groundtruth_classes = tf.constant([1]) groundtruth_instance_masks = tf.ones(shape=[1, 20, 20], dtype=tf.uint8) + original_image_spatial_shapes = tf.constant([[20, 20]], dtype=tf.int32) + groundtruth_keypoints = tf.constant([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]]) if resized_groundtruth_masks: groundtruth_instance_masks = tf.ones(shape=[1, 10, 10], dtype=tf.uint8) @@ -100,6 +102,8 @@ class EvalUtilTest(test_case.TestCase, parameterized.TestCase): groundtruth_keypoints = tf.tile( tf.expand_dims(groundtruth_keypoints, 0), multiples=[batch_size, 1, 1]) + original_image_spatial_shapes = tf.tile(original_image_spatial_shapes, + multiples=[batch_size, 1]) detections = { detection_fields.detection_boxes: detection_boxes, @@ -112,7 +116,10 @@ class EvalUtilTest(test_case.TestCase, parameterized.TestCase): input_data_fields.groundtruth_boxes: groundtruth_boxes, input_data_fields.groundtruth_classes: groundtruth_classes, input_data_fields.groundtruth_keypoints: groundtruth_keypoints, - input_data_fields.groundtruth_instance_masks: groundtruth_instance_masks + input_data_fields.groundtruth_instance_masks: + groundtruth_instance_masks, + input_data_fields.original_image_spatial_shape: + original_image_spatial_shapes } if batch_size > 1: return eval_util.result_dict_for_batched_example( diff --git a/research/object_detection/metrics/coco_evaluation.py b/research/object_detection/metrics/coco_evaluation.py index 35b336acd0172e4971634d32cd5b5f09db7361e3..9c5e3056eb5bbd42510d9b41cfcd5df3cbc00268 100644 --- a/research/object_detection/metrics/coco_evaluation.py +++ b/research/object_detection/metrics/coco_evaluation.py @@ -1191,18 +1191,20 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): groundtruth_instance_masks_batched, groundtruth_is_crowd_batched, num_gt_boxes_per_image, detection_scores_batched, detection_classes_batched, - detection_masks_batched, num_det_boxes_per_image): + detection_masks_batched, num_det_boxes_per_image, + original_image_spatial_shape): """Update op for metrics.""" for (image_id, groundtruth_boxes, groundtruth_classes, groundtruth_instance_masks, groundtruth_is_crowd, num_gt_box, detection_scores, detection_classes, - detection_masks, num_det_box) in zip( + detection_masks, num_det_box, original_image_shape) in zip( image_id_batched, groundtruth_boxes_batched, groundtruth_classes_batched, groundtruth_instance_masks_batched, groundtruth_is_crowd_batched, num_gt_boxes_per_image, detection_scores_batched, detection_classes_batched, - detection_masks_batched, num_det_boxes_per_image): + detection_masks_batched, num_det_boxes_per_image, + original_image_spatial_shape): self.add_single_ground_truth_image_info( image_id, { 'groundtruth_boxes': @@ -1210,7 +1212,8 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): 'groundtruth_classes': groundtruth_classes[:num_gt_box], 'groundtruth_instance_masks': - groundtruth_instance_masks[:num_gt_box], + groundtruth_instance_masks[:num_gt_box][ + :original_image_shape[0], :original_image_shape[1]], 'groundtruth_is_crowd': groundtruth_is_crowd[:num_gt_box] }) @@ -1218,13 +1221,16 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): image_id, { 'detection_scores': detection_scores[:num_det_box], 'detection_classes': detection_classes[:num_det_box], - 'detection_masks': detection_masks[:num_det_box] + 'detection_masks': detection_masks[:num_det_box][ + :original_image_shape[0], :original_image_shape[1]] }) # Unpack items from the evaluation dictionary. input_data_fields = standard_fields.InputDataFields detection_fields = standard_fields.DetectionResultFields image_id = eval_dict[input_data_fields.key] + original_image_spatial_shape = eval_dict[ + input_data_fields.original_image_spatial_shape] groundtruth_boxes = eval_dict[input_data_fields.groundtruth_boxes] groundtruth_classes = eval_dict[input_data_fields.groundtruth_classes] groundtruth_instance_masks = eval_dict[ @@ -1276,7 +1282,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): image_id, groundtruth_boxes, groundtruth_classes, groundtruth_instance_masks, groundtruth_is_crowd, num_gt_boxes_per_image, detection_scores, detection_classes, - detection_masks, num_det_boxes_per_image + detection_masks, num_det_boxes_per_image, original_image_spatial_shape ], []) def get_estimator_eval_metric_ops(self, eval_dict): diff --git a/research/object_detection/metrics/coco_evaluation_test.py b/research/object_detection/metrics/coco_evaluation_test.py index eda3590cfbf5f6ec76b20f857ff45f87d9255085..4d6dc2c1b562db294dead2798eb2f7de23963a7e 100644 --- a/research/object_detection/metrics/coco_evaluation_test.py +++ b/research/object_detection/metrics/coco_evaluation_test.py @@ -1601,6 +1601,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4)) groundtruth_classes = tf.placeholder(tf.float32, shape=(None)) groundtruth_masks = tf.placeholder(tf.uint8, shape=(None, None, None)) + original_image_spatial_shape = tf.placeholder(tf.int32, shape=(None, 2)) detection_scores = tf.placeholder(tf.float32, shape=(None)) detection_classes = tf.placeholder(tf.float32, shape=(None)) detection_masks = tf.placeholder(tf.uint8, shape=(None, None, None)) @@ -1612,6 +1613,8 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): input_data_fields.groundtruth_boxes: groundtruth_boxes, input_data_fields.groundtruth_classes: groundtruth_classes, input_data_fields.groundtruth_instance_masks: groundtruth_masks, + input_data_fields.original_image_spatial_shape: + original_image_spatial_shape, detection_fields.detection_scores: detection_scores, detection_fields.detection_classes: detection_classes, detection_fields.detection_masks: detection_masks, @@ -1637,6 +1640,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): np.ones([50, 50], dtype=np.uint8), ((0, 70), (0, 70)), mode='constant') ]), + original_image_spatial_shape: np.array([[120, 120]]), detection_scores: np.array([.9, .8]), detection_classes: @@ -1661,6 +1665,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4)) groundtruth_classes = tf.placeholder(tf.float32, shape=(None)) groundtruth_masks = tf.placeholder(tf.uint8, shape=(None, None, None)) + original_image_spatial_shape = tf.placeholder(tf.int32, shape=(None, 2)) detection_scores = tf.placeholder(tf.float32, shape=(None)) detection_classes = tf.placeholder(tf.float32, shape=(None)) detection_masks = tf.placeholder(tf.uint8, shape=(None, None, None)) @@ -1672,6 +1677,8 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): input_data_fields.groundtruth_boxes: groundtruth_boxes, input_data_fields.groundtruth_classes: groundtruth_classes, input_data_fields.groundtruth_instance_masks: groundtruth_masks, + input_data_fields.original_image_spatial_shape: + original_image_spatial_shape, detection_fields.detection_scores: detection_scores, detection_fields.detection_classes: detection_classes, detection_fields.detection_masks: detection_masks, @@ -1701,6 +1708,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): np.ones([50, 50], dtype=np.uint8), ((0, 70), (0, 70)), mode='constant') ]), + original_image_spatial_shape: np.array([[120, 120], [120, 120]]), detection_scores: np.array([.9, .8]), detection_classes: @@ -1725,6 +1733,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): dtype=np.uint8), ((0, 0), (10, 10), (10, 10)), mode='constant'), + original_image_spatial_shape: np.array([[70, 70]]), detection_scores: np.array([.8]), detection_classes: np.array([1]), detection_masks: np.pad(np.ones([1, 50, 50], dtype=np.uint8), @@ -1740,6 +1749,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): dtype=np.uint8), ((0, 0), (10, 10), (10, 10)), mode='constant'), + original_image_spatial_shape: np.array([[45, 45]]), detection_scores: np.array([.8]), detection_classes: np.array([1]), detection_masks: np.pad(np.ones([1, 25, 25], @@ -1778,6 +1788,7 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): groundtruth_classes = tf.placeholder(tf.float32, shape=(batch_size, None)) groundtruth_masks = tf.placeholder( tf.uint8, shape=(batch_size, None, None, None)) + original_image_spatial_shape = tf.placeholder(tf.int32, shape=(None, 2)) detection_scores = tf.placeholder(tf.float32, shape=(batch_size, None)) detection_classes = tf.placeholder(tf.float32, shape=(batch_size, None)) detection_masks = tf.placeholder( @@ -1790,6 +1801,8 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): input_data_fields.groundtruth_boxes: groundtruth_boxes, input_data_fields.groundtruth_classes: groundtruth_classes, input_data_fields.groundtruth_instance_masks: groundtruth_masks, + input_data_fields.original_image_spatial_shape: + original_image_spatial_shape, detection_fields.detection_scores: detection_scores, detection_fields.detection_classes: detection_classes, detection_fields.detection_masks: detection_masks, @@ -1826,6 +1839,8 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase): mode='constant') ], axis=0), + original_image_spatial_shape: np.array( + [[100, 100], [100, 100], [100, 100]]), detection_scores: np.array([[.8], [.8], [.8]]), detection_classes: diff --git a/research/object_detection/metrics/lvis_evaluation.py b/research/object_detection/metrics/lvis_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..95fc12d2032fb7afba780468984fd91cdbc3c0c8 --- /dev/null +++ b/research/object_detection/metrics/lvis_evaluation.py @@ -0,0 +1,443 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class for evaluating object detections with LVIS metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import re + +from lvis import results as lvis_results + +import numpy as np +from six.moves import zip +import tensorflow.compat.v1 as tf + +from object_detection.core import standard_fields as fields +from object_detection.metrics import lvis_tools +from object_detection.utils import object_detection_evaluation + + +def convert_masks_to_binary(masks): + """Converts masks to 0 or 1 and uint8 type.""" + return (masks > 0).astype(np.uint8) + + +class LVISMaskEvaluator(object_detection_evaluation.DetectionEvaluator): + """Class to evaluate LVIS mask metrics.""" + + def __init__(self, + categories): + """Constructor. + + Args: + categories: A list of dicts, each of which has the following keys - + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name e.g., 'cat', 'dog'. + """ + super(LVISMaskEvaluator, self).__init__(categories) + self._image_ids_with_detections = set([]) + self._groundtruth_list = [] + self._detection_masks_list = [] + self._category_id_set = set([cat['id'] for cat in self._categories]) + self._annotation_id = 1 + self._image_id_to_mask_shape_map = {} + self._image_id_to_verified_neg_classes = {} + self._image_id_to_not_exhaustive_classes = {} + + def clear(self): + """Clears the state to prepare for a fresh evaluation.""" + self._image_id_to_mask_shape_map.clear() + self._image_ids_with_detections.clear() + self._image_id_to_verified_neg_classes.clear() + self._image_id_to_not_exhaustive_classes.clear() + self._groundtruth_list = [] + self._detection_masks_list = [] + + def add_single_ground_truth_image_info(self, + image_id, + groundtruth_dict): + """Adds groundtruth for a single image to be used for evaluation. + + If the image has already been added, a warning is logged, and groundtruth is + ignored. + + Args: + image_id: A unique string/integer identifier for the image. + groundtruth_dict: A dictionary containing - + InputDataFields.groundtruth_boxes: float32 numpy array of shape + [num_boxes, 4] containing `num_boxes` groundtruth boxes of the format + [ymin, xmin, ymax, xmax] in absolute image coordinates. + InputDataFields.groundtruth_classes: integer numpy array of shape + [num_boxes] containing 1-indexed groundtruth classes for the boxes. + InputDataFields.groundtruth_instance_masks: uint8 numpy array of shape + [num_masks, image_height, image_width] containing groundtruth masks. + The elements of the array must be in {0, 1}. + InputDataFields.groundtruth_verified_neg_classes: [num_classes] + float indicator vector with values in {0, 1}. + InputDataFields.groundtruth_not_exhaustive_classes: [num_classes] + float indicator vector with values in {0, 1}. + InputDataFields.groundtruth_area (optional): float numpy array of + shape [num_boxes] containing the area (in the original absolute + coordinates) of the annotated object. + Raises: + ValueError: if groundtruth_dict is missing a required field + """ + if image_id in self._image_id_to_mask_shape_map: + tf.logging.warning('Ignoring ground truth with image id %s since it was ' + 'previously added', image_id) + return + for key in [fields.InputDataFields.groundtruth_boxes, + fields.InputDataFields.groundtruth_classes, + fields.InputDataFields.groundtruth_instance_masks, + fields.InputDataFields.groundtruth_verified_neg_classes, + fields.InputDataFields.groundtruth_not_exhaustive_classes]: + if key not in groundtruth_dict.keys(): + raise ValueError('groundtruth_dict missing entry: {}'.format(key)) + + groundtruth_instance_masks = groundtruth_dict[ + fields.InputDataFields.groundtruth_instance_masks] + groundtruth_instance_masks = convert_masks_to_binary( + groundtruth_instance_masks) + verified_neg_classes_shape = groundtruth_dict[ + fields.InputDataFields.groundtruth_verified_neg_classes].shape + not_exhaustive_classes_shape = groundtruth_dict[ + fields.InputDataFields.groundtruth_not_exhaustive_classes].shape + if verified_neg_classes_shape != (len(self._category_id_set),): + raise ValueError('Invalid shape for verified_neg_classes_shape.') + if not_exhaustive_classes_shape != (len(self._category_id_set),): + raise ValueError('Invalid shape for not_exhaustive_classes_shape.') + self._image_id_to_verified_neg_classes[image_id] = np.flatnonzero( + groundtruth_dict[ + fields.InputDataFields.groundtruth_verified_neg_classes] + == 1).tolist() + self._image_id_to_not_exhaustive_classes[image_id] = np.flatnonzero( + groundtruth_dict[ + fields.InputDataFields.groundtruth_not_exhaustive_classes] + == 1).tolist() + + # Drop optional fields if empty tensor. + groundtruth_area = groundtruth_dict.get( + fields.InputDataFields.groundtruth_area) + if groundtruth_area is not None and not groundtruth_area.shape[0]: + groundtruth_area = None + + self._groundtruth_list.extend( + lvis_tools.ExportSingleImageGroundtruthToLVIS( + image_id=image_id, + next_annotation_id=self._annotation_id, + category_id_set=self._category_id_set, + groundtruth_boxes=groundtruth_dict[ + fields.InputDataFields.groundtruth_boxes], + groundtruth_classes=groundtruth_dict[ + fields.InputDataFields.groundtruth_classes], + groundtruth_masks=groundtruth_instance_masks, + groundtruth_area=groundtruth_area) + ) + + self._annotation_id += groundtruth_dict[fields.InputDataFields. + groundtruth_boxes].shape[0] + self._image_id_to_mask_shape_map[image_id] = groundtruth_dict[ + fields.InputDataFields.groundtruth_instance_masks].shape + + def add_single_detected_image_info(self, + image_id, + detections_dict): + """Adds detections for a single image to be used for evaluation. + + If a detection has already been added for this image id, a warning is + logged, and the detection is skipped. + + Args: + image_id: A unique string/integer identifier for the image. + detections_dict: A dictionary containing - + DetectionResultFields.detection_scores: float32 numpy array of shape + [num_boxes] containing detection scores for the boxes. + DetectionResultFields.detection_classes: integer numpy array of shape + [num_boxes] containing 1-indexed detection classes for the boxes. + DetectionResultFields.detection_masks: optional uint8 numpy array of + shape [num_boxes, image_height, image_width] containing instance + masks corresponding to the boxes. The elements of the array must be + in {0, 1}. + Raises: + ValueError: If groundtruth for the image_id is not available. + """ + if image_id not in self._image_id_to_mask_shape_map: + raise ValueError('Missing groundtruth for image id: {}'.format(image_id)) + + if image_id in self._image_ids_with_detections: + tf.logging.warning('Ignoring detection with image id %s since it was ' + 'previously added', image_id) + return + + groundtruth_masks_shape = self._image_id_to_mask_shape_map[image_id] + detection_masks = detections_dict[fields.DetectionResultFields. + detection_masks] + if groundtruth_masks_shape[1:] != detection_masks.shape[1:]: + raise ValueError('Spatial shape of groundtruth masks and detection masks ' + 'are incompatible: {} vs {}'.format( + groundtruth_masks_shape, + detection_masks.shape)) + detection_masks = convert_masks_to_binary(detection_masks) + + self._detection_masks_list.extend( + lvis_tools.ExportSingleImageDetectionMasksToLVIS( + image_id=image_id, + category_id_set=self._category_id_set, + detection_masks=detection_masks, + detection_scores=detections_dict[ + fields.DetectionResultFields.detection_scores], + detection_classes=detections_dict[ + fields.DetectionResultFields.detection_classes])) + self._image_ids_with_detections.update([image_id]) + + def evaluate(self): + """Evaluates the detection boxes and returns a dictionary of coco metrics. + + Returns: + A dictionary holding + """ + tf.logging.info('Performing evaluation on %d images.', + len(self._image_id_to_mask_shape_map.keys())) + # pylint: disable=g-complex-comprehension + groundtruth_dict = { + 'annotations': self._groundtruth_list, + 'images': [ + { + 'id': image_id, + 'height': shape[1], + 'width': shape[2], + 'neg_category_ids': + self._image_id_to_verified_neg_classes[image_id], + 'not_exhaustive_category_ids': + self._image_id_to_not_exhaustive_classes[image_id] + } for image_id, shape in self._image_id_to_mask_shape_map.items()], + 'categories': self._categories + } + # pylint: enable=g-complex-comprehension + lvis_wrapped_groundtruth = lvis_tools.LVISWrapper(groundtruth_dict) + detections = lvis_results.LVISResults(lvis_wrapped_groundtruth, + self._detection_masks_list) + mask_evaluator = lvis_tools.LVISEvalWrapper( + lvis_wrapped_groundtruth, detections, iou_type='segm') + mask_metrics = mask_evaluator.ComputeMetrics() + mask_metrics = {'DetectionMasks_'+ key: value + for key, value in iter(mask_metrics.items())} + return mask_metrics + + def add_eval_dict(self, eval_dict): + """Observes an evaluation result dict for a single example. + + When executing eagerly, once all observations have been observed by this + method you can use `.evaluate()` to get the final metrics. + + When using `tf.estimator.Estimator` for evaluation this function is used by + `get_estimator_eval_metric_ops()` to construct the metric update op. + + Args: + eval_dict: A dictionary that holds tensors for evaluating an object + detection model, returned from + eval_util.result_dict_for_single_example(). + + Returns: + None when executing eagerly, or an update_op that can be used to update + the eval metrics in `tf.estimator.EstimatorSpec`. + """ + def update_op(image_id_batched, groundtruth_boxes_batched, + groundtruth_classes_batched, + groundtruth_instance_masks_batched, + groundtruth_verified_neg_classes_batched, + groundtruth_not_exhaustive_classes_batched, + num_gt_boxes_per_image, + detection_scores_batched, detection_classes_batched, + detection_masks_batched, num_det_boxes_per_image, + original_image_spatial_shape): + """Update op for metrics.""" + + for (image_id, groundtruth_boxes, groundtruth_classes, + groundtruth_instance_masks, groundtruth_verified_neg_classes, + groundtruth_not_exhaustive_classes, num_gt_box, + detection_scores, detection_classes, + detection_masks, num_det_box, original_image_shape) in zip( + image_id_batched, groundtruth_boxes_batched, + groundtruth_classes_batched, groundtruth_instance_masks_batched, + groundtruth_verified_neg_classes_batched, + groundtruth_not_exhaustive_classes_batched, + num_gt_boxes_per_image, + detection_scores_batched, detection_classes_batched, + detection_masks_batched, num_det_boxes_per_image, + original_image_spatial_shape): + self.add_single_ground_truth_image_info( + image_id, { + input_data_fields.groundtruth_boxes: + groundtruth_boxes[:num_gt_box], + input_data_fields.groundtruth_classes: + groundtruth_classes[:num_gt_box], + input_data_fields.groundtruth_instance_masks: + groundtruth_instance_masks[:num_gt_box][ + :original_image_shape[0], :original_image_shape[1]], + input_data_fields.groundtruth_verified_neg_classes: + groundtruth_verified_neg_classes, + input_data_fields.groundtruth_not_exhaustive_classes: + groundtruth_not_exhaustive_classes + }) + self.add_single_detected_image_info( + image_id, { + 'detection_scores': detection_scores[:num_det_box], + 'detection_classes': detection_classes[:num_det_box], + 'detection_masks': detection_masks[:num_det_box][ + :original_image_shape[0], :original_image_shape[1]] + }) + + # Unpack items from the evaluation dictionary. + input_data_fields = fields.InputDataFields + detection_fields = fields.DetectionResultFields + image_id = eval_dict[input_data_fields.key] + original_image_spatial_shape = eval_dict[ + input_data_fields.original_image_spatial_shape] + groundtruth_boxes = eval_dict[input_data_fields.groundtruth_boxes] + groundtruth_classes = eval_dict[input_data_fields.groundtruth_classes] + groundtruth_instance_masks = eval_dict[ + input_data_fields.groundtruth_instance_masks] + groundtruth_verified_neg_classes = eval_dict[ + input_data_fields.groundtruth_verified_neg_classes] + groundtruth_not_exhaustive_classes = eval_dict[ + input_data_fields.groundtruth_not_exhaustive_classes] + + num_gt_boxes_per_image = eval_dict.get( + input_data_fields.num_groundtruth_boxes, None) + detection_scores = eval_dict[detection_fields.detection_scores] + detection_classes = eval_dict[detection_fields.detection_classes] + detection_masks = eval_dict[detection_fields.detection_masks] + num_det_boxes_per_image = eval_dict.get(detection_fields.num_detections, + None) + + if not image_id.shape.as_list(): + # Apply a batch dimension to all tensors. + image_id = tf.expand_dims(image_id, 0) + groundtruth_boxes = tf.expand_dims(groundtruth_boxes, 0) + groundtruth_classes = tf.expand_dims(groundtruth_classes, 0) + groundtruth_instance_masks = tf.expand_dims(groundtruth_instance_masks, 0) + groundtruth_verified_neg_classes = tf.expand_dims( + groundtruth_verified_neg_classes, 0) + groundtruth_not_exhaustive_classes = tf.expand_dims( + groundtruth_not_exhaustive_classes, 0) + detection_scores = tf.expand_dims(detection_scores, 0) + detection_classes = tf.expand_dims(detection_classes, 0) + detection_masks = tf.expand_dims(detection_masks, 0) + + if num_gt_boxes_per_image is None: + num_gt_boxes_per_image = tf.shape(groundtruth_boxes)[1:2] + else: + num_gt_boxes_per_image = tf.expand_dims(num_gt_boxes_per_image, 0) + + if num_det_boxes_per_image is None: + num_det_boxes_per_image = tf.shape(detection_scores)[1:2] + else: + num_det_boxes_per_image = tf.expand_dims(num_det_boxes_per_image, 0) + else: + if num_gt_boxes_per_image is None: + num_gt_boxes_per_image = tf.tile( + tf.shape(groundtruth_boxes)[1:2], + multiples=tf.shape(groundtruth_boxes)[0:1]) + if num_det_boxes_per_image is None: + num_det_boxes_per_image = tf.tile( + tf.shape(detection_scores)[1:2], + multiples=tf.shape(detection_scores)[0:1]) + + return tf.py_func(update_op, [ + image_id, groundtruth_boxes, groundtruth_classes, + groundtruth_instance_masks, groundtruth_verified_neg_classes, + groundtruth_not_exhaustive_classes, + num_gt_boxes_per_image, detection_scores, detection_classes, + detection_masks, num_det_boxes_per_image, original_image_spatial_shape + ], []) + + def get_estimator_eval_metric_ops(self, eval_dict): + """Returns a dictionary of eval metric ops. + + Note that once value_op is called, the detections and groundtruth added via + update_op are cleared. + + Args: + eval_dict: A dictionary that holds tensors for evaluating object detection + performance. For single-image evaluation, this dictionary may be + produced from eval_util.result_dict_for_single_example(). If multi-image + evaluation, `eval_dict` should contain the fields + 'num_groundtruth_boxes_per_image' and 'num_det_boxes_per_image' to + properly unpad the tensors from the batch. + + Returns: + a dictionary of metric names to tuple of value_op and update_op that can + be used as eval metric ops in tf.estimator.EstimatorSpec. Note that all + update ops must be run together and similarly all value ops must be run + together to guarantee correct behaviour. + """ + update_op = self.add_eval_dict(eval_dict) + metric_names = ['DetectionMasks_Precision/mAP', + 'DetectionMasks_Precision/mAP@.50IOU', + 'DetectionMasks_Precision/mAP@.75IOU', + 'DetectionMasks_Precision/mAP (small)', + 'DetectionMasks_Precision/mAP (medium)', + 'DetectionMasks_Precision/mAP (large)', + 'DetectionMasks_Recall/AR@1', + 'DetectionMasks_Recall/AR@10', + 'DetectionMasks_Recall/AR@100', + 'DetectionMasks_Recall/AR@100 (small)', + 'DetectionMasks_Recall/AR@100 (medium)', + 'DetectionMasks_Recall/AR@100 (large)'] + if self._include_metrics_per_category: + for category_dict in self._categories: + metric_names.append('DetectionMasks_PerformanceByCategory/mAP/' + + category_dict['name']) + + def first_value_func(): + self._metrics = self.evaluate() + self.clear() + return np.float32(self._metrics[metric_names[0]]) + + def value_func_factory(metric_name): + def value_func(): + return np.float32(self._metrics[metric_name]) + return value_func + + # Ensure that the metrics are only evaluated once. + first_value_op = tf.py_func(first_value_func, [], tf.float32) + eval_metric_ops = {metric_names[0]: (first_value_op, update_op)} + with tf.control_dependencies([first_value_op]): + for metric_name in metric_names[1:]: + eval_metric_ops[metric_name] = (tf.py_func( + value_func_factory(metric_name), [], np.float32), update_op) + return eval_metric_ops + + def dump_detections_to_json_file(self, json_output_path): + """Saves the detections into json_output_path in the format used by MS COCO. + + Args: + json_output_path: String containing the output file's path. It can be also + None. In that case nothing will be written to the output file. + """ + if json_output_path and json_output_path is not None: + pattern = re.compile(r'\d+\.\d{8,}') + def mround(match): + return '{:.2f}'.format(float(match.group())) + + with tf.io.gfile.GFile(json_output_path, 'w') as fid: + json_string = json.dumps(self._detection_masks_list) + fid.write(re.sub(pattern, mround, json_string)) + + tf.logging.info('Dumping detections to output json file: %s', + json_output_path) diff --git a/research/object_detection/metrics/lvis_evaluation_test.py b/research/object_detection/metrics/lvis_evaluation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a095b8abb58c5be3b4c93107c52d8da0e11f5c4 --- /dev/null +++ b/research/object_detection/metrics/lvis_evaluation_test.py @@ -0,0 +1,182 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow_models.object_detection.metrics.coco_evaluation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import numpy as np +import tensorflow.compat.v1 as tf +from object_detection.core import standard_fields as fields +from object_detection.metrics import lvis_evaluation +from object_detection.utils import tf_version + + +def _get_categories_list(): + return [{ + 'id': 1, + 'name': 'person', + 'frequency': 'f' + }, { + 'id': 2, + 'name': 'dog', + 'frequency': 'c' + }, { + 'id': 3, + 'name': 'cat', + 'frequency': 'r' + }] + + +class LvisMaskEvaluationTest(tf.test.TestCase): + + def testGetOneMAPWithMatchingGroundtruthAndDetections(self): + """Tests that mAP is calculated correctly on GT and Detections.""" + masks1 = np.expand_dims(np.pad( + np.ones([100, 100], dtype=np.uint8), + ((100, 56), (100, 56)), mode='constant'), axis=0) + masks2 = np.expand_dims(np.pad( + np.ones([50, 50], dtype=np.uint8), + ((50, 156), (50, 156)), mode='constant'), axis=0) + masks3 = np.expand_dims(np.pad( + np.ones([25, 25], dtype=np.uint8), + ((25, 206), (25, 206)), mode='constant'), axis=0) + + lvis_evaluator = lvis_evaluation.LVISMaskEvaluator( + _get_categories_list()) + lvis_evaluator.add_single_ground_truth_image_info( + image_id='image1', + groundtruth_dict={ + fields.InputDataFields.groundtruth_boxes: + np.array([[100., 100., 200., 200.]]), + fields.InputDataFields.groundtruth_classes: np.array([1]), + fields.InputDataFields.groundtruth_instance_masks: masks1, + fields.InputDataFields.groundtruth_verified_neg_classes: + np.array([0, 0, 0]), + fields.InputDataFields.groundtruth_not_exhaustive_classes: + np.array([0, 0, 0]) + }) + lvis_evaluator.add_single_detected_image_info( + image_id='image1', + detections_dict={ + fields.DetectionResultFields.detection_masks: masks1, + fields.DetectionResultFields.detection_scores: + np.array([.8]), + fields.DetectionResultFields.detection_classes: + np.array([1]) + }) + lvis_evaluator.add_single_ground_truth_image_info( + image_id='image2', + groundtruth_dict={ + fields.InputDataFields.groundtruth_boxes: + np.array([[50., 50., 100., 100.]]), + fields.InputDataFields.groundtruth_classes: np.array([1]), + fields.InputDataFields.groundtruth_instance_masks: masks2, + fields.InputDataFields.groundtruth_verified_neg_classes: + np.array([0, 0, 0]), + fields.InputDataFields.groundtruth_not_exhaustive_classes: + np.array([0, 0, 0]) + }) + lvis_evaluator.add_single_detected_image_info( + image_id='image2', + detections_dict={ + fields.DetectionResultFields.detection_masks: masks2, + fields.DetectionResultFields.detection_scores: + np.array([.8]), + fields.DetectionResultFields.detection_classes: + np.array([1]) + }) + lvis_evaluator.add_single_ground_truth_image_info( + image_id='image3', + groundtruth_dict={ + fields.InputDataFields.groundtruth_boxes: + np.array([[25., 25., 50., 50.]]), + fields.InputDataFields.groundtruth_classes: np.array([1]), + fields.InputDataFields.groundtruth_instance_masks: masks3, + fields.InputDataFields.groundtruth_verified_neg_classes: + np.array([0, 0, 0]), + fields.InputDataFields.groundtruth_not_exhaustive_classes: + np.array([0, 0, 0]) + }) + lvis_evaluator.add_single_detected_image_info( + image_id='image3', + detections_dict={ + fields.DetectionResultFields.detection_masks: masks3, + fields.DetectionResultFields.detection_scores: + np.array([.8]), + fields.DetectionResultFields.detection_classes: + np.array([1]) + }) + metrics = lvis_evaluator.evaluate() + self.assertAlmostEqual(metrics['DetectionMasks_AP'], 1.0) + + +@unittest.skipIf(tf_version.is_tf1(), 'Only Supported in TF2.X') +class LVISMaskEvaluationPyFuncTest(tf.test.TestCase): + + def testAddEvalDict(self): + lvis_evaluator = lvis_evaluation.LVISMaskEvaluator(_get_categories_list()) + image_id = tf.constant('image1', dtype=tf.string) + groundtruth_boxes = tf.constant( + np.array([[100., 100., 200., 200.], [50., 50., 100., 100.]]), + dtype=tf.float32) + groundtruth_classes = tf.constant(np.array([1, 2]), dtype=tf.float32) + groundtruth_masks = tf.constant(np.stack([ + np.pad(np.ones([100, 100], dtype=np.uint8), ((10, 10), (10, 10)), + mode='constant'), + np.pad(np.ones([50, 50], dtype=np.uint8), ((0, 70), (0, 70)), + mode='constant') + ]), dtype=tf.uint8) + original_image_spatial_shapes = tf.constant([[120, 120], [120, 120]], + dtype=tf.int32) + groundtruth_verified_neg_classes = tf.constant(np.array([0, 0, 0]), + dtype=tf.float32) + groundtruth_not_exhaustive_classes = tf.constant(np.array([0, 0, 0]), + dtype=tf.float32) + detection_scores = tf.constant(np.array([.9, .8]), dtype=tf.float32) + detection_classes = tf.constant(np.array([2, 1]), dtype=tf.float32) + detection_masks = tf.constant(np.stack([ + np.pad(np.ones([50, 50], dtype=np.uint8), ((0, 70), (0, 70)), + mode='constant'), + np.pad(np.ones([100, 100], dtype=np.uint8), ((10, 10), (10, 10)), + mode='constant'), + ]), dtype=tf.uint8) + + input_data_fields = fields.InputDataFields + detection_fields = fields.DetectionResultFields + eval_dict = { + input_data_fields.key: image_id, + input_data_fields.groundtruth_boxes: groundtruth_boxes, + input_data_fields.groundtruth_classes: groundtruth_classes, + input_data_fields.groundtruth_instance_masks: groundtruth_masks, + input_data_fields.groundtruth_verified_neg_classes: + groundtruth_verified_neg_classes, + input_data_fields.groundtruth_not_exhaustive_classes: + groundtruth_not_exhaustive_classes, + input_data_fields.original_image_spatial_shape: + original_image_spatial_shapes, + detection_fields.detection_scores: detection_scores, + detection_fields.detection_classes: detection_classes, + detection_fields.detection_masks: detection_masks + } + lvis_evaluator.add_eval_dict(eval_dict) + self.assertLen(lvis_evaluator._groundtruth_list, 2) + self.assertLen(lvis_evaluator._detection_masks_list, 2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/research/object_detection/metrics/lvis_tools.py b/research/object_detection/metrics/lvis_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..92a102efd89b7be4e24a50f92b9d542e451d5952 --- /dev/null +++ b/research/object_detection/metrics/lvis_tools.py @@ -0,0 +1,259 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrappers for third party lvis to be used within object_detection. + +Usage example: given a set of images with ids in the list image_ids +and corresponding lists of numpy arrays encoding groundtruth (boxes, +masks and classes) and detections (masks, scores and classes), where +elements of each list correspond to detections/annotations of a single image, +then evaluation can be invoked as follows: + + groundtruth = lvis_tools.LVISWrapper(groundtruth_dict) + detections = lvis_results.LVISResults(groundtruth, detections_list) + evaluator = lvis_tools.LVISEvalWrapper(groundtruth, detections, + iou_type='segm') + summary_metrics = evaluator.ComputeMetrics() + +TODO(jonathanhuang): Add support for exporting to JSON. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from lvis import eval as lvis_eval +from lvis import lvis +import numpy as np +from pycocotools import mask +import six +from six.moves import range + + +def RleCompress(masks): + """Compresses mask using Run-length encoding provided by pycocotools. + + Args: + masks: uint8 numpy array of shape [mask_height, mask_width] with values in + {0, 1}. + + Returns: + A pycocotools Run-length encoding of the mask. + """ + rle = mask.encode(np.asfortranarray(masks)) + rle['counts'] = six.ensure_str(rle['counts']) + return rle + + +def _ConvertBoxToCOCOFormat(box): + """Converts a box in [ymin, xmin, ymax, xmax] format to COCO format. + + This is a utility function for converting from our internal + [ymin, xmin, ymax, xmax] convention to the convention used by the COCO API + i.e., [xmin, ymin, width, height]. + + Args: + box: a [ymin, xmin, ymax, xmax] numpy array + + Returns: + a list of floats representing [xmin, ymin, width, height] + """ + return [float(box[1]), float(box[0]), float(box[3] - box[1]), + float(box[2] - box[0])] + + +class LVISWrapper(lvis.LVIS): + """Wrapper for the lvis.LVIS class.""" + + def __init__(self, dataset, detection_type='bbox'): + """LVISWrapper constructor. + + See https://www.lvisdataset.org/dataset for a description of the format. + By default, the coco.COCO class constructor reads from a JSON file. + This function duplicates the same behavior but loads from a dictionary, + allowing us to perform evaluation without writing to external storage. + + Args: + dataset: a dictionary holding bounding box annotations in the COCO format. + detection_type: type of detections being wrapped. Can be one of ['bbox', + 'segmentation'] + + Raises: + ValueError: if detection_type is unsupported. + """ + self.logger = logging.getLogger(__name__) + self.logger.info('Loading annotations.') + self.dataset = dataset + self._create_index() + + +class LVISEvalWrapper(lvis_eval.LVISEval): + """LVISEval wrapper.""" + + def __init__(self, groundtruth=None, detections=None, iou_type='bbox'): + lvis_eval.LVISEval.__init__( + self, groundtruth, detections, iou_type=iou_type) + self._iou_type = iou_type + + def ComputeMetrics(self): + self.run() + summary_metrics = {} + summary_metrics = self.results + return summary_metrics + + +def ExportSingleImageGroundtruthToLVIS(image_id, + next_annotation_id, + category_id_set, + groundtruth_boxes, + groundtruth_classes, + groundtruth_masks=None, + groundtruth_area=None): + """Export groundtruth of a single image to LVIS format. + + This function converts groundtruth detection annotations represented as numpy + arrays to dictionaries that can be ingested by the LVIS evaluation API. Note + that the image_ids provided here must match the ones given to + ExportSingleImageDetectionMasksToLVIS. We assume that boxes, classes and masks + are in correspondence - that is, e.g., groundtruth_boxes[i, :], and + groundtruth_classes[i] are associated with the same groundtruth annotation. + + In the exported result, "area" fields are always set to the area of the + groundtruth bounding box. + + Args: + image_id: a unique image identifier either of type integer or string. + next_annotation_id: integer specifying the first id to use for the + groundtruth annotations. All annotations are assigned a continuous integer + id starting from this value. + category_id_set: A set of valid class ids. Groundtruth with classes not in + category_id_set are dropped. + groundtruth_boxes: numpy array (float32) with shape [num_gt_boxes, 4] + groundtruth_classes: numpy array (int) with shape [num_gt_boxes] + groundtruth_masks: optional uint8 numpy array of shape [num_detections, + image_height, image_width] containing detection_masks. + groundtruth_area: numpy array (float32) with shape [num_gt_boxes]. If + provided, then the area values (in the original absolute coordinates) will + be populated instead of calculated from bounding box coordinates. + + Returns: + a list of groundtruth annotations for a single image in the COCO format. + + Raises: + ValueError: if (1) groundtruth_boxes and groundtruth_classes do not have the + right lengths or (2) if each of the elements inside these lists do not + have the correct shapes or (3) if image_ids are not integers + """ + + if len(groundtruth_classes.shape) != 1: + raise ValueError('groundtruth_classes is ' + 'expected to be of rank 1.') + if len(groundtruth_boxes.shape) != 2: + raise ValueError('groundtruth_boxes is expected to be of ' + 'rank 2.') + if groundtruth_boxes.shape[1] != 4: + raise ValueError('groundtruth_boxes should have ' + 'shape[1] == 4.') + num_boxes = groundtruth_classes.shape[0] + if num_boxes != groundtruth_boxes.shape[0]: + raise ValueError('Corresponding entries in groundtruth_classes, ' + 'and groundtruth_boxes should have ' + 'compatible shapes (i.e., agree on the 0th dimension).' + 'Classes shape: %d. Boxes shape: %d. Image ID: %s' % ( + groundtruth_classes.shape[0], + groundtruth_boxes.shape[0], image_id)) + + groundtruth_list = [] + for i in range(num_boxes): + if groundtruth_classes[i] in category_id_set: + if groundtruth_area is not None and groundtruth_area[i] > 0: + area = float(groundtruth_area[i]) + else: + area = float((groundtruth_boxes[i, 2] - groundtruth_boxes[i, 0]) * + (groundtruth_boxes[i, 3] - groundtruth_boxes[i, 1])) + export_dict = { + 'id': + next_annotation_id + i, + 'image_id': + image_id, + 'category_id': + int(groundtruth_classes[i]), + 'bbox': + list(_ConvertBoxToCOCOFormat(groundtruth_boxes[i, :])), + 'area': area, + } + if groundtruth_masks is not None: + export_dict['segmentation'] = RleCompress(groundtruth_masks[i]) + + groundtruth_list.append(export_dict) + return groundtruth_list + + +def ExportSingleImageDetectionMasksToLVIS(image_id, + category_id_set, + detection_masks, + detection_scores, + detection_classes): + """Export detection masks of a single image to LVIS format. + + This function converts detections represented as numpy arrays to dictionaries + that can be ingested by the LVIS evaluation API. We assume that + detection_masks, detection_scores, and detection_classes are in correspondence + - that is: detection_masks[i, :], detection_classes[i] and detection_scores[i] + are associated with the same annotation. + + Args: + image_id: unique image identifier either of type integer or string. + category_id_set: A set of valid class ids. Detections with classes not in + category_id_set are dropped. + detection_masks: uint8 numpy array of shape [num_detections, image_height, + image_width] containing detection_masks. + detection_scores: float numpy array of shape [num_detections] containing + scores for detection masks. + detection_classes: integer numpy array of shape [num_detections] containing + the classes for detection masks. + + Returns: + a list of detection mask annotations for a single image in the COCO format. + + Raises: + ValueError: if (1) detection_masks, detection_scores and detection_classes + do not have the right lengths or (2) if each of the elements inside these + lists do not have the correct shapes or (3) if image_ids are not integers. + """ + + if len(detection_classes.shape) != 1 or len(detection_scores.shape) != 1: + raise ValueError('All entries in detection_classes and detection_scores' + 'expected to be of rank 1.') + num_boxes = detection_classes.shape[0] + if not num_boxes == len(detection_masks) == detection_scores.shape[0]: + raise ValueError('Corresponding entries in detection_classes, ' + 'detection_scores and detection_masks should have ' + 'compatible lengths and shapes ' + 'Classes length: %d. Masks length: %d. ' + 'Scores length: %d' % ( + detection_classes.shape[0], len(detection_masks), + detection_scores.shape[0] + )) + detections_list = [] + for i in range(num_boxes): + if detection_classes[i] in category_id_set: + detections_list.append({ + 'image_id': image_id, + 'category_id': int(detection_classes[i]), + 'segmentation': RleCompress(detection_masks[i]), + 'score': float(detection_scores[i]) + }) + return detections_list diff --git a/research/object_detection/metrics/lvis_tools_test.py b/research/object_detection/metrics/lvis_tools_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6ef107ad113120d48bbafba9d7ae1971092463 --- /dev/null +++ b/research/object_detection/metrics/lvis_tools_test.py @@ -0,0 +1,158 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow_model.object_detection.metrics.lvis_tools.""" +from lvis import results as lvis_results +import numpy as np +from pycocotools import mask +import tensorflow.compat.v1 as tf +from object_detection.metrics import lvis_tools + + +class LVISToolsTest(tf.test.TestCase): + + def setUp(self): + super(LVISToolsTest, self).setUp() + mask1 = np.pad( + np.ones([100, 100], dtype=np.uint8), + ((100, 56), (100, 56)), mode='constant') + mask2 = np.pad( + np.ones([50, 50], dtype=np.uint8), + ((50, 156), (50, 156)), mode='constant') + mask1_rle = lvis_tools.RleCompress(mask1) + mask2_rle = lvis_tools.RleCompress(mask2) + groundtruth_annotations_list = [ + { + 'id': 1, + 'image_id': 'first', + 'category_id': 1, + 'bbox': [100., 100., 100., 100.], + 'area': 100.**2, + 'segmentation': mask1_rle + }, + { + 'id': 2, + 'image_id': 'second', + 'category_id': 1, + 'bbox': [50., 50., 50., 50.], + 'area': 50.**2, + 'segmentation': mask2_rle + }, + ] + image_list = [ + { + 'id': 'first', + 'neg_category_ids': [], + 'not_exhaustive_category_ids': [], + 'height': 256, + 'width': 256 + }, + { + 'id': 'second', + 'neg_category_ids': [], + 'not_exhaustive_category_ids': [], + 'height': 256, + 'width': 256 + } + ] + category_list = [{'id': 0, 'name': 'person', 'frequency': 'f'}, + {'id': 1, 'name': 'cat', 'frequency': 'c'}, + {'id': 2, 'name': 'dog', 'frequency': 'r'}] + self._groundtruth_dict = { + 'annotations': groundtruth_annotations_list, + 'images': image_list, + 'categories': category_list + } + + self._detections_list = [ + { + 'image_id': 'first', + 'category_id': 1, + 'segmentation': mask1_rle, + 'score': .8 + }, + { + 'image_id': 'second', + 'category_id': 1, + 'segmentation': mask2_rle, + 'score': .7 + }, + ] + + def testLVISWrappers(self): + groundtruth = lvis_tools.LVISWrapper(self._groundtruth_dict) + detections = lvis_results.LVISResults(groundtruth, self._detections_list) + evaluator = lvis_tools.LVISEvalWrapper(groundtruth, detections, + iou_type='segm') + summary_metrics = evaluator.ComputeMetrics() + self.assertAlmostEqual(1.0, summary_metrics['AP']) + + def testSingleImageDetectionMaskExport(self): + masks = np.array( + [[[1, 1,], [1, 1]], + [[0, 0], [0, 1]], + [[0, 0], [0, 0]]], dtype=np.uint8) + classes = np.array([1, 2, 3], dtype=np.int32) + scores = np.array([0.8, 0.2, 0.7], dtype=np.float32) + lvis_annotations = lvis_tools.ExportSingleImageDetectionMasksToLVIS( + image_id='first_image', + category_id_set=set([1, 2, 3]), + detection_classes=classes, + detection_scores=scores, + detection_masks=masks) + expected_counts = ['04', '31', '4'] + for i, mask_annotation in enumerate(lvis_annotations): + self.assertEqual(mask_annotation['segmentation']['counts'], + expected_counts[i]) + self.assertTrue(np.all(np.equal(mask.decode( + mask_annotation['segmentation']), masks[i]))) + self.assertEqual(mask_annotation['image_id'], 'first_image') + self.assertEqual(mask_annotation['category_id'], classes[i]) + self.assertAlmostEqual(mask_annotation['score'], scores[i]) + + def testSingleImageGroundtruthExport(self): + masks = np.array( + [[[1, 1,], [1, 1]], + [[0, 0], [0, 1]], + [[0, 0], [0, 0]]], dtype=np.uint8) + boxes = np.array([[0, 0, 1, 1], + [0, 0, .5, .5], + [.5, .5, 1, 1]], dtype=np.float32) + lvis_boxes = np.array([[0, 0, 1, 1], + [0, 0, .5, .5], + [.5, .5, .5, .5]], dtype=np.float32) + classes = np.array([1, 2, 3], dtype=np.int32) + next_annotation_id = 1 + expected_counts = ['04', '31', '4'] + + lvis_annotations = lvis_tools.ExportSingleImageGroundtruthToLVIS( + image_id='first_image', + category_id_set=set([1, 2, 3]), + next_annotation_id=next_annotation_id, + groundtruth_boxes=boxes, + groundtruth_classes=classes, + groundtruth_masks=masks) + for i, annotation in enumerate(lvis_annotations): + self.assertEqual(annotation['segmentation']['counts'], + expected_counts[i]) + self.assertTrue(np.all(np.equal(mask.decode( + annotation['segmentation']), masks[i]))) + self.assertTrue(np.all(np.isclose(annotation['bbox'], lvis_boxes[i]))) + self.assertEqual(annotation['image_id'], 'first_image') + self.assertEqual(annotation['category_id'], classes[i]) + self.assertEqual(annotation['id'], i + next_annotation_id) + + +if __name__ == '__main__': + tf.test.main() diff --git a/research/object_detection/packages/tf1/setup.py b/research/object_detection/packages/tf1/setup.py index dc3bfaca0b8949c372b12e808cc3304a3d963ff7..a40a368a6f5fddbccfc13b4d76f38a49d3c1c8d3 100644 --- a/research/object_detection/packages/tf1/setup.py +++ b/research/object_detection/packages/tf1/setup.py @@ -4,8 +4,8 @@ from setuptools import find_packages from setuptools import setup REQUIRED_PACKAGES = ['pillow', 'lxml', 'matplotlib', 'Cython', - 'contextlib2', 'tf-slim', 'six', 'pycocotools', 'scipy', - 'pandas'] + 'contextlib2', 'tf-slim', 'six', 'pycocotools', 'lvis', + 'scipy', 'pandas'] setup( name='object_detection', diff --git a/research/object_detection/packages/tf2/setup.py b/research/object_detection/packages/tf2/setup.py index cb997241a91603f36e4f29a28b7bdf5907310128..3f9f0e35363cde03bee00641f3fb53ccc85c55ad 100644 --- a/research/object_detection/packages/tf2/setup.py +++ b/research/object_detection/packages/tf2/setup.py @@ -18,6 +18,7 @@ REQUIRED_PACKAGES = [ 'tf-slim', 'six', 'pycocotools', + 'lvis', 'scipy', 'pandas', 'tf-models-official'