coco.py 3.5 KB
Newer Older
F
Francisco Massa 已提交
1 2 3 4
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torchvision

5 6 7
from fcos_core.structures.bounding_box import BoxList
from fcos_core.structures.segmentation_mask import SegmentationMask
from fcos_core.structures.keypoint import PersonKeypoints
F
Francisco Massa 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36


min_keypoints_per_image = 10


def _count_visible_keypoints(anno):
    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)


def _has_only_empty_bbox(anno):
    return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)


def has_valid_annotation(anno):
    # if it's empty, there is no annotation
    if len(anno) == 0:
        return False
    # if all boxes have close to zero area, there is no annotation
    if _has_only_empty_bbox(anno):
        return False
    # keypoints task have a slight different critera for considering
    # if an annotation is valid
    if "keypoints" not in anno[0]:
        return True
    # for keypoint detection tasks, only consider valid images those
    # containing at least min_keypoints_per_image
    if _count_visible_keypoints(anno) >= min_keypoints_per_image:
        return True
    return False
F
Francisco Massa 已提交
37 38 39 40 41 42 43 44 45 46 47 48


class COCODataset(torchvision.datasets.coco.CocoDetection):
    def __init__(
        self, ann_file, root, remove_images_without_annotations, transforms=None
    ):
        super(COCODataset, self).__init__(root, ann_file)
        # sort indices for reproducible results
        self.ids = sorted(self.ids)

        # filter images without detection annotations
        if remove_images_without_annotations:
F
Francisco Massa 已提交
49
            ids = []
50
            for img_id in self.ids:
F
Francisco Massa 已提交
51
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
52
                anno = self.coco.loadAnns(ann_ids)
F
Francisco Massa 已提交
53 54 55
                if has_valid_annotation(anno):
                    ids.append(img_id)
            self.ids = ids
56

F
Francisco Massa 已提交
57 58 59 60 61 62 63
        self.json_category_id_to_contiguous_id = {
            v: i + 1 for i, v in enumerate(self.coco.getCatIds())
        }
        self.contiguous_category_id_to_json_id = {
            v: k for k, v in self.json_category_id_to_contiguous_id.items()
        }
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
H
Hao Chen 已提交
64
        self._transforms = transforms
F
Francisco Massa 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82

    def __getitem__(self, idx):
        img, anno = super(COCODataset, self).__getitem__(idx)

        # filter crowd annotations
        # TODO might be better to add an extra field
        anno = [obj for obj in anno if obj["iscrowd"] == 0]

        boxes = [obj["bbox"] for obj in anno]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")

        classes = [obj["category_id"] for obj in anno]
        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
        classes = torch.tensor(classes)
        target.add_field("labels", classes)

        masks = [obj["segmentation"] for obj in anno]
83
        masks = SegmentationMask(masks, img.size, mode='poly')
F
Francisco Massa 已提交
84 85
        target.add_field("masks", masks)

F
Francisco Massa 已提交
86 87 88 89 90
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = PersonKeypoints(keypoints, img.size)
            target.add_field("keypoints", keypoints)

F
Francisco Massa 已提交
91 92
        target = target.clip_to_image(remove_empty=True)

H
Hao Chen 已提交
93 94
        if self._transforms is not None:
            img, target = self._transforms(img, target)
F
Francisco Massa 已提交
95 96 97 98 99 100 101

        return img, target, idx

    def get_img_info(self, index):
        img_id = self.id_to_img_map[index]
        img_data = self.coco.imgs[img_id]
        return img_data