diff --git a/mrcnn/utils.py b/mrcnn/utils.py index 0133dbf914439276b6076e656c158c535846c853..85e51487894cd240bb1d261da0e96d0aff4e5086 100644 --- a/mrcnn/utils.py +++ b/mrcnn/utils.py @@ -753,6 +753,30 @@ def compute_ap(gt_boxes, gt_class_ids, gt_masks, return mAP, precisions, recalls, overlaps +def compute_ap_range(gt_box, gt_class_id, gt_mask, + pred_box, pred_class_id, pred_score, pred_mask, + iou_thresholds=None, verbose=1): + """Compute AP over a range or IoU thresholds. Default range is 0.5-0.95.""" + # Default is 0.5 to 0.95 with increments of 0.05 + iou_thresholds = iou_thresholds or np.arange(0.5, 1.0, 0.05) + + # Compute AP over range of IoU thresholds + AP = [] + for iou_threshold in iou_thresholds: + ap, precisions, recalls, overlaps =\ + compute_ap(gt_box, gt_class_id, gt_mask, + pred_box, pred_class_id, pred_score, pred_mask, + iou_threshold=iou_threshold) + if verbose: + print("AP @{:.2f}:\t {:.3f}".format(iou_threshold, ap)) + AP.append(ap) + AP = np.array(AP).mean() + if verbose: + print("AP @{:.2f}-{:.2f}:\t {:.3f}".format( + iou_thresholds[0], iou_thresholds[-1], AP)) + return AP + + def compute_recall(pred_boxes, gt_boxes, iou): """Compute the recall at the given IoU threshold. It's an indication of how many GT boxes were found by the given prediction boxes.