未验证 提交 9a1ba140 编写于 作者: F Francisco Massa 提交者: GitHub

add support for pascal voc dataset and evaluate (#207)

* add support for pascal voc dataset and evaluate

* optimization for adding voc dataset

* make inference.py dataset-agnostic; add use_difficult option to voc dataset

* handle voc difficult objects correctly

* Remove dependency on lxml plus minor improvements

* More cleanups

* More comments and improvements

* Lint fix

* Move configs to their own folder
上级 7bc87084
......@@ -83,6 +83,8 @@ ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014
ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014
ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014
# for pascal voc dataset:
ln -s /path_to_VOCdevkit_dir datasets/voc
```
You can also configure your own paths to the datasets.
......
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN:
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TEST: 300
ANCHOR_SIZES: (128, 256, 512)
ROI_BOX_HEAD:
NUM_CLASSES: 21
DATASETS:
TRAIN: ("voc_2007_trainval",)
TEST: ("voc_2007_test",)
SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.0001
STEPS: (50000, )
MAX_ITER: 70000
IMS_PER_BATCH: 1
TEST:
IMS_PER_BATCH: 1
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN:
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TEST: 300
ANCHOR_SIZES: (128, 256, 512)
ROI_BOX_HEAD:
NUM_CLASSES: 21
DATASETS:
TRAIN: ("voc_2007_trainval",)
TEST: ("voc_2007_test",)
SOLVER:
BASE_LR: 0.004
WEIGHT_DECAY: 0.0001
STEPS: (12500, )
MAX_ITER: 17500
IMS_PER_BATCH: 4
TEST:
IMS_PER_BATCH: 4
......@@ -21,6 +21,13 @@ class DatasetCatalog(object):
"coco/val2014",
"coco/annotations/instances_valminusminival2014.json",
),
"voc_2007_trainval": ("voc/VOC2007", 'trainval'),
"voc_2007_test": ("voc/VOC2007", 'test'),
"voc_2012_train": ("voc/VOC2012", 'train'),
"voc_2012_trainval": ("voc/VOC2012", 'trainval'),
"voc_2012_val": ("voc/VOC2012", 'val'),
"voc_2012_test": ("voc/VOC2012", 'test'),
}
@staticmethod
......@@ -36,6 +43,17 @@ class DatasetCatalog(object):
factory="COCODataset",
args=args,
)
elif "voc" in name:
data_dir = DatasetCatalog.DATA_DIR
attrs = DatasetCatalog.DATASETS[name]
args = dict(
data_dir=os.path.join(data_dir, attrs[0]),
split=attrs[1],
)
return dict(
factory="PascalVOCDataset",
args=args,
)
raise RuntimeError("Dataset not available: {}".format(name))
......
......@@ -26,7 +26,8 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
"""
if not isinstance(dataset_list, (list, tuple)):
raise RuntimeError(
"dataset_list should be a list of strings, got {}".format(dataset_list))
"dataset_list should be a list of strings, got {}".format(dataset_list)
)
datasets = []
for dataset_name in dataset_list:
data = dataset_catalog.get(dataset_name)
......@@ -36,6 +37,8 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
# during training
if data["factory"] == "COCODataset":
args["remove_images_without_annotations"] = is_train
if data["factory"] == "PascalVOCDataset":
args["use_difficult"] = not is_train
args["transforms"] = transforms
# make dataset from factory
dataset = factory(**args)
......@@ -95,7 +98,9 @@ def make_batch_data_sampler(
sampler, images_per_batch, drop_last=False
)
if num_iters is not None:
batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
batch_sampler = samplers.IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
__all__ = ["COCODataset", "ConcatDataset"]
__all__ = ["COCODataset", "ConcatDataset", "PascalVOCDataset"]
......@@ -11,7 +11,6 @@ class COCODataset(torchvision.datasets.coco.CocoDetection):
self, ann_file, root, remove_images_without_annotations, transforms=None
):
super(COCODataset, self).__init__(root, ann_file)
# sort indices for reproducible results
self.ids = sorted(self.ids)
......
from maskrcnn_benchmark.data import datasets
from .coco import coco_evaluation
from .voc import voc_evaluation
def evaluate(dataset, predictions, output_folder, **kwargs):
"""evaluate dataset using different methods based on dataset type.
Args:
dataset: Dataset object
predictions(list[BoxList]): each item in the list represents the
prediction results for one image.
output_folder: output folder, to save evaluation files or results.
**kwargs: other args.
Returns:
evaluation result
"""
args = dict(
dataset=dataset, predictions=predictions, output_folder=output_folder, **kwargs
)
if isinstance(dataset, datasets.COCODataset):
return coco_evaluation(**args)
elif isinstance(dataset, datasets.PascalVOCDataset):
return voc_evaluation(**args)
else:
dataset_name = dataset.__class__.__name__
raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name))
from .coco_eval import do_coco_evaluation
def coco_evaluation(
dataset,
predictions,
output_folder,
box_only,
iou_types,
expected_results,
expected_results_sigma_tol,
):
return do_coco_evaluation(
dataset=dataset,
predictions=predictions,
box_only=box_only,
output_folder=output_folder,
iou_types=iou_types,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
import logging
import tempfile
import os
import torch
from collections import OrderedDict
from tqdm import tqdm
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
def do_coco_evaluation(
dataset,
predictions,
box_only,
output_folder,
iou_types,
expected_results,
expected_results_sigma_tol,
):
logger = logging.getLogger("maskrcnn_benchmark.inference")
if box_only:
logger.info("Evaluating bbox proposals")
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
res = COCOResults("box_proposal")
for limit in [100, 1000]:
for area, suffix in areas.items():
stats = evaluate_box_proposals(
predictions, dataset, area=area, limit=limit
)
key = "AR{}@{:d}".format(suffix, limit)
res.results["box_proposal"][key] = stats["ar"].item()
logger.info(res)
check_expected_results(res, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(res, os.path.join(output_folder, "box_proposals.pth"))
return
logger.info("Preparing results for COCO format")
coco_results = {}
if "bbox" in iou_types:
logger.info("Preparing bbox results")
coco_results["bbox"] = prepare_for_coco_detection(predictions, dataset)
if "segm" in iou_types:
logger.info("Preparing segm results")
coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset)
results = COCOResults(*iou_types)
logger.info("Evaluating predictions")
for iou_type in iou_types:
with tempfile.NamedTemporaryFile() as f:
file_path = f.name
if output_folder:
file_path = os.path.join(output_folder, iou_type + ".json")
res = evaluate_predictions_on_coco(
dataset.coco, coco_results[iou_type], file_path, iou_type
)
results.update(res)
logger.info(results)
check_expected_results(results, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(results, os.path.join(output_folder, "coco_results.pth"))
return results, coco_results
def prepare_for_coco_detection(predictions, dataset):
# assert isinstance(dataset, COCODataset)
coco_results = []
for image_id, prediction in enumerate(predictions):
original_id = dataset.id_to_img_map[image_id]
if len(prediction) == 0:
continue
# TODO replace with get_img_info?
image_width = dataset.coco.imgs[original_id]["width"]
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
prediction = prediction.convert("xywh")
boxes = prediction.bbox.tolist()
scores = prediction.get_field("scores").tolist()
labels = prediction.get_field("labels").tolist()
mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
coco_results.extend(
[
{
"image_id": original_id,
"category_id": mapped_labels[k],
"bbox": box,
"score": scores[k],
}
for k, box in enumerate(boxes)
]
)
return coco_results
def prepare_for_coco_segmentation(predictions, dataset):
import pycocotools.mask as mask_util
import numpy as np
masker = Masker(threshold=0.5, padding=1)
# assert isinstance(dataset, COCODataset)
coco_results = []
for image_id, prediction in tqdm(enumerate(predictions)):
original_id = dataset.id_to_img_map[image_id]
if len(prediction) == 0:
continue
# TODO replace with get_img_info?
image_width = dataset.coco.imgs[original_id]["width"]
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask")
# t = time.time()
masks = masker(masks, prediction)
# logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh')
# boxes = prediction.bbox.tolist()
scores = prediction.get_field("scores").tolist()
labels = prediction.get_field("labels").tolist()
# rles = prediction.get_field('mask')
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
coco_results.extend(
[
{
"image_id": original_id,
"category_id": mapped_labels[k],
"segmentation": rle,
"score": scores[k],
}
for k, rle in enumerate(rles)
]
)
return coco_results
# inspired from Detectron
def evaluate_box_proposals(
predictions, dataset, thresholds=None, area="all", limit=None
):
"""Evaluate detection proposal recall metrics. This function is a much
faster alternative to the official COCO API recall evaluation code. However,
it produces slightly different results.
"""
# Record max overlap value for each gt box
# Return vector of overlap values
areas = {
"all": 0,
"small": 1,
"medium": 2,
"large": 3,
"96-128": 4,
"128-256": 5,
"256-512": 6,
"512-inf": 7,
}
area_ranges = [
[0 ** 2, 1e5 ** 2], # all
[0 ** 2, 32 ** 2], # small
[32 ** 2, 96 ** 2], # medium
[96 ** 2, 1e5 ** 2], # large
[96 ** 2, 128 ** 2], # 96-128
[128 ** 2, 256 ** 2], # 128-256
[256 ** 2, 512 ** 2], # 256-512
[512 ** 2, 1e5 ** 2],
] # 512-inf
assert area in areas, "Unknown area range: {}".format(area)
area_range = area_ranges[areas[area]]
gt_overlaps = []
num_pos = 0
for image_id, prediction in enumerate(predictions):
original_id = dataset.id_to_img_map[image_id]
# TODO replace with get_img_info?
image_width = dataset.coco.imgs[original_id]["width"]
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
# sort predictions in descending order
# TODO maybe remove this and make it explicit in the documentation
inds = prediction.get_field("objectness").sort(descending=True)[1]
prediction = prediction[inds]
ann_ids = dataset.coco.getAnnIds(imgIds=original_id)
anno = dataset.coco.loadAnns(ann_ids)
gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0]
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
gt_boxes = BoxList(gt_boxes, (image_width, image_height), mode="xywh").convert(
"xyxy"
)
gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
if len(gt_boxes) == 0:
continue
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
gt_boxes = gt_boxes[valid_gt_inds]
num_pos += len(gt_boxes)
if len(gt_boxes) == 0:
continue
if len(prediction) == 0:
continue
if limit is not None and len(prediction) > limit:
prediction = prediction[:limit]
overlaps = boxlist_iou(prediction, gt_boxes)
_gt_overlaps = torch.zeros(len(gt_boxes))
for j in range(min(len(prediction), len(gt_boxes))):
# find which proposal box maximally covers each gt box
# and get the iou amount of coverage for each gt box
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
# find which gt box is 'best' covered (i.e. 'best' = most iou)
gt_ovr, gt_ind = max_overlaps.max(dim=0)
assert gt_ovr >= 0
# find the proposal box that covers the best covered gt box
box_ind = argmax_overlaps[gt_ind]
# record the iou coverage of this gt box
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
assert _gt_overlaps[j] == gt_ovr
# mark the proposal box and the gt box as used
overlaps[box_ind, :] = -1
overlaps[:, gt_ind] = -1
# append recorded iou coverage level
gt_overlaps.append(_gt_overlaps)
gt_overlaps = torch.cat(gt_overlaps, dim=0)
gt_overlaps, _ = torch.sort(gt_overlaps)
if thresholds is None:
step = 0.05
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
recalls = torch.zeros_like(thresholds)
# compute recall for each iou threshold
for i, t in enumerate(thresholds):
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
# ar = 2 * np.trapz(recalls, thresholds)
ar = recalls.mean()
return {
"ar": ar,
"recalls": recalls,
"thresholds": thresholds,
"gt_overlaps": gt_overlaps,
"num_pos": num_pos,
}
def evaluate_predictions_on_coco(
coco_gt, coco_results, json_result_file, iou_type="bbox"
):
import json
with open(json_result_file, "w") as f:
json.dump(coco_results, f)
from pycocotools.cocoeval import COCOeval
coco_dt = coco_gt.loadRes(str(json_result_file))
# coco_dt = coco_gt.loadRes(coco_results)
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval
class COCOResults(object):
METRICS = {
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
"box_proposal": [
"AR@100",
"ARs@100",
"ARm@100",
"ARl@100",
"AR@1000",
"ARs@1000",
"ARm@1000",
"ARl@1000",
],
"keypoint": ["AP", "AP50", "AP75", "APm", "APl"],
}
def __init__(self, *iou_types):
allowed_types = ("box_proposal", "bbox", "segm")
assert all(iou_type in allowed_types for iou_type in iou_types)
results = OrderedDict()
for iou_type in iou_types:
results[iou_type] = OrderedDict(
[(metric, -1) for metric in COCOResults.METRICS[iou_type]]
)
self.results = results
def update(self, coco_eval):
if coco_eval is None:
return
from pycocotools.cocoeval import COCOeval
assert isinstance(coco_eval, COCOeval)
s = coco_eval.stats
iou_type = coco_eval.params.iouType
res = self.results[iou_type]
metrics = COCOResults.METRICS[iou_type]
for idx, metric in enumerate(metrics):
res[metric] = s[idx]
def __repr__(self):
# TODO make it pretty
return repr(self.results)
def check_expected_results(results, expected_results, sigma_tol):
if not expected_results:
return
logger = logging.getLogger("maskrcnn_benchmark.inference")
for task, metric, (mean, std) in expected_results:
actual_val = results.results[task][metric]
lo = mean - sigma_tol * std
hi = mean + sigma_tol * std
ok = (lo < actual_val) and (actual_val < hi)
msg = (
"{} > {} sanity check (actual vs. expected): "
"{:.3f} vs. mean={:.4f}, std={:.4}, range=({:.4f}, {:.4f})"
).format(task, metric, actual_val, mean, std, lo, hi)
if not ok:
msg = "FAIL: " + msg
logger.error(msg)
else:
msg = "PASS: " + msg
logger.info(msg)
import logging
from .voc_eval import do_voc_evaluation
def voc_evaluation(dataset, predictions, output_folder, box_only, **_):
logger = logging.getLogger("maskrcnn_benchmark.inference")
if box_only:
logger.warning("voc evaluation doesn't support box_only, ignored.")
logger.info("performing voc evaluation, ignored iou_types.")
return do_voc_evaluation(
dataset=dataset,
predictions=predictions,
output_folder=output_folder,
logger=logger,
)
# A modification version from chainercv repository.
# (See https://github.com/chainer/chainercv/blob/master/chainercv/evaluations/eval_detection_voc.py)
from __future__ import division
import os
from collections import defaultdict
import numpy as np
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
def do_voc_evaluation(dataset, predictions, output_folder, logger):
# TODO need to make the use_07_metric format available
# for the user to choose
pred_boxlists = []
gt_boxlists = []
for image_id, prediction in enumerate(predictions):
img_info = dataset.get_img_info(image_id)
if len(prediction) == 0:
continue
image_width = img_info["width"]
image_height = img_info["height"]
prediction = prediction.resize((image_width, image_height))
pred_boxlists.append(prediction)
gt_boxlist = dataset.get_groundtruth(image_id)
gt_boxlists.append(gt_boxlist)
result = eval_detection_voc(
pred_boxlists=pred_boxlists,
gt_boxlists=gt_boxlists,
iou_thresh=0.5,
use_07_metric=True,
)
result_str = "mAP: {:.4f}\n".format(result["map"])
for i, ap in enumerate(result["ap"]):
if i == 0: # skip background
continue
result_str += "{:<16}: {:.4f}\n".format(
dataset.map_class_id_to_class_name(i), ap
)
logger.info(result_str)
if output_folder:
with open(os.path.join(output_folder, "result.txt"), "w") as fid:
fid.write(result_str)
return result
def eval_detection_voc(pred_boxlists, gt_boxlists, iou_thresh=0.5, use_07_metric=False):
"""Evaluate on voc dataset.
Args:
pred_boxlists(list[BoxList]): pred boxlist, has labels and scores fields.
gt_boxlists(list[BoxList]): ground truth boxlist, has labels field.
iou_thresh: iou thresh
use_07_metric: boolean
Returns:
dict represents the results
"""
assert len(gt_boxlists) == len(
pred_boxlists
), "Length of gt and pred lists need to be same."
prec, rec = calc_detection_voc_prec_rec(
pred_boxlists=pred_boxlists, gt_boxlists=gt_boxlists, iou_thresh=iou_thresh
)
ap = calc_detection_voc_ap(prec, rec, use_07_metric=use_07_metric)
return {"ap": ap, "map": np.nanmean(ap)}
def calc_detection_voc_prec_rec(gt_boxlists, pred_boxlists, iou_thresh=0.5):
"""Calculate precision and recall based on evaluation code of PASCAL VOC.
This function calculates precision and recall of
predicted bounding boxes obtained from a dataset which has :math:`N`
images.
The code is based on the evaluation code used in PASCAL VOC Challenge.
"""
n_pos = defaultdict(int)
score = defaultdict(list)
match = defaultdict(list)
for gt_boxlist, pred_boxlist in zip(gt_boxlists, pred_boxlists):
pred_bbox = pred_boxlist.bbox.numpy()
pred_label = pred_boxlist.get_field("labels").numpy()
pred_score = pred_boxlist.get_field("scores").numpy()
gt_bbox = gt_boxlist.bbox.numpy()
gt_label = gt_boxlist.get_field("labels").numpy()
gt_difficult = gt_boxlist.get_field("difficult").numpy()
for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)):
pred_mask_l = pred_label == l
pred_bbox_l = pred_bbox[pred_mask_l]
pred_score_l = pred_score[pred_mask_l]
# sort by score
order = pred_score_l.argsort()[::-1]
pred_bbox_l = pred_bbox_l[order]
pred_score_l = pred_score_l[order]
gt_mask_l = gt_label == l
gt_bbox_l = gt_bbox[gt_mask_l]
gt_difficult_l = gt_difficult[gt_mask_l]
n_pos[l] += np.logical_not(gt_difficult_l).sum()
score[l].extend(pred_score_l)
if len(pred_bbox_l) == 0:
continue
if len(gt_bbox_l) == 0:
match[l].extend((0,) * pred_bbox_l.shape[0])
continue
# VOC evaluation follows integer typed bounding boxes.
pred_bbox_l = pred_bbox_l.copy()
pred_bbox_l[:, 2:] += 1
gt_bbox_l = gt_bbox_l.copy()
gt_bbox_l[:, 2:] += 1
iou = boxlist_iou(
BoxList(pred_bbox_l, gt_boxlist.size),
BoxList(gt_bbox_l, gt_boxlist.size),
).numpy()
gt_index = iou.argmax(axis=1)
# set -1 if there is no matching ground truth
gt_index[iou.max(axis=1) < iou_thresh] = -1
del iou
selec = np.zeros(gt_bbox_l.shape[0], dtype=bool)
for gt_idx in gt_index:
if gt_idx >= 0:
if gt_difficult_l[gt_idx]:
match[l].append(-1)
else:
if not selec[gt_idx]:
match[l].append(1)
else:
match[l].append(0)
selec[gt_idx] = True
else:
match[l].append(0)
n_fg_class = max(n_pos.keys()) + 1
prec = [None] * n_fg_class
rec = [None] * n_fg_class
for l in n_pos.keys():
score_l = np.array(score[l])
match_l = np.array(match[l], dtype=np.int8)
order = score_l.argsort()[::-1]
match_l = match_l[order]
tp = np.cumsum(match_l == 1)
fp = np.cumsum(match_l == 0)
# If an element of fp + tp is 0,
# the corresponding element of prec[l] is nan.
prec[l] = tp / (fp + tp)
# If n_pos[l] is 0, rec[l] is None.
if n_pos[l] > 0:
rec[l] = tp / n_pos[l]
return prec, rec
def calc_detection_voc_ap(prec, rec, use_07_metric=False):
"""Calculate average precisions based on evaluation code of PASCAL VOC.
This function calculates average precisions
from given precisions and recalls.
The code is based on the evaluation code used in PASCAL VOC Challenge.
Args:
prec (list of numpy.array): A list of arrays.
:obj:`prec[l]` indicates precision for class :math:`l`.
If :obj:`prec[l]` is :obj:`None`, this function returns
:obj:`numpy.nan` for class :math:`l`.
rec (list of numpy.array): A list of arrays.
:obj:`rec[l]` indicates recall for class :math:`l`.
If :obj:`rec[l]` is :obj:`None`, this function returns
:obj:`numpy.nan` for class :math:`l`.
use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric
for calculating average precision. The default value is
:obj:`False`.
Returns:
~numpy.ndarray:
This function returns an array of average precisions.
The :math:`l`-th value corresponds to the average precision
for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is
:obj:`None`, the corresponding value is set to :obj:`numpy.nan`.
"""
n_fg_class = len(prec)
ap = np.empty(n_fg_class)
for l in range(n_fg_class):
if prec[l] is None or rec[l] is None:
ap[l] = np.nan
continue
if use_07_metric:
# 11 point metric
ap[l] = 0
for t in np.arange(0.0, 1.1, 0.1):
if np.sum(rec[l] >= t) == 0:
p = 0
else:
p = np.max(np.nan_to_num(prec[l])[rec[l] >= t])
ap[l] += p / 11
else:
# correct AP calculation
# first append sentinel values at the end
mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0]))
mrec = np.concatenate(([0], rec[l], [1]))
mpre = np.maximum.accumulate(mpre[::-1])[::-1]
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
import os
import torch
import torch.utils.data
from PIL import Image
import sys
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
from maskrcnn_benchmark.structures.bounding_box import BoxList
class PascalVOCDataset(torch.utils.data.Dataset):
CLASSES = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)
def __init__(self, data_dir, split, use_difficult=False, transforms=None):
self.root = data_dir
self.image_set = split
self.keep_difficult = use_difficult
self.transforms = transforms
self._annopath = os.path.join(self.root, "Annotations", "%s.xml")
self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg")
self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt")
with open(self._imgsetpath % self.image_set) as f:
self.ids = f.readlines()
self.ids = [x.strip("\n") for x in self.ids]
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
cls = PascalVOCDataset.CLASSES
self.class_to_ind = dict(zip(cls, range(len(cls))))
def __getitem__(self, index):
img_id = self.ids[index]
img = Image.open(self._imgpath % img_id).convert("RGB")
target = self.get_groundtruth(index)
target = target.clip_to_image(remove_empty=True)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target, index
def __len__(self):
return len(self.ids)
def get_groundtruth(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
anno = self._preprocess_annotation(anno)
height, width = anno["im_info"]
target = BoxList(anno["boxes"], (width, height), mode="xyxy")
target.add_field("labels", anno["labels"])
target.add_field("difficult", anno["difficult"])
return target
def _preprocess_annotation(self, target):
boxes = []
gt_classes = []
difficult_boxes = []
for obj in target.iter("object"):
difficult = int(obj.find("difficult").text) == 1
if not self.keep_difficult and difficult:
continue
name = obj.find("name").text.lower().strip()
bb = obj.find("bndbox")
bndbox = tuple(
map(
int,
[
bb.find("xmin").text,
bb.find("ymin").text,
bb.find("xmax").text,
bb.find("ymax").text,
],
)
)
boxes.append(bndbox)
gt_classes.append(self.class_to_ind[name])
difficult_boxes.append(difficult)
size = target.find("size")
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
res = {
"boxes": torch.tensor(boxes, dtype=torch.float32),
"labels": torch.tensor(gt_classes),
"difficult": torch.tensor(difficult_boxes),
"im_info": im_info,
}
return res
def get_img_info(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
size = anno.find("size")
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
return {"height": im_info[0], "width": im_info[1]}
def map_class_id_to_class_name(self, class_id):
return PascalVOCDataset.CLASSES[class_id]
......@@ -6,4 +6,3 @@ from .transforms import ToTensor
from .transforms import Normalize
from .build import build_transforms
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import tempfile
import time
import os
from collections import OrderedDict
import torch
from tqdm import tqdm
from ..structures.bounding_box import BoxList
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process
from ..utils.comm import scatter_gather
from ..utils.comm import synchronize
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
def compute_on_dataset(model, data_loader, device):
model.eval()
results_dict = {}
......@@ -36,231 +29,6 @@ def compute_on_dataset(model, data_loader, device):
return results_dict
def prepare_for_coco_detection(predictions, dataset):
# assert isinstance(dataset, COCODataset)
coco_results = []
for image_id, prediction in enumerate(predictions):
original_id = dataset.id_to_img_map[image_id]
if len(prediction) == 0:
continue
# TODO replace with get_img_info?
image_width = dataset.coco.imgs[original_id]["width"]
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
prediction = prediction.convert("xywh")
boxes = prediction.bbox.tolist()
scores = prediction.get_field("scores").tolist()
labels = prediction.get_field("labels").tolist()
mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
coco_results.extend(
[
{
"image_id": original_id,
"category_id": mapped_labels[k],
"bbox": box,
"score": scores[k],
}
for k, box in enumerate(boxes)
]
)
return coco_results
def prepare_for_coco_segmentation(predictions, dataset):
import pycocotools.mask as mask_util
import numpy as np
masker = Masker(threshold=0.5, padding=1)
# assert isinstance(dataset, COCODataset)
coco_results = []
for image_id, prediction in tqdm(enumerate(predictions)):
original_id = dataset.id_to_img_map[image_id]
if len(prediction) == 0:
continue
# TODO replace with get_img_info?
image_width = dataset.coco.imgs[original_id]["width"]
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask")
# t = time.time()
# Masker is necessary only if masks haven't been already resized.
if list(masks.shape[-2:]) != [image_height, image_width]:
masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
masks = masks[0]
# logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh')
# boxes = prediction.bbox.tolist()
scores = prediction.get_field("scores").tolist()
labels = prediction.get_field("labels").tolist()
# rles = prediction.get_field('mask')
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels]
coco_results.extend(
[
{
"image_id": original_id,
"category_id": mapped_labels[k],
"segmentation": rle,
"score": scores[k],
}
for k, rle in enumerate(rles)
]
)
return coco_results
# inspired from Detectron
def evaluate_box_proposals(
predictions, dataset, thresholds=None, area="all", limit=None
):
"""Evaluate detection proposal recall metrics. This function is a much
faster alternative to the official COCO API recall evaluation code. However,
it produces slightly different results.
"""
# Record max overlap value for each gt box
# Return vector of overlap values
areas = {
"all": 0,
"small": 1,
"medium": 2,
"large": 3,
"96-128": 4,
"128-256": 5,
"256-512": 6,
"512-inf": 7,
}
area_ranges = [
[0 ** 2, 1e5 ** 2], # all
[0 ** 2, 32 ** 2], # small
[32 ** 2, 96 ** 2], # medium
[96 ** 2, 1e5 ** 2], # large
[96 ** 2, 128 ** 2], # 96-128
[128 ** 2, 256 ** 2], # 128-256
[256 ** 2, 512 ** 2], # 256-512
[512 ** 2, 1e5 ** 2],
] # 512-inf
assert area in areas, "Unknown area range: {}".format(area)
area_range = area_ranges[areas[area]]
gt_overlaps = []
num_pos = 0
for image_id, prediction in enumerate(predictions):
original_id = dataset.id_to_img_map[image_id]
# TODO replace with get_img_info?
image_width = dataset.coco.imgs[original_id]["width"]
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
# sort predictions in descending order
# TODO maybe remove this and make it explicit in the documentation
inds = prediction.get_field("objectness").sort(descending=True)[1]
prediction = prediction[inds]
ann_ids = dataset.coco.getAnnIds(imgIds=original_id)
anno = dataset.coco.loadAnns(ann_ids)
gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0]
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
gt_boxes = BoxList(gt_boxes, (image_width, image_height), mode="xywh").convert(
"xyxy"
)
gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
if len(gt_boxes) == 0:
continue
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
gt_boxes = gt_boxes[valid_gt_inds]
num_pos += len(gt_boxes)
if len(gt_boxes) == 0:
continue
if len(prediction) == 0:
continue
if limit is not None and len(prediction) > limit:
prediction = prediction[:limit]
overlaps = boxlist_iou(prediction, gt_boxes)
_gt_overlaps = torch.zeros(len(gt_boxes))
for j in range(min(len(prediction), len(gt_boxes))):
# find which proposal box maximally covers each gt box
# and get the iou amount of coverage for each gt box
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
# find which gt box is 'best' covered (i.e. 'best' = most iou)
gt_ovr, gt_ind = max_overlaps.max(dim=0)
assert gt_ovr >= 0
# find the proposal box that covers the best covered gt box
box_ind = argmax_overlaps[gt_ind]
# record the iou coverage of this gt box
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
assert _gt_overlaps[j] == gt_ovr
# mark the proposal box and the gt box as used
overlaps[box_ind, :] = -1
overlaps[:, gt_ind] = -1
# append recorded iou coverage level
gt_overlaps.append(_gt_overlaps)
gt_overlaps = torch.cat(gt_overlaps, dim=0)
gt_overlaps, _ = torch.sort(gt_overlaps)
if thresholds is None:
step = 0.05
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
recalls = torch.zeros_like(thresholds)
# compute recall for each iou threshold
for i, t in enumerate(thresholds):
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
# ar = 2 * np.trapz(recalls, thresholds)
ar = recalls.mean()
return {
"ar": ar,
"recalls": recalls,
"thresholds": thresholds,
"gt_overlaps": gt_overlaps,
"num_pos": num_pos,
}
def evaluate_predictions_on_coco(
coco_gt, coco_results, json_result_file, iou_type="bbox"
):
import json
with open(json_result_file, "w") as f:
json.dump(coco_results, f)
from pycocotools.cocoeval import COCOeval
coco_dt = coco_gt.loadRes(str(json_result_file))
# coco_dt = coco_gt.loadRes(coco_results)
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
all_predictions = scatter_gather(predictions_per_gpu)
if not is_main_process():
......@@ -283,84 +51,17 @@ def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
return predictions
class COCOResults(object):
METRICS = {
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
"box_proposal": [
"AR@100",
"ARs@100",
"ARm@100",
"ARl@100",
"AR@1000",
"ARs@1000",
"ARm@1000",
"ARl@1000",
],
"keypoint": ["AP", "AP50", "AP75", "APm", "APl"],
}
def __init__(self, *iou_types):
allowed_types = ("box_proposal", "bbox", "segm")
assert all(iou_type in allowed_types for iou_type in iou_types)
results = OrderedDict()
for iou_type in iou_types:
results[iou_type] = OrderedDict(
[(metric, -1) for metric in COCOResults.METRICS[iou_type]]
)
self.results = results
def update(self, coco_eval):
if coco_eval is None:
return
from pycocotools.cocoeval import COCOeval
assert isinstance(coco_eval, COCOeval)
s = coco_eval.stats
iou_type = coco_eval.params.iouType
res = self.results[iou_type]
metrics = COCOResults.METRICS[iou_type]
for idx, metric in enumerate(metrics):
res[metric] = s[idx]
def __repr__(self):
# TODO make it pretty
return repr(self.results)
def check_expected_results(results, expected_results, sigma_tol):
if not expected_results:
return
logger = logging.getLogger("maskrcnn_benchmark.inference")
for task, metric, (mean, std) in expected_results:
actual_val = results.results[task][metric]
lo = mean - sigma_tol * std
hi = mean + sigma_tol * std
ok = (lo < actual_val) and (actual_val < hi)
msg = (
"{} > {} sanity check (actual vs. expected): "
"{:.3f} vs. mean={:.4f}, std={:.4}, range=({:.4f}, {:.4f})"
).format(task, metric, actual_val, mean, std, lo, hi)
if not ok:
msg = "FAIL: " + msg
logger.error(msg)
else:
msg = "PASS: " + msg
logger.info(msg)
def inference(
model,
data_loader,
iou_types=("bbox",),
box_only=False,
device="cuda",
expected_results=(),
expected_results_sigma_tol=4,
output_folder=None,
model,
data_loader,
dataset_name,
iou_types=("bbox",),
box_only=False,
device="cuda",
expected_results=(),
expected_results_sigma_tol=4,
output_folder=None,
):
# convert to a torch.device for efficiency
device = torch.device(device)
num_devices = (
......@@ -370,7 +71,7 @@ def inference(
)
logger = logging.getLogger("maskrcnn_benchmark.inference")
dataset = data_loader.dataset
logger.info("Start evaluation on {} images".format(len(dataset)))
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
start_time = time.time()
predictions = compute_on_dataset(model, data_loader, device)
# wait for all processes to complete before measuring the time
......@@ -390,46 +91,14 @@ def inference(
if output_folder:
torch.save(predictions, os.path.join(output_folder, "predictions.pth"))
if box_only:
logger.info("Evaluating bbox proposals")
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
res = COCOResults("box_proposal")
for limit in [100, 1000]:
for area, suffix in areas.items():
stats = evaluate_box_proposals(
predictions, dataset, area=area, limit=limit
)
key = "AR{}@{:d}".format(suffix, limit)
res.results["box_proposal"][key] = stats["ar"].item()
logger.info(res)
check_expected_results(res, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(res, os.path.join(output_folder, "box_proposals.pth"))
return
logger.info("Preparing results for COCO format")
coco_results = {}
if "bbox" in iou_types:
logger.info("Preparing bbox results")
coco_results["bbox"] = prepare_for_coco_detection(predictions, dataset)
if "segm" in iou_types:
logger.info("Preparing segm results")
coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset)
results = COCOResults(*iou_types)
logger.info("Evaluating predictions")
for iou_type in iou_types:
with tempfile.NamedTemporaryFile() as f:
file_path = f.name
if output_folder:
file_path = os.path.join(output_folder, iou_type + ".json")
res = evaluate_predictions_on_coco(
dataset.coco, coco_results[iou_type], file_path, iou_type
)
results.update(res)
logger.info(results)
check_expected_results(results, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(results, os.path.join(output_folder, "coco_results.pth"))
return results, coco_results, predictions
extra_args = dict(
box_only=box_only,
iou_types=iou_types,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
return evaluate(dataset=dataset,
predictions=predictions,
output_folder=output_folder,
**extra_args)
......@@ -68,17 +68,18 @@ def main():
if cfg.MODEL.MASK_ON:
iou_types = iou_types + ("segm",)
output_folders = [None] * len(cfg.DATASETS.TEST)
dataset_names = cfg.DATASETS.TEST
if cfg.OUTPUT_DIR:
dataset_names = cfg.DATASETS.TEST
for idx, dataset_name in enumerate(dataset_names):
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
mkdir(output_folder)
output_folders[idx] = output_folder
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
for output_folder, data_loader_val in zip(output_folders, data_loaders_val):
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
inference(
model,
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
......
......@@ -84,17 +84,18 @@ def test(cfg, model, distributed):
if cfg.MODEL.MASK_ON:
iou_types = iou_types + ("segm",)
output_folders = [None] * len(cfg.DATASETS.TEST)
dataset_names = cfg.DATASETS.TEST
if cfg.OUTPUT_DIR:
dataset_names = cfg.DATASETS.TEST
for idx, dataset_name in enumerate(dataset_names):
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
mkdir(output_folder)
output_folders[idx] = output_folder
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
for output_folder, data_loader_val in zip(output_folders, data_loaders_val):
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
inference(
model,
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册