From 0ea6ea0627eebb3a76bc09771616fbe59a989424 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 8 Oct 2019 22:50:40 +1030 Subject: [PATCH] compatibility for torchvision 0.4.0 --- fcos_core/data/datasets/coco.py | 6 +++--- fcos_core/data/transforms/transforms.py | 7 +++++-- fcos_core/structures/segmentation_mask.py | 3 ++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/fcos_core/data/datasets/coco.py b/fcos_core/data/datasets/coco.py index 4d08a32..ac697e7 100644 --- a/fcos_core/data/datasets/coco.py +++ b/fcos_core/data/datasets/coco.py @@ -61,7 +61,7 @@ class COCODataset(torchvision.datasets.coco.CocoDetection): v: k for k, v in self.json_category_id_to_contiguous_id.items() } self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} - self.transforms = transforms + self._transforms = transforms def __getitem__(self, idx): img, anno = super(COCODataset, self).__getitem__(idx) @@ -90,8 +90,8 @@ class COCODataset(torchvision.datasets.coco.CocoDetection): target = target.clip_to_image(remove_empty=True) - if self.transforms is not None: - img, target = self.transforms(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) return img, target, idx diff --git a/fcos_core/data/transforms/transforms.py b/fcos_core/data/transforms/transforms.py index 179723e..102c47e 100644 --- a/fcos_core/data/transforms/transforms.py +++ b/fcos_core/data/transforms/transforms.py @@ -57,9 +57,12 @@ class Resize(object): def __call__(self, image, target=None): size = self.get_size(image.size) image = F.resize(image, size) - if target is None: + if isinstance(target, list): + target = [t.resize(image.size) for t in target] + elif target is None: return image - target = target.resize(image.size) + else: + target = target.resize(image.size) return image, target diff --git a/fcos_core/structures/segmentation_mask.py b/fcos_core/structures/segmentation_mask.py index aca97a9..31486b6 100644 --- a/fcos_core/structures/segmentation_mask.py +++ b/fcos_core/structures/segmentation_mask.py @@ -414,7 +414,8 @@ class PolygonList(object): else: # advanced indexing on a single dimension selected_polygons = [] - if isinstance(item, torch.Tensor) and item.dtype == torch.uint8: + if isinstance(item, torch.Tensor) and \ + item.dtype == torch.uint8 or item.dtype == torch.bool: item = item.nonzero() item = item.squeeze(1) if item.numel() > 0 else item item = item.tolist() -- GitLab